Implement text streaming (#10)

Still experimental. There might be bugs.
This commit is contained in:
oobabooga 2023-01-18 19:06:50 -03:00
parent ca13acdfa0
commit 0f01a3b1fa

119
server.py
View File

@ -139,25 +139,28 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
preset = infile.read() preset = infile.read()
loaded_preset = inference_settings loaded_preset = inference_settings
input_ids = encode(question, tokens) for i in range(tokens):
input_ids = encode(question, 1)
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
cuda = ".cuda()" if args.cpu else "" cuda = ".cuda()" if args.cpu else ""
if eos_token is None: if eos_token is None:
output = eval(f"model.generate(input_ids, {preset}){cuda}") output = eval(f"model.generate(input_ids, {preset}){cuda}")
else: else:
n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1] n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}") output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
reply = tokenizer.decode(output[0], skip_special_tokens=True) reply = tokenizer.decode(output[0], skip_special_tokens=True)
reply = reply.replace(r'<|endoftext|>', '') reply = reply.replace(r'<|endoftext|>', '')
if model_name.lower().startswith('galactica'): question = reply
reply = fix_galactica(reply) if model_name.lower().startswith('galactica'):
return reply, reply, generate_basic_html(reply) reply = fix_galactica(reply)
elif model_name.lower().startswith('gpt4chan'): yield reply, reply, generate_basic_html(reply)
reply = fix_gpt4chan(reply) elif model_name.lower().startswith('gpt4chan'):
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply) reply = fix_gpt4chan(reply)
else: yield reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply) else:
yield reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
# Choosing the default model # Choosing the default model
if args.model is not None: if args.model is not None:
@ -205,20 +208,20 @@ if args.notebook:
with gr.Column(): with gr.Column():
preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset') preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=True, api_name="textgen") btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False, api_name="textgen")
textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=True) textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False)
elif args.chat or args.cai_chat: elif args.chat or args.cai_chat:
history = [] history = []
# This gets the new line characters right. # This gets the new line characters right.
def chat_response_cleaner(text): def clean_chat_message(text):
text = text.replace('\n', '\n\n') text = text.replace('\n', '\n\n')
text = re.sub(r"\n{3,}", "\n\n", text) text = re.sub(r"\n{3,}", "\n\n", text)
text = text.strip() text = text.strip()
return text return text
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): def generate_chat_prompt(text, tokens, name1, name2, context):
text = chat_response_cleaner(text) text = clean_chat_message(text)
rows = [f"{context}\n\n"] rows = [f"{context}\n\n"]
i = len(history)-1 i = len(history)-1
@ -234,26 +237,42 @@ elif args.chat or args.cai_chat:
rows.pop(1) rows.pop(1)
question = ''.join(rows) question = ''.join(rows)
return question
if check: def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0] history.append(['', ''])
idx = reply.rfind(question[-1024:]) question = generate_chat_prompt(text, tokens, name1, name2, context)
reply = reply[idx+min(1024, len(question)):].split('\n')[0].strip() eos_token = '\n' if check else None
else: for i in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
reply = generate_reply(question, tokens, inference_settings, selected_model)[0] reply = i[0]
idx = reply.rfind(question[-1024:])
reply = reply[idx+min(1024, len(question)):]
idx = reply.find(f"\n{name1}:")
if idx != -1:
reply = reply[:idx]
reply = chat_response_cleaner(reply)
history.append((text, reply)) if check:
return history idx = reply.rfind(question[-1024:])
reply = reply[idx+min(1024, len(question)):].split('\n')[0].strip()
else:
idx = reply.rfind(question[-1024:])
reply = reply[idx+min(1024, len(question)):]
idx = reply.find(f"\n{name1}:")
if idx != -1:
reply = reply[:idx]
reply = clean_chat_message(reply)
history[-1] = [text, reply]
# Prevent the chat log from flashing if something like "\nYo" is generated just
# before "\nYou:" is completed
tmp = f"\n{name1}:"
found = False
for j in range(1, len(tmp)):
if reply[-j:] == tmp[:j]:
found = True
if not found:
yield history
def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
history = chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check) for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
return generate_chat_html(history, name1, name2) yield generate_chat_html(history, name1, name2)
def remove_last_message(name1, name2): def remove_last_message(name1, name2):
history.pop() history.pop()
@ -305,13 +324,13 @@ elif args.chat or args.cai_chat:
check = gr.Checkbox(value=settings['stop_at_newline'], label='Stop generating at new line character?') check = gr.Checkbox(value=settings['stop_at_newline'], label='Stop generating at new line character?')
if args.cai_chat: if args.cai_chat:
btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True, api_name="textgen") btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen")
textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True) textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False)
btn2.click(clear_html, [], display1, show_progress=False) btn2.click(clear_html, [], display1, show_progress=True)
else: else:
btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True, api_name="textgen") btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen")
textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True) textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False)
btn2.click(lambda x: "", display1, display1) btn2.click(lambda x: "", display1, display1, show_progress=True)
btn2.click(clear) btn2.click(clear)
btn3.click(remove_last_message, [name1, name2], display1, show_progress=False) btn3.click(remove_last_message, [name1, name2], display1, show_progress=False)
@ -320,8 +339,9 @@ elif args.chat or args.cai_chat:
else: else:
def continue_wrapper(question, tokens, inference_settings, selected_model): def continue_wrapper(question, tokens, inference_settings, selected_model):
a, b, c = generate_reply(question, tokens, inference_settings, selected_model) for i in generate_reply(question, tokens, inference_settings, selected_model):
return a, a, b, c a, b, c = i
yield a, a, b, c
with gr.Blocks(css=css, analytics_enabled=False) as interface: with gr.Blocks(css=css, analytics_enabled=False) as interface:
gr.Markdown(description) gr.Markdown(description)
@ -341,10 +361,11 @@ else:
with gr.Tab('HTML'): with gr.Tab('HTML'):
html = gr.HTML() html = gr.HTML()
btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=True, api_name="textgen") btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False, api_name="textgen")
cont.click(continue_wrapper, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, textbox, markdown, html], show_progress=True) cont.click(continue_wrapper, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, textbox, markdown, html], show_progress=False)
textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=True) textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False)
interface.queue()
if args.no_listen: if args.no_listen:
interface.launch(share=False) interface.launch(share=False)
else: else: