Add the option to not automatically load the selected model (#1762)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
LaaZa 2023-05-09 18:52:35 +00:00 committed by GitHub
parent cf6caf1830
commit 218bd64bd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 11 deletions

View File

@ -31,6 +31,7 @@ input_params = []
need_restart = False need_restart = False
settings = { settings = {
'autoload_model': True,
'max_new_tokens': 200, 'max_new_tokens': 200,
'max_new_tokens_min': 1, 'max_new_tokens_min': 1,
'max_new_tokens_max': 2000, 'max_new_tokens_max': 2000,

View File

@ -51,7 +51,14 @@ from modules.models import load_model, load_soft_prompt, unload_model
from modules.text_generation import encode, generate_reply, stop_everything_event from modules.text_generation import encode, generate_reply, stop_everything_event
def load_model_wrapper(selected_model): def load_model_wrapper(selected_model, autoload=False):
if not autoload:
yield f"The settings for {selected_model} have been updated.\nClick on \"Load the model\" to load it."
return
if selected_model == 'None':
yield "No model selected"
else:
try: try:
yield f"Loading {selected_model}..." yield f"Loading {selected_model}..."
shared.model_name = selected_model shared.model_name = selected_model
@ -292,6 +299,7 @@ def create_model_menus():
with gr.Row(): with gr.Row():
shared.gradio['lora_menu_apply'] = gr.Button(value='Apply the selected LoRAs') shared.gradio['lora_menu_apply'] = gr.Button(value='Apply the selected LoRAs')
with gr.Row(): with gr.Row():
load = gr.Button("Load the model", visible=not shared.settings['autoload_model'])
unload = gr.Button("Unload the model") unload = gr.Button("Unload the model")
reload = gr.Button("Reload the model") reload = gr.Button("Reload the model")
save_settings = gr.Button("Save settings for this model") save_settings = gr.Button("Save settings for this model")
@ -327,6 +335,9 @@ def create_model_menus():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Row():
shared.gradio['autoload_model'] = gr.Checkbox(value=shared.settings['autoload_model'], label='Autoload the model', info='Whether to load the model as soon as it is selected in the Model dropdown. You can change the default with a settings.json file.')
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main") shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main")
shared.gradio['download_model_button'] = gr.Button("Download") shared.gradio['download_model_button'] = gr.Button("Download")
@ -335,12 +346,20 @@ def create_model_menus():
# In this event handler, the interface state is read and updated # In this event handler, the interface state is read and updated
# with the model defaults (if any), and then the model is loaded # with the model defaults (if any), and then the model is loaded
# unless "autoload_model" is unchecked
shared.gradio['model_menu'].change( shared.gradio['model_menu'].change(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
load_model_specific_settings, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['interface_state']).then( load_model_specific_settings, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['interface_state']).then(
ui.apply_interface_values, shared.gradio['interface_state'], [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False).then( ui.apply_interface_values, shared.gradio['interface_state'], [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False).then(
update_model_parameters, shared.gradio['interface_state'], None).then( update_model_parameters, shared.gradio['interface_state'], None).then(
load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=True) load_model_wrapper, [shared.gradio[k] for k in ['model_menu', 'autoload_model']], shared.gradio['model_status'], show_progress=False)
load.click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
ui.apply_interface_values, shared.gradio['interface_state'],
[shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False).then(
update_model_parameters, shared.gradio['interface_state'], None).then(
partial(load_model_wrapper, autoload=True), shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=False)
unload.click( unload.click(
unload_model, None, None).then( unload_model, None, None).then(
@ -358,6 +377,7 @@ def create_model_menus():
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False) shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False)
shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False) shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False)
shared.gradio['autoload_model'].change(lambda x : gr.update(visible=not x), shared.gradio['autoload_model'], load)
def create_settings_menus(default_preset): def create_settings_menus(default_preset):

View File

@ -1,4 +1,5 @@
{ {
"autoload_model": true,
"max_new_tokens": 200, "max_new_tokens": 200,
"max_new_tokens_min": 1, "max_new_tokens_min": 1,
"max_new_tokens_max": 2000, "max_new_tokens_max": 2000,