diff --git a/server.py b/server.py index b9599462..1951f92b 100644 --- a/server.py +++ b/server.py @@ -188,7 +188,7 @@ if args.notebook: elif args.chat: history = [] - def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context): + def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): question = context+'\n\n' for i in range(len(history)): question += f"{name1}: {history[i][0][3:-5].strip()}\n" @@ -196,8 +196,16 @@ elif args.chat: question += f"{name1}: {text.strip()}\n" question += f"{name2}:" - reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0] - reply = reply[len(question):].split('\n')[0].strip() + if check: + reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0] + reply = reply[len(question):].split('\n')[0].strip() + else: + reply = generate_reply(question, tokens, inference_settings, selected_model)[0] + reply = reply[len(question):].strip() + idx = reply.find(f"\n{name1}:") + if idx != -1: + reply = reply[:idx] + history.append((text, reply)) return history @@ -228,14 +236,17 @@ elif args.chat: name1 = gr.Textbox(value=name1_str, lines=1, label='Your name') name2 = gr.Textbox(value=name2_str, lines=1, label='Bot\'s name') context = gr.Textbox(value=context_str, lines=2, label='Context') + with gr.Row(): + check = gr.Checkbox(value=True, label='Stop generating at new line character?') + with gr.Column(): display1 = gr.Chatbot() textbox = gr.Textbox(lines=2, label='Input') btn = gr.Button("Generate") btn2 = gr.Button("Clear history") - btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context], display1, show_progress=True, api_name="textgen") - textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context], display1, show_progress=True) + btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True, api_name="textgen") + textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True) btn2.click(clear) btn.click(lambda x: "", textbox, textbox, show_progress=False) textbox.submit(lambda x: "", textbox, textbox, show_progress=False)