Add Interface mode tab

This commit is contained in:
oobabooga 2023-03-15 23:29:56 -03:00
parent b50172255a
commit 4d64a57092
4 changed files with 213 additions and 175 deletions

View File

@ -33,6 +33,6 @@ svg {
ol li p, ul li p { ol li p, ul li p {
display: inline-block; display: inline-block;
} }
#main, #settings, #chat-settings { #main, #parameters, #chat-settings, #interface-mode {
border: 0; border: 0;
} }

View File

@ -11,9 +11,12 @@ def load_extensions():
for i, name in enumerate(shared.args.extensions): for i, name in enumerate(shared.args.extensions):
if name in available_extensions: if name in available_extensions:
print(f'Loading the extension "{name}"... ', end='') print(f'Loading the extension "{name}"... ', end='')
try:
exec(f"import extensions.{name}.script") exec(f"import extensions.{name}.script")
state[name] = [True, i] state[name] = [True, i]
print('Ok.') print('Ok.')
except:
print('Fail.')
# This iterator returns the extensions in the order specified in the command-line # This iterator returns the extensions in the order specified in the command-line
def iterator(): def iterator():
@ -42,6 +45,7 @@ def create_extensions_block():
extension.params[param] = shared.settings[_id] extension.params[param] = shared.settings[_id]
# Creating the extension ui elements # Creating the extension ui elements
if len(state) > 0:
with gr.Box(elem_id="extensions"): with gr.Box(elem_id="extensions"):
gr.Markdown("Extensions") gr.Markdown("Extensions")
for extension, name in iterator(): for extension, name in iterator():

View File

@ -19,6 +19,9 @@ gradio = {}
# Generation input parameters # Generation input parameters
input_params = [] input_params = []
# For restarting the interface
need_restart = False
settings = { settings = {
'max_new_tokens': 200, 'max_new_tokens': 200,
'max_new_tokens_min': 1, 'max_new_tokens_min': 1,

View File

@ -176,8 +176,6 @@ else:
shared.args.extensions = shared.args.extensions or [] shared.args.extensions = shared.args.extensions or []
if extension not in shared.args.extensions: if extension not in shared.args.extensions:
shared.args.extensions.append(extension) shared.args.extensions.append(extension)
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
extensions_module.load_extensions()
# Default model # Default model
if shared.args.model is not None: if shared.args.model is not None:
@ -199,13 +197,18 @@ else:
shared.model, shared.tokenizer = load_model(shared.model_name) shared.model, shared.tokenizer = load_model(shared.model_name)
# Default UI settings # Default UI settings
gen_events = []
default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')] default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')] default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
title ='Text generation web UI' title ='Text generation web UI'
description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n' description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else '' suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
def create_interface():
gen_events = []
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
extensions_module.load_extensions()
with gr.Blocks(css=ui.css if not any((shared.args.chat, shared.args.cai_chat)) else ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']: with gr.Blocks(css=ui.css if not any((shared.args.chat, shared.args.cai_chat)) else ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
if shared.args.chat or shared.args.cai_chat: if shared.args.chat or shared.args.cai_chat:
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
@ -263,7 +266,7 @@ with gr.Blocks(css=ui.css if not any((shared.args.chat, shared.args.cai_chat)) e
with gr.Tab('Upload TavernAI Character Card'): with gr.Tab('Upload TavernAI Character Card'):
shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image']) shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
with gr.Tab("Settings", elem_id="settings"): with gr.Tab("Parameters", elem_id="parameters"):
with gr.Box(): with gr.Box():
gr.Markdown("Chat parameters") gr.Markdown("Chat parameters")
with gr.Row(): with gr.Row():
@ -337,7 +340,7 @@ with gr.Blocks(css=ui.css if not any((shared.args.chat, shared.args.cai_chat)) e
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
create_model_and_preset_menus() create_model_and_preset_menus()
with gr.Tab("Settings", elem_id="settings"): with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset) create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
@ -369,7 +372,7 @@ with gr.Blocks(css=ui.css if not any((shared.args.chat, shared.args.cai_chat)) e
shared.gradio['markdown'] = gr.Markdown() shared.gradio['markdown'] = gr.Markdown()
with gr.Tab('HTML'): with gr.Tab('HTML'):
shared.gradio['html'] = gr.HTML() shared.gradio['html'] = gr.HTML()
with gr.Tab("Settings", elem_id="settings"): with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset) create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
@ -380,15 +383,43 @@ with gr.Blocks(css=ui.css if not any((shared.args.chat, shared.args.cai_chat)) e
shared.gradio['Stop'].click(None, None, None, cancels=gen_events) shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
with gr.Tab("Interface mode", elem_id="interface-mode"):
def set_interface_mode(mode, choices):
shared.args.extensions = choices
for k in ["notebook", "chat", "cai_chat"]:
exec(f"shared.args.{k} = False")
if mode != "default":
exec(f"shared.args.{mode} = True")
shared.need_restart = True
extensions = get_available_extensions()
modes = ["default", "notebook", "chat", "cai_chat"]
current_mode = "default"
for mode in modes:
if hasattr(shared.args, mode) and eval(f"shared.args.{mode}"):
current_mode = mode
modes_menu = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
group = gr.CheckboxGroup(choices=extensions, value=shared.args.extensions, label="Available extensions")
kill = gr.Button("Apply and restart the interface")
kill.click(set_interface_mode, [modes_menu, group], None)
kill.click(lambda : None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2000)}')
if shared.args.extensions is not None: if shared.args.extensions is not None:
extensions_module.create_extensions_block() extensions_module.create_extensions_block()
# Launch the interface
shared.gradio['interface'].queue() shared.gradio['interface'].queue()
if shared.args.listen: if shared.args.listen:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch) shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
else: else:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch) shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
# I think that I will need this later create_interface()
while True: while True:
time.sleep(0.5) time.sleep(0.5)
if shared.need_restart:
shared.need_restart = False
shared.gradio['interface'].close()
create_interface()