Pygmalion: add checkbox for choosing whether to stop at newline or not

This commit is contained in:
oobabooga 2023-01-13 15:02:17 -03:00
parent 3a00cb1bbd
commit ecb2cc2194

View File

@ -188,7 +188,7 @@ if args.notebook:
elif args.chat: elif args.chat:
history = [] history = []
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context): def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
question = context+'\n\n' question = context+'\n\n'
for i in range(len(history)): for i in range(len(history)):
question += f"{name1}: {history[i][0][3:-5].strip()}\n" question += f"{name1}: {history[i][0][3:-5].strip()}\n"
@ -196,8 +196,16 @@ elif args.chat:
question += f"{name1}: {text.strip()}\n" question += f"{name1}: {text.strip()}\n"
question += f"{name2}:" question += f"{name2}:"
if check:
reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0] reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0]
reply = reply[len(question):].split('\n')[0].strip() reply = reply[len(question):].split('\n')[0].strip()
else:
reply = generate_reply(question, tokens, inference_settings, selected_model)[0]
reply = reply[len(question):].strip()
idx = reply.find(f"\n{name1}:")
if idx != -1:
reply = reply[:idx]
history.append((text, reply)) history.append((text, reply))
return history return history
@ -228,14 +236,17 @@ elif args.chat:
name1 = gr.Textbox(value=name1_str, lines=1, label='Your name') name1 = gr.Textbox(value=name1_str, lines=1, label='Your name')
name2 = gr.Textbox(value=name2_str, lines=1, label='Bot\'s name') name2 = gr.Textbox(value=name2_str, lines=1, label='Bot\'s name')
context = gr.Textbox(value=context_str, lines=2, label='Context') context = gr.Textbox(value=context_str, lines=2, label='Context')
with gr.Row():
check = gr.Checkbox(value=True, label='Stop generating at new line character?')
with gr.Column(): with gr.Column():
display1 = gr.Chatbot() display1 = gr.Chatbot()
textbox = gr.Textbox(lines=2, label='Input') textbox = gr.Textbox(lines=2, label='Input')
btn = gr.Button("Generate") btn = gr.Button("Generate")
btn2 = gr.Button("Clear history") btn2 = gr.Button("Clear history")
btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context], display1, show_progress=True, api_name="textgen") btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True, api_name="textgen")
textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context], display1, show_progress=True) textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True)
btn2.click(clear) btn2.click(clear)
btn.click(lambda x: "", textbox, textbox, show_progress=False) btn.click(lambda x: "", textbox, textbox, show_progress=False)
textbox.submit(lambda x: "", textbox, textbox, show_progress=False) textbox.submit(lambda x: "", textbox, textbox, show_progress=False)