mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-23 00:18:20 +01:00
Implement text streaming (#10)
Still experimental. There might be bugs.
This commit is contained in:
parent
ca13acdfa0
commit
0f01a3b1fa
119
server.py
119
server.py
@ -139,25 +139,28 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
|||||||
preset = infile.read()
|
preset = infile.read()
|
||||||
loaded_preset = inference_settings
|
loaded_preset = inference_settings
|
||||||
|
|
||||||
input_ids = encode(question, tokens)
|
for i in range(tokens):
|
||||||
|
input_ids = encode(question, 1)
|
||||||
|
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
|
||||||
|
|
||||||
cuda = ".cuda()" if args.cpu else ""
|
cuda = ".cuda()" if args.cpu else ""
|
||||||
if eos_token is None:
|
if eos_token is None:
|
||||||
output = eval(f"model.generate(input_ids, {preset}){cuda}")
|
output = eval(f"model.generate(input_ids, {preset}){cuda}")
|
||||||
else:
|
else:
|
||||||
n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
|
n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
|
||||||
output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
|
output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
|
||||||
|
|
||||||
reply = tokenizer.decode(output[0], skip_special_tokens=True)
|
reply = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
reply = reply.replace(r'<|endoftext|>', '')
|
reply = reply.replace(r'<|endoftext|>', '')
|
||||||
if model_name.lower().startswith('galactica'):
|
question = reply
|
||||||
reply = fix_galactica(reply)
|
if model_name.lower().startswith('galactica'):
|
||||||
return reply, reply, generate_basic_html(reply)
|
reply = fix_galactica(reply)
|
||||||
elif model_name.lower().startswith('gpt4chan'):
|
yield reply, reply, generate_basic_html(reply)
|
||||||
reply = fix_gpt4chan(reply)
|
elif model_name.lower().startswith('gpt4chan'):
|
||||||
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
|
reply = fix_gpt4chan(reply)
|
||||||
else:
|
yield reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
|
||||||
return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
|
else:
|
||||||
|
yield reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
|
||||||
|
|
||||||
# Choosing the default model
|
# Choosing the default model
|
||||||
if args.model is not None:
|
if args.model is not None:
|
||||||
@ -205,20 +208,20 @@ if args.notebook:
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
|
preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
|
||||||
|
|
||||||
btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=True, api_name="textgen")
|
btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False, api_name="textgen")
|
||||||
textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=True)
|
textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False)
|
||||||
elif args.chat or args.cai_chat:
|
elif args.chat or args.cai_chat:
|
||||||
history = []
|
history = []
|
||||||
|
|
||||||
# This gets the new line characters right.
|
# This gets the new line characters right.
|
||||||
def chat_response_cleaner(text):
|
def clean_chat_message(text):
|
||||||
text = text.replace('\n', '\n\n')
|
text = text.replace('\n', '\n\n')
|
||||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
|
def generate_chat_prompt(text, tokens, name1, name2, context):
|
||||||
text = chat_response_cleaner(text)
|
text = clean_chat_message(text)
|
||||||
|
|
||||||
rows = [f"{context}\n\n"]
|
rows = [f"{context}\n\n"]
|
||||||
i = len(history)-1
|
i = len(history)-1
|
||||||
@ -234,26 +237,42 @@ elif args.chat or args.cai_chat:
|
|||||||
rows.pop(1)
|
rows.pop(1)
|
||||||
|
|
||||||
question = ''.join(rows)
|
question = ''.join(rows)
|
||||||
|
return question
|
||||||
|
|
||||||
if check:
|
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
|
||||||
reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0]
|
history.append(['', ''])
|
||||||
idx = reply.rfind(question[-1024:])
|
question = generate_chat_prompt(text, tokens, name1, name2, context)
|
||||||
reply = reply[idx+min(1024, len(question)):].split('\n')[0].strip()
|
eos_token = '\n' if check else None
|
||||||
else:
|
for i in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
|
||||||
reply = generate_reply(question, tokens, inference_settings, selected_model)[0]
|
reply = i[0]
|
||||||
idx = reply.rfind(question[-1024:])
|
|
||||||
reply = reply[idx+min(1024, len(question)):]
|
|
||||||
idx = reply.find(f"\n{name1}:")
|
|
||||||
if idx != -1:
|
|
||||||
reply = reply[:idx]
|
|
||||||
reply = chat_response_cleaner(reply)
|
|
||||||
|
|
||||||
history.append((text, reply))
|
if check:
|
||||||
return history
|
idx = reply.rfind(question[-1024:])
|
||||||
|
reply = reply[idx+min(1024, len(question)):].split('\n')[0].strip()
|
||||||
|
else:
|
||||||
|
idx = reply.rfind(question[-1024:])
|
||||||
|
reply = reply[idx+min(1024, len(question)):]
|
||||||
|
idx = reply.find(f"\n{name1}:")
|
||||||
|
if idx != -1:
|
||||||
|
reply = reply[:idx]
|
||||||
|
reply = clean_chat_message(reply)
|
||||||
|
|
||||||
|
history[-1] = [text, reply]
|
||||||
|
|
||||||
|
# Prevent the chat log from flashing if something like "\nYo" is generated just
|
||||||
|
# before "\nYou:" is completed
|
||||||
|
tmp = f"\n{name1}:"
|
||||||
|
found = False
|
||||||
|
for j in range(1, len(tmp)):
|
||||||
|
if reply[-j:] == tmp[:j]:
|
||||||
|
found = True
|
||||||
|
|
||||||
|
if not found:
|
||||||
|
yield history
|
||||||
|
|
||||||
def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
|
def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
|
||||||
history = chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check)
|
for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
|
||||||
return generate_chat_html(history, name1, name2)
|
yield generate_chat_html(history, name1, name2)
|
||||||
|
|
||||||
def remove_last_message(name1, name2):
|
def remove_last_message(name1, name2):
|
||||||
history.pop()
|
history.pop()
|
||||||
@ -305,13 +324,13 @@ elif args.chat or args.cai_chat:
|
|||||||
check = gr.Checkbox(value=settings['stop_at_newline'], label='Stop generating at new line character?')
|
check = gr.Checkbox(value=settings['stop_at_newline'], label='Stop generating at new line character?')
|
||||||
|
|
||||||
if args.cai_chat:
|
if args.cai_chat:
|
||||||
btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True, api_name="textgen")
|
btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen")
|
||||||
textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True)
|
textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False)
|
||||||
btn2.click(clear_html, [], display1, show_progress=False)
|
btn2.click(clear_html, [], display1, show_progress=True)
|
||||||
else:
|
else:
|
||||||
btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True, api_name="textgen")
|
btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen")
|
||||||
textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True)
|
textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False)
|
||||||
btn2.click(lambda x: "", display1, display1)
|
btn2.click(lambda x: "", display1, display1, show_progress=True)
|
||||||
|
|
||||||
btn2.click(clear)
|
btn2.click(clear)
|
||||||
btn3.click(remove_last_message, [name1, name2], display1, show_progress=False)
|
btn3.click(remove_last_message, [name1, name2], display1, show_progress=False)
|
||||||
@ -320,8 +339,9 @@ elif args.chat or args.cai_chat:
|
|||||||
else:
|
else:
|
||||||
|
|
||||||
def continue_wrapper(question, tokens, inference_settings, selected_model):
|
def continue_wrapper(question, tokens, inference_settings, selected_model):
|
||||||
a, b, c = generate_reply(question, tokens, inference_settings, selected_model)
|
for i in generate_reply(question, tokens, inference_settings, selected_model):
|
||||||
return a, a, b, c
|
a, b, c = i
|
||||||
|
yield a, a, b, c
|
||||||
|
|
||||||
with gr.Blocks(css=css, analytics_enabled=False) as interface:
|
with gr.Blocks(css=css, analytics_enabled=False) as interface:
|
||||||
gr.Markdown(description)
|
gr.Markdown(description)
|
||||||
@ -341,10 +361,11 @@ else:
|
|||||||
with gr.Tab('HTML'):
|
with gr.Tab('HTML'):
|
||||||
html = gr.HTML()
|
html = gr.HTML()
|
||||||
|
|
||||||
btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=True, api_name="textgen")
|
btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False, api_name="textgen")
|
||||||
cont.click(continue_wrapper, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, textbox, markdown, html], show_progress=True)
|
cont.click(continue_wrapper, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, textbox, markdown, html], show_progress=False)
|
||||||
textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=True)
|
textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False)
|
||||||
|
|
||||||
|
interface.queue()
|
||||||
if args.no_listen:
|
if args.no_listen:
|
||||||
interface.launch(share=False)
|
interface.launch(share=False)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user