From de72e83508c1e88c67af2231e1242220cc61e96b Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 29 Jan 2023 14:27:22 -0300 Subject: [PATCH] Reorganize things --- server.py | 58 +++++++++++++++++++++++++++---------------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/server.py b/server.py index 0d19b740..b17a2f1a 100644 --- a/server.py +++ b/server.py @@ -46,7 +46,7 @@ args = parser.parse_args() if (args.chat or args.cai_chat) and not args.no_stream: print("Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n") - + settings = { 'max_new_tokens': 200, 'max_new_tokens_min': 1, @@ -326,6 +326,7 @@ default_text = settings['prompt_gpt4chan'] if model_name.lower().startswith(('gp description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n" css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem} #refresh-button {flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label {min-height: 0}" buttons = {} +gen_events = [] if args.chat or args.cai_chat: history = {'internal': [], 'visible': []} @@ -600,9 +601,9 @@ if args.chat or args.cai_chat: suffix = '_pygmalion' if 'pygmalion' in model_name.lower() else '' with gr.Blocks(css=css+".h-\[40vh\] {height: 66.67vh} .gradio-container {max-width: 800px; margin-left: auto; margin-right: auto}", analytics_enabled=False) as interface: if args.cai_chat: - display1 = gr.HTML(value=generate_chat_html([], "", "", character)) + display = gr.HTML(value=generate_chat_html([], "", "", character)) else: - display1 = gr.Chatbot() + display = gr.Chatbot() textbox = gr.Textbox(label='Input') buttons["Generate"] = gr.Button("Generate") with gr.Row(): @@ -664,35 +665,34 @@ if args.chat or args.cai_chat: input_params = [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check, history_size_slider] if args.cai_chat: - gen_event = buttons["Generate"].click(cai_chatbot_wrapper, input_params, display1, show_progress=args.no_stream, api_name="textgen") - gen_event2 = textbox.submit(cai_chatbot_wrapper, input_params, display1, show_progress=args.no_stream) + gen_events.append(buttons["Generate"].click(cai_chatbot_wrapper, input_params, display, show_progress=args.no_stream, api_name="textgen")) + gen_events.append(textbox.submit(cai_chatbot_wrapper, input_params, display, show_progress=args.no_stream)) else: - gen_event = buttons["Generate"].click(chatbot_wrapper, input_params, display1, show_progress=args.no_stream, api_name="textgen") - gen_event2 = textbox.submit(chatbot_wrapper, input_params, display1, show_progress=args.no_stream) - gen_event3 = buttons["Regenerate"].click(regenerate_wrapper, input_params, display1, show_progress=args.no_stream) + gen_events.append(buttons["Generate"].click(chatbot_wrapper, input_params, display, show_progress=args.no_stream, api_name="textgen")) + gen_events.append(textbox.submit(chatbot_wrapper, input_params, display, show_progress=args.no_stream)) + gen_events.append(buttons["Regenerate"].click(regenerate_wrapper, input_params, display, show_progress=args.no_stream)) buttons["Send last reply to input"].click(send_last_reply_to_input, [], textbox, show_progress=args.no_stream) - buttons["Replace last reply"].click(replace_last_reply, [textbox, name1, name2], display1, show_progress=args.no_stream) - - buttons["Clear"].click(clear_chat_log, [character_menu, name1, name2], display1) - buttons["Remove last"].click(remove_last_message, [name1, name2], [display1, textbox], show_progress=False) + buttons["Replace last reply"].click(replace_last_reply, [textbox, name1, name2], display, show_progress=args.no_stream) + buttons["Clear"].click(clear_chat_log, [character_menu, name1, name2], display) + buttons["Remove last"].click(remove_last_message, [name1, name2], [display, textbox], show_progress=False) + buttons["Stop"].click(None, None, None, cancels=gen_events) + buttons["Download"].click(save_history, inputs=[], outputs=[download]) + buttons["Upload character"].click(upload_character, [upload_char, upload_img], [character_menu]) for i in ["Generate", "Regenerate", "Replace last reply"]: buttons[i].click(lambda x: "", textbox, textbox, show_progress=False) textbox.submit(lambda x: "", textbox, textbox, show_progress=False) - buttons["Stop"].click(None, None, None, cancels=[gen_event, gen_event2, gen_event3]) - buttons["Download"].click(save_history, inputs=[], outputs=[download]) - character_menu.change(load_character, [character_menu, name1, name2], [name2, context, display1]) - upload.upload(upload_history, [upload, name1, name2], []) - buttons["Upload character"].click(upload_character, [upload_char, upload_img], [character_menu]) + character_menu.change(load_character, [character_menu, name1, name2], [name2, context, display]) upload_img_tavern.upload(upload_tavern_character, [upload_img_tavern, name1, name2], [character_menu]) + upload.upload(upload_history, [upload, name1, name2], []) upload_img_me.upload(upload_your_profile_picture, [upload_img_me], []) if args.cai_chat: - upload.upload(redraw_html, [name1, name2], [display1]) - upload_img_me.upload(redraw_html, [name1, name2], [display1]) + upload.upload(redraw_html, [name1, name2], [display]) + upload_img_me.upload(redraw_html, [name1, name2], [display]) else: - upload.upload(lambda : history['visible'], [], [display1]) - upload_img_me.upload(lambda : history['visible'], [], [display1]) + upload.upload(lambda : history['visible'], [], [display]) + upload_img_me.upload(lambda : history['visible'], [], [display]) elif args.notebook: with gr.Blocks(css=css, analytics_enabled=False) as interface: @@ -720,9 +720,9 @@ elif args.notebook: if args.extensions is not None: create_extensions_block() - gen_event = buttons["Generate"].click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen") - gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream) - buttons["Stop"].click(None, None, None, cancels=[gen_event, gen_event2]) + gen_events.append(buttons["Generate"].click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen")) + gen_events.append(textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream)) + buttons["Stop"].click(None, None, None, cancels=gen_events) else: with gr.Blocks(css=css, analytics_enabled=False) as interface: @@ -740,7 +740,7 @@ else: buttons["Generate"] = gr.Button("Generate") with gr.Row(): with gr.Column(): - cont = gr.Button("Continue") + buttons["Continue"] = gr.Button("Continue") with gr.Column(): buttons["Stop"] = gr.Button("Stop") if args.extensions is not None: @@ -754,10 +754,10 @@ else: with gr.Tab('HTML'): html = gr.HTML() - gen_event = buttons["Generate"].click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=args.no_stream, api_name="textgen") - gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=args.no_stream) - cont_event = cont.click(generate_reply, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=args.no_stream) - buttons["Stop"].click(None, None, None, cancels=[gen_event, gen_event2, cont_event]) + gen_events.append(buttons["Generate"].click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=args.no_stream, api_name="textgen")) + gen_events.append(textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=args.no_stream)) + gen_events.append(buttons["Continue"].click(generate_reply, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=args.no_stream)) + buttons["Stop"].click(None, None, None, cancels=gen_events) interface.queue() if args.listen: