From 14139317055f2dd947e54472cc81d69956f25fe3 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 15 Mar 2023 12:01:32 -0300 Subject: [PATCH] Add a header bar and redesign the interface (#293) --- extensions/gallery/script.py | 2 +- modules/ui.py | 9 ++ server.py | 159 ++++++++++++++++++++--------------- 3 files changed, 99 insertions(+), 71 deletions(-) diff --git a/extensions/gallery/script.py b/extensions/gallery/script.py index 8a2d7cf9..fbf23bc9 100644 --- a/extensions/gallery/script.py +++ b/extensions/gallery/script.py @@ -76,7 +76,7 @@ def generate_html(): return container_html def ui(): - with gr.Accordion("Character gallery"): + with gr.Accordion("Character gallery", open=False): update = gr.Button("Refresh") gallery = gr.HTML(value=generate_html()) update.click(generate_html, [], gallery) diff --git a/modules/ui.py b/modules/ui.py index bb193e35..27233153 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -38,6 +38,9 @@ svg { ol li p, ul li p { display: inline-block; } +#main, #settings, #extensions, #chat-settings { + border: 0; +} """ chat_css = """ @@ -64,6 +67,12 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* { } """ +page_js = """ +document.getElementById("main").parentNode.childNodes[0].style = "border: none; background-color: #8080802b; margin-bottom: 40px" +document.getElementById("main").parentNode.style = "padding: 0; margin: 0" +document.getElementById("main").parentNode.parentNode.parentNode.style = "padding: 0" +""" + class ToolButton(gr.Button, gr.components.FormComponent): """Small button with single emoji as text, fits inside gradio forms""" diff --git a/server.py b/server.py index 4ac81f01..a7ec4888 100644 --- a/server.py +++ b/server.py @@ -101,9 +101,7 @@ def upload_soft_prompt(file): return name -def create_settings_menus(default_preset): - generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True) - +def create_model_and_preset_menus(): with gr.Row(): with gr.Column(): with gr.Row(): @@ -114,7 +112,11 @@ def create_settings_menus(default_preset): shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset') ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button') - with gr.Accordion('Custom generation parameters', open=False, elem_id='accordion'): +def create_settings_menus(default_preset): + generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True) + + with gr.Box(): + gr.Markdown('Custom generation parameters') with gr.Row(): with gr.Column(): shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature') @@ -128,9 +130,11 @@ def create_settings_menus(default_preset): shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream) shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample') + with gr.Box(): gr.Markdown('Contrastive search:') shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha') + with gr.Box(): gr.Markdown('Beam search (uses a lot of VRAM):') with gr.Row(): with gr.Column(): @@ -139,7 +143,8 @@ def create_settings_menus(default_preset): shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') - with gr.Accordion('Soft prompt', open=False, elem_id='accordion'): + with gr.Box(): + gr.Markdown('Soft prompt') with gr.Row(): shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt') ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda : None, lambda : {'choices': get_available_softprompts()}, 'refresh-button') @@ -202,26 +207,41 @@ suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else '' if shared.args.chat or shared.args.cai_chat: with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']: - if shared.args.cai_chat: - shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character)) - else: - shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528")) - shared.gradio['textbox'] = gr.Textbox(label='Input') - with gr.Row(): - shared.gradio['Stop'] = gr.Button('Stop') - shared.gradio['Generate'] = gr.Button('Generate') - with gr.Row(): - shared.gradio['Impersonate'] = gr.Button('Impersonate') - shared.gradio['Regenerate'] = gr.Button('Regenerate') - with gr.Row(): - shared.gradio['Copy last reply'] = gr.Button('Copy last reply') - shared.gradio['Replace last reply'] = gr.Button('Replace last reply') - shared.gradio['Remove last'] = gr.Button('Remove last') + with gr.Tab("Text generation", elem_id="main"): + if shared.args.cai_chat: + shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character)) + else: + shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528")) + shared.gradio['textbox'] = gr.Textbox(label='Input') + with gr.Row(): + shared.gradio['Stop'] = gr.Button('Stop') + shared.gradio['Generate'] = gr.Button('Generate') + with gr.Row(): + shared.gradio['Impersonate'] = gr.Button('Impersonate') + shared.gradio['Regenerate'] = gr.Button('Regenerate') + with gr.Row(): + shared.gradio['Copy last reply'] = gr.Button('Copy last reply') + shared.gradio['Replace last reply'] = gr.Button('Replace last reply') + shared.gradio['Remove last'] = gr.Button('Remove last') - shared.gradio['Clear history'] = gr.Button('Clear history') - shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False) - shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) - with gr.Tab('Chat settings'): + shared.gradio['Clear history'] = gr.Button('Clear history') + shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False) + shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) + + create_model_and_preset_menus() + + with gr.Box(): + with gr.Row(): + with gr.Column(): + shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) + shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) + with gr.Column(): + shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)') + + if shared.args.extensions is not None: + extensions_module.create_extensions_block() + + with gr.Tab("Chat settings", elem_id="chat-settings"): shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name') shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name') shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context') @@ -255,21 +275,11 @@ if shared.args.chat or shared.args.cai_chat: with gr.Tab('Upload TavernAI Character Card'): shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image']) - with gr.Tab('Generation settings'): - with gr.Row(): - with gr.Column(): - shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) - with gr.Column(): - shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) - shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)') + with gr.Tab("Settings", elem_id="settings"): create_settings_menus(default_preset) - shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']] - if shared.args.extensions is not None: - with gr.Tab('Extensions'): - extensions_module.create_extensions_block() - function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper' + shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']] gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) @@ -310,58 +320,66 @@ if shared.args.chat or shared.args.cai_chat: shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']]) + shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.page_js}}}") shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None) shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True) elif shared.args.notebook: with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']: - gr.Markdown(description) - with gr.Tab('Raw'): - shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=23) - with gr.Tab('Markdown'): - shared.gradio['markdown'] = gr.Markdown() - with gr.Tab('HTML'): - shared.gradio['html'] = gr.HTML() + with gr.Tab("Text generation", elem_id="main"): + with gr.Tab('Raw'): + shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=25) + with gr.Tab('Markdown'): + shared.gradio['markdown'] = gr.Markdown() + with gr.Tab('HTML'): + shared.gradio['html'] = gr.HTML() - shared.gradio['Generate'] = gr.Button('Generate') - shared.gradio['Stop'] = gr.Button('Stop') - shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) + with gr.Row(): + shared.gradio['Stop'] = gr.Button('Stop') + shared.gradio['Generate'] = gr.Button('Generate') + shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) - create_settings_menus(default_preset) - if shared.args.extensions is not None: - extensions_module.create_extensions_block() + create_model_and_preset_menus() + if shared.args.extensions is not None: + extensions_module.create_extensions_block() + + with gr.Tab("Settings", elem_id="settings"): + create_settings_menus(default_preset) shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) shared.gradio['Stop'].click(None, None, None, cancels=gen_events) + shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.page_js}}}") else: with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']: - gr.Markdown(description) - with gr.Row(): - with gr.Column(): - shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input') - shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) - shared.gradio['Generate'] = gr.Button('Generate') - with gr.Row(): - with gr.Column(): - shared.gradio['Continue'] = gr.Button('Continue') - with gr.Column(): - shared.gradio['Stop'] = gr.Button('Stop') + with gr.Tab("Text generation", elem_id="main"): + with gr.Row(): + with gr.Column(): + shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input') + shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) + shared.gradio['Generate'] = gr.Button('Generate') + with gr.Row(): + with gr.Column(): + shared.gradio['Continue'] = gr.Button('Continue') + with gr.Column(): + shared.gradio['Stop'] = gr.Button('Stop') - create_settings_menus(default_preset) - if shared.args.extensions is not None: - extensions_module.create_extensions_block() + create_model_and_preset_menus() + if shared.args.extensions is not None: + extensions_module.create_extensions_block() - with gr.Column(): - with gr.Tab('Raw'): - shared.gradio['output_textbox'] = gr.Textbox(lines=15, label='Output') - with gr.Tab('Markdown'): - shared.gradio['markdown'] = gr.Markdown() - with gr.Tab('HTML'): - shared.gradio['html'] = gr.HTML() + with gr.Column(): + with gr.Tab('Raw'): + shared.gradio['output_textbox'] = gr.Textbox(lines=25, label='Output') + with gr.Tab('Markdown'): + shared.gradio['markdown'] = gr.Markdown() + with gr.Tab('HTML'): + shared.gradio['html'] = gr.HTML() + with gr.Tab("Settings", elem_id="settings"): + create_settings_menus(default_preset) shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']] @@ -369,6 +387,7 @@ else: gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)) shared.gradio['Stop'].click(None, None, None, cancels=gen_events) + shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.page_js}}}") shared.gradio['interface'].queue() if shared.args.listen: