mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Implement text streaming (#10)
Still experimental. There might be bugs.
This commit is contained in:
parent
ca13acdfa0
commit
0f01a3b1fa
119
server.py
119
server.py
@ -139,25 +139,28 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
||||
preset = infile.read()
|
||||
loaded_preset = inference_settings
|
||||
|
||||
input_ids = encode(question, tokens)
|
||||
for i in range(tokens):
|
||||
input_ids = encode(question, 1)
|
||||
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
|
||||
|
||||
cuda = ".cuda()" if args.cpu else ""
|
||||
if eos_token is None:
|
||||
output = eval(f"model.generate(input_ids, {preset}){cuda}")
|
||||
else:
|
||||
n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
|
||||
output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
|
||||
cuda = ".cuda()" if args.cpu else ""
|
||||
if eos_token is None:
|
||||
output = eval(f"model.generate(input_ids, {preset}){cuda}")
|
||||
else:
|
||||
n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
|
||||
output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
|
||||
|
||||
reply = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
reply = reply.replace(r'<|endoftext|>', '')
|
||||
if model_name.lower().startswith('galactica'):
|
||||
reply = fix_galactica(reply)
|
||||
return reply, reply, generate_basic_html(reply)
|
||||
elif model_name.lower().startswith('gpt4chan'):
|
||||
reply = fix_gpt4chan(reply)
|
||||
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
|
||||
else:
|
||||
return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
|
||||
reply = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
reply = reply.replace(r'<|endoftext|>', '')
|
||||
question = reply
|
||||
if model_name.lower().startswith('galactica'):
|
||||
reply = fix_galactica(reply)
|
||||
yield reply, reply, generate_basic_html(reply)
|
||||
elif model_name.lower().startswith('gpt4chan'):
|
||||
reply = fix_gpt4chan(reply)
|
||||
yield reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
|
||||
else:
|
||||
yield reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
|
||||
|
||||
# Choosing the default model
|
||||
if args.model is not None:
|
||||
@ -205,20 +208,20 @@ if args.notebook:
|
||||
with gr.Column():
|
||||
preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
|
||||
|
||||
btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=True, api_name="textgen")
|
||||
textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=True)
|
||||
btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False, api_name="textgen")
|
||||
textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False)
|
||||
elif args.chat or args.cai_chat:
|
||||
history = []
|
||||
|
||||
# This gets the new line characters right.
|
||||
def chat_response_cleaner(text):
|
||||
def clean_chat_message(text):
|
||||
text = text.replace('\n', '\n\n')
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
|
||||
text = chat_response_cleaner(text)
|
||||
def generate_chat_prompt(text, tokens, name1, name2, context):
|
||||
text = clean_chat_message(text)
|
||||
|
||||
rows = [f"{context}\n\n"]
|
||||
i = len(history)-1
|
||||
@ -234,26 +237,42 @@ elif args.chat or args.cai_chat:
|
||||
rows.pop(1)
|
||||
|
||||
question = ''.join(rows)
|
||||
return question
|
||||
|
||||
if check:
|
||||
reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0]
|
||||
idx = reply.rfind(question[-1024:])
|
||||
reply = reply[idx+min(1024, len(question)):].split('\n')[0].strip()
|
||||
else:
|
||||
reply = generate_reply(question, tokens, inference_settings, selected_model)[0]
|
||||
idx = reply.rfind(question[-1024:])
|
||||
reply = reply[idx+min(1024, len(question)):]
|
||||
idx = reply.find(f"\n{name1}:")
|
||||
if idx != -1:
|
||||
reply = reply[:idx]
|
||||
reply = chat_response_cleaner(reply)
|
||||
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
|
||||
history.append(['', ''])
|
||||
question = generate_chat_prompt(text, tokens, name1, name2, context)
|
||||
eos_token = '\n' if check else None
|
||||
for i in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
|
||||
reply = i[0]
|
||||
|
||||
history.append((text, reply))
|
||||
return history
|
||||
if check:
|
||||
idx = reply.rfind(question[-1024:])
|
||||
reply = reply[idx+min(1024, len(question)):].split('\n')[0].strip()
|
||||
else:
|
||||
idx = reply.rfind(question[-1024:])
|
||||
reply = reply[idx+min(1024, len(question)):]
|
||||
idx = reply.find(f"\n{name1}:")
|
||||
if idx != -1:
|
||||
reply = reply[:idx]
|
||||
reply = clean_chat_message(reply)
|
||||
|
||||
history[-1] = [text, reply]
|
||||
|
||||
# Prevent the chat log from flashing if something like "\nYo" is generated just
|
||||
# before "\nYou:" is completed
|
||||
tmp = f"\n{name1}:"
|
||||
found = False
|
||||
for j in range(1, len(tmp)):
|
||||
if reply[-j:] == tmp[:j]:
|
||||
found = True
|
||||
|
||||
if not found:
|
||||
yield history
|
||||
|
||||
def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
|
||||
history = chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check)
|
||||
return generate_chat_html(history, name1, name2)
|
||||
for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
|
||||
yield generate_chat_html(history, name1, name2)
|
||||
|
||||
def remove_last_message(name1, name2):
|
||||
history.pop()
|
||||
@ -305,13 +324,13 @@ elif args.chat or args.cai_chat:
|
||||
check = gr.Checkbox(value=settings['stop_at_newline'], label='Stop generating at new line character?')
|
||||
|
||||
if args.cai_chat:
|
||||
btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True, api_name="textgen")
|
||||
textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True)
|
||||
btn2.click(clear_html, [], display1, show_progress=False)
|
||||
btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen")
|
||||
textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False)
|
||||
btn2.click(clear_html, [], display1, show_progress=True)
|
||||
else:
|
||||
btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True, api_name="textgen")
|
||||
textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True)
|
||||
btn2.click(lambda x: "", display1, display1)
|
||||
btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen")
|
||||
textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False)
|
||||
btn2.click(lambda x: "", display1, display1, show_progress=True)
|
||||
|
||||
btn2.click(clear)
|
||||
btn3.click(remove_last_message, [name1, name2], display1, show_progress=False)
|
||||
@ -320,8 +339,9 @@ elif args.chat or args.cai_chat:
|
||||
else:
|
||||
|
||||
def continue_wrapper(question, tokens, inference_settings, selected_model):
|
||||
a, b, c = generate_reply(question, tokens, inference_settings, selected_model)
|
||||
return a, a, b, c
|
||||
for i in generate_reply(question, tokens, inference_settings, selected_model):
|
||||
a, b, c = i
|
||||
yield a, a, b, c
|
||||
|
||||
with gr.Blocks(css=css, analytics_enabled=False) as interface:
|
||||
gr.Markdown(description)
|
||||
@ -341,10 +361,11 @@ else:
|
||||
with gr.Tab('HTML'):
|
||||
html = gr.HTML()
|
||||
|
||||
btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=True, api_name="textgen")
|
||||
cont.click(continue_wrapper, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, textbox, markdown, html], show_progress=True)
|
||||
textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=True)
|
||||
btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False, api_name="textgen")
|
||||
cont.click(continue_wrapper, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, textbox, markdown, html], show_progress=False)
|
||||
textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False)
|
||||
|
||||
interface.queue()
|
||||
if args.no_listen:
|
||||
interface.launch(share=False)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user