diff --git a/server.py b/server.py index 40e2239e..0cc35846 100644 --- a/server.py +++ b/server.py @@ -191,27 +191,7 @@ else: description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n" css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem}" -if args.notebook: - with gr.Blocks(css=css, analytics_enabled=False) as interface: - gr.Markdown(description) - with gr.Tab('Raw'): - textbox = gr.Textbox(value=default_text, lines=23) - with gr.Tab('Markdown'): - markdown = gr.Markdown() - with gr.Tab('HTML'): - html = gr.HTML() - btn = gr.Button("Generate") - - length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens']) - with gr.Row(): - with gr.Column(): - model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model') - with gr.Column(): - preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset') - - btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False, api_name="textgen") - textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False) -elif args.chat or args.cai_chat: +if args.chat or args.cai_chat: history = [] # This gets the new line characters right. @@ -311,10 +291,9 @@ elif args.chat or args.cai_chat: textbox = gr.Textbox(lines=2, label='Input') btn = gr.Button("Generate") with gr.Row(): - with gr.Column(): - btn3 = gr.Button("Remove last message") - with gr.Column(): - btn2 = gr.Button("Clear history") + btn2 = gr.Button("Clear history") + stop = gr.Button("Stop") + btn3 = gr.Button("Remove last message") length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens']) with gr.Row(): @@ -330,25 +309,44 @@ elif args.chat or args.cai_chat: check = gr.Checkbox(value=settings['stop_at_newline'], label='Stop generating at new line character?') if args.cai_chat: - btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen") - textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False) - btn2.click(clear_html, [], display1, show_progress=True) + gen_event = btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen") + gen_event2 = textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False) + btn2.click(clear_html, [], display1, show_progress=False) else: - btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen") - textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False) - btn2.click(lambda x: "", display1, display1, show_progress=True) + gen_event = btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False, api_name="textgen") + gen_event2 = textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=False) + btn2.click(lambda x: "", display1, display1, show_progress=False) btn2.click(clear) btn3.click(remove_last_message, [name1, name2], display1, show_progress=False) btn.click(lambda x: "", textbox, textbox, show_progress=False) textbox.submit(lambda x: "", textbox, textbox, show_progress=False) + stop.click(None, None, None, cancels=[gen_event, gen_event2]) + +elif args.notebook: + with gr.Blocks(css=css, analytics_enabled=False) as interface: + gr.Markdown(description) + with gr.Tab('Raw'): + textbox = gr.Textbox(value=default_text, lines=23) + with gr.Tab('Markdown'): + markdown = gr.Markdown() + with gr.Tab('HTML'): + html = gr.HTML() + btn = gr.Button("Generate") + stop = gr.Button("Stop") + + length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens']) + with gr.Row(): + with gr.Column(): + model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model') + with gr.Column(): + preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset') + + gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False, api_name="textgen") + gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False) + stop.click(None, None, None, cancels=[gen_event, gen_event2]) + else: - - def continue_wrapper(question, tokens, inference_settings, selected_model): - for i in generate_reply(question, tokens, inference_settings, selected_model): - a, b, c = i - yield a, a, b, c - with gr.Blocks(css=css, analytics_enabled=False) as interface: gr.Markdown(description) with gr.Row(): @@ -358,7 +356,11 @@ else: preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset') model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model') btn = gr.Button("Generate") - cont = gr.Button("Continue") + with gr.Row(): + with gr.Column(): + cont = gr.Button("Continue") + with gr.Column(): + stop = gr.Button("Stop") with gr.Column(): with gr.Tab('Raw'): output_textbox = gr.Textbox(lines=15, label='Output') @@ -367,9 +369,10 @@ else: with gr.Tab('HTML'): html = gr.HTML() - btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False, api_name="textgen") - cont.click(continue_wrapper, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, textbox, markdown, html], show_progress=False) - textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False) + gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False, api_name="textgen") + gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False) + cont_event = cont.click(generate_reply, [output_textbox, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=False) + stop.click(None, None, None, cancels=[gen_event, gen_event2, cont_event]) interface.queue() if args.no_listen: