mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
Make interface state (mostly) persistent on page reload
This commit is contained in:
parent
47809e28aa
commit
b1ee674d75
@ -20,6 +20,9 @@ processing_message = '*Is typing...*'
|
|||||||
# UI elements (buttons, sliders, HTML, etc)
|
# UI elements (buttons, sliders, HTML, etc)
|
||||||
gradio = {}
|
gradio = {}
|
||||||
|
|
||||||
|
# For keeping the values of UI elements on page reload
|
||||||
|
persistent_interface_state = {}
|
||||||
|
|
||||||
# Generation input parameters
|
# Generation input parameters
|
||||||
input_params = []
|
input_params = []
|
||||||
|
|
||||||
|
@ -33,9 +33,10 @@ def list_model_elements():
|
|||||||
|
|
||||||
|
|
||||||
def list_interface_input_elements(chat=False):
|
def list_interface_input_elements(chat=False):
|
||||||
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens']
|
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu']
|
||||||
if chat:
|
if chat:
|
||||||
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template']
|
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu']
|
||||||
|
|
||||||
elements += list_model_elements()
|
elements += list_model_elements()
|
||||||
return elements
|
return elements
|
||||||
|
|
||||||
@ -44,11 +45,26 @@ def gather_interface_values(*args):
|
|||||||
output = {}
|
output = {}
|
||||||
for i, element in enumerate(shared.input_elements):
|
for i, element in enumerate(shared.input_elements):
|
||||||
output[element] = args[i]
|
output[element] = args[i]
|
||||||
|
|
||||||
|
shared.persistent_interface_state = output
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def apply_interface_values(state):
|
def apply_interface_values(state, use_persistent=False):
|
||||||
return [state[i] for i in list_interface_input_elements(chat=shared.is_chat())]
|
if use_persistent:
|
||||||
|
state = shared.persistent_interface_state
|
||||||
|
|
||||||
|
elements = list_interface_input_elements(chat=shared.is_chat())
|
||||||
|
if len(state) == 0:
|
||||||
|
return [gr.update() for k in elements] # Dummy, do nothing
|
||||||
|
else:
|
||||||
|
if use_persistent and 'mode' in state:
|
||||||
|
if state['mode'] == 'instruct':
|
||||||
|
return [state[k] if k not in ['character_menu'] else gr.update() for k in elements]
|
||||||
|
else:
|
||||||
|
return [state[k] if k not in ['instruction_template'] else gr.update() for k in elements]
|
||||||
|
else:
|
||||||
|
return [state[k] for k in elements]
|
||||||
|
|
||||||
|
|
||||||
class ToolButton(gr.Button, gr.components.FormComponent):
|
class ToolButton(gr.Button, gr.components.FormComponent):
|
||||||
|
@ -32,6 +32,7 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
import zipfile
|
import zipfile
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
@ -846,6 +847,8 @@ def create_interface():
|
|||||||
shared.gradio['count_tokens'].click(count_tokens, shared.gradio['textbox'], shared.gradio['status'], show_progress=False)
|
shared.gradio['count_tokens'].click(count_tokens, shared.gradio['textbox'], shared.gradio['status'], show_progress=False)
|
||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||||
|
|
||||||
|
shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False)
|
||||||
|
|
||||||
# Launch the interface
|
# Launch the interface
|
||||||
shared.gradio['interface'].queue()
|
shared.gradio['interface'].queue()
|
||||||
if shared.args.listen:
|
if shared.args.listen:
|
||||||
@ -855,7 +858,6 @@ def create_interface():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
# Loading custom settings
|
# Loading custom settings
|
||||||
settings_file = None
|
settings_file = None
|
||||||
if shared.args.settings is not None and Path(shared.args.settings).exists():
|
if shared.args.settings is not None and Path(shared.args.settings).exists():
|
||||||
@ -900,9 +902,11 @@ if __name__ == "__main__":
|
|||||||
print('The following models are available:\n')
|
print('The following models are available:\n')
|
||||||
for i, model in enumerate(available_models):
|
for i, model in enumerate(available_models):
|
||||||
print(f'{i+1}. {model}')
|
print(f'{i+1}. {model}')
|
||||||
|
|
||||||
print(f'\nWhich one do you want to load? 1-{len(available_models)}\n')
|
print(f'\nWhich one do you want to load? 1-{len(available_models)}\n')
|
||||||
i = int(input()) - 1
|
i = int(input()) - 1
|
||||||
print()
|
print()
|
||||||
|
|
||||||
shared.model_name = available_models[i]
|
shared.model_name = available_models[i]
|
||||||
|
|
||||||
# If any model has been selected, load it
|
# If any model has been selected, load it
|
||||||
|
Loading…
Reference in New Issue
Block a user