Add "Save current settings for this model" button

This commit is contained in:
oobabooga 2023-04-15 12:54:02 -03:00
parent b9dcba7762
commit ac189011cb
3 changed files with 40 additions and 12 deletions

1
.gitignore vendored
View File

@ -24,3 +24,4 @@ settings.json
img_bot* img_bot*
img_me* img_me*
prompts/[0-9]* prompts/[0-9]*
models/config-user.yaml

View File

@ -45,11 +45,8 @@ def load_model(model_name):
shared.is_RWKV = 'rwkv-' in model_name.lower() shared.is_RWKV = 'rwkv-' in model_name.lower()
shared.is_llamacpp = len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))) > 0 shared.is_llamacpp = len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))) > 0
# Default settings # Load the model in simple 16-bit mode by default
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]): if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]):
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), device_map='auto', load_in_8bit=True)
else:
model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
if torch.has_mps: if torch.has_mps:
device = torch.device('mps') device = torch.device('mps')

View File

@ -21,6 +21,7 @@ from pathlib import Path
import gradio as gr import gradio as gr
import psutil import psutil
import torch import torch
import yaml
from PIL import Image from PIL import Image
import modules.extensions as extensions_module import modules.extensions as extensions_module
@ -233,7 +234,7 @@ def get_model_specific_settings(model):
model_settings = {} model_settings = {}
for pat in settings: for pat in settings:
if re.match(pat, model.lower()): if re.match(pat.lower(), model.lower()):
for k in settings[pat]: for k in settings[pat]:
model_settings[k] = settings[pat][k] model_settings[k] = settings[pat][k]
@ -249,6 +250,29 @@ def load_model_specific_settings(model, state, return_dict=False):
return state return state
def save_model_settings(model, state):
if model == 'None':
yield ("Not saving the settings because no model is loaded.")
return
with Path(f'{shared.args.model_dir}/config-user.yaml') as p:
if p.exists():
user_config = yaml.safe_load(open(p, 'r').read())
else:
user_config = {}
if model not in user_config:
user_config[model] = {}
for k in ui.list_model_elements():
user_config[model][k] = state[k]
with open(p, 'w') as f:
f.write(yaml.dump(user_config))
yield (f"Settings for {model} saved to {p}")
def create_model_menus(): def create_model_menus():
# Finding the default values for the GPU and CPU memories # Finding the default values for the GPU and CPU memories
total_mem = [] total_mem = []
@ -285,10 +309,12 @@ def create_model_menus():
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras(), 'value': shared.lora_names}, 'refresh-button') ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras(), 'value': shared.lora_names}, 'refresh-button')
with gr.Column(): with gr.Column():
with gr.Row():
shared.gradio['lora_menu_apply'] = gr.Button(value='Apply the selected LoRAs') shared.gradio['lora_menu_apply'] = gr.Button(value='Apply the selected LoRAs')
with gr.Row(): with gr.Row():
unload = gr.Button("Unload the model") unload = gr.Button("Unload the model")
reload = gr.Button("Reload the model") reload = gr.Button("Reload the model")
save_settings = gr.Button("Save current settings for this model")
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -344,7 +370,11 @@ def create_model_menus():
unload_model, None, None).then( unload_model, None, None).then(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
update_model_parameters, shared.gradio['interface_state'], None).then( update_model_parameters, shared.gradio['interface_state'], None).then(
load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=True) load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=False)
save_settings.click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
save_model_settings, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['model_status'], show_progress=False)
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False) shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False)
shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False) shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False)