diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py index fddde30c..91d40e4c 100644 --- a/extensions/send_pictures/script.py +++ b/extensions/send_pictures/script.py @@ -6,6 +6,7 @@ import torch from transformers import BlipForConditionalGeneration, BlipProcessor from modules import chat, shared +from modules.ui import gather_interface_values # If 'state' is True, will hijack the next chat generation with # custom input text given by 'value' in the format [text, visible_text] @@ -42,7 +43,7 @@ def ui(): picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None) # Call the generation function - picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream) - - # Clear the picture from the upload field - picture_select.upload(lambda: None, [], [picture_select], show_progress=False) + picture_select.upload( + gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( + chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( + lambda: None, None, picture_select, show_progress=False) diff --git a/modules/ui.py b/modules/ui.py index def1faaf..cb494bb7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -2,6 +2,8 @@ from pathlib import Path import gradio as gr +from modules import shared + refresh_symbol = '\U0001f504' # 🔄 with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f: @@ -14,6 +16,21 @@ with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f: chat_js = f.read() +def list_interface_input_elements(chat=False): + elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings'] + if chat: + elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode'] + return elements + + +def gather_interface_values(*args): + output = {} + for i, element in enumerate(shared.input_elements): + output[element] = args[i] + output['custom_stopping_strings'] = eval(f"[{output['custom_stopping_strings']}]") + return output + + class ToolButton(gr.Button, gr.components.FormComponent): """Small button with single emoji as text, fits inside gradio forms""" diff --git a/server.py b/server.py index 54dac6e0..f89cbae0 100644 --- a/server.py +++ b/server.py @@ -24,7 +24,6 @@ from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt, unload_model from modules.text_generation import generate_reply, stop_everything_event - # Loading custom settings settings_file = None if shared.args.settings is not None and Path(shared.args.settings).exists(): @@ -361,21 +360,6 @@ else: title = 'Text generation web UI' -def list_interface_input_elements(chat=False): - elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings'] - if chat: - elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode'] - return elements - - -def gather_interface_values(*args): - output = {} - for i, element in enumerate(shared.input_elements): - output[element] = args[i] - output['custom_stopping_strings'] = eval(f"[{output['custom_stopping_strings']}]") - return output - - def create_interface(): gen_events = [] if shared.args.extensions is not None and len(shared.args.extensions) > 0: @@ -384,7 +368,7 @@ def create_interface(): with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']: if shared.is_chat(): - shared.input_elements = list_interface_input_elements(chat=True) + shared.input_elements = ui.list_interface_input_elements(chat=True) shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements}) shared.gradio['Chat input'] = gr.State() @@ -469,33 +453,33 @@ def create_interface(): reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'mode']] gen_events.append(shared.gradio['Generate'].click( - gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( + ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then( chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( chat.save_history, shared.gradio['mode'], None, show_progress=False) ) gen_events.append(shared.gradio['textbox'].submit( - gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( + ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then( chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( chat.save_history, shared.gradio['mode'], None, show_progress=False) ) gen_events.append(shared.gradio['Regenerate'].click( - gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( + ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( chat.save_history, shared.gradio['mode'], None, show_progress=False) ) gen_events.append(shared.gradio['Continue'].click( - gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( + ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( chat.save_history, shared.gradio['mode'], None, show_progress=False) ) gen_events.append(shared.gradio['Impersonate'].click( - gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( + ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream) ) @@ -551,7 +535,7 @@ def create_interface(): shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True) elif shared.args.notebook: - shared.input_elements = list_interface_input_elements(chat=False) + shared.input_elements = ui.list_interface_input_elements(chat=False) shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements}) with gr.Tab("Text generation", elem_id="main"): with gr.Row(): @@ -584,13 +568,13 @@ def create_interface(): output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] gen_events.append(shared.gradio['Generate'].click( - gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( + ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then( #None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}") ) gen_events.append(shared.gradio['textbox'].submit( - gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( + ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then( #None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}") ) @@ -599,7 +583,7 @@ def create_interface(): shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") else: - shared.input_elements = list_interface_input_elements(chat=False) + shared.input_elements = ui.list_interface_input_elements(chat=False) shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements}) with gr.Tab("Text generation", elem_id="main"): with gr.Row(): @@ -630,19 +614,19 @@ def create_interface(): output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']] gen_events.append(shared.gradio['Generate'].click( - gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( + ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then( #None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}") ) gen_events.append(shared.gradio['textbox'].submit( - gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( + ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then( #None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}") ) gen_events.append(shared.gradio['Continue'].click( - gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( + ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)#.then( #None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}") )