Add "Save current settings for this model" button

This commit is contained in:
oobabooga 2023-04-15 12:54:02 -03:00
parent b9dcba7762
commit ac189011cb
3 changed files with 40 additions and 12 deletions

1
.gitignore vendored
View File

@ -24,3 +24,4 @@ settings.json
img_bot*
img_me*
prompts/[0-9]*
models/config-user.yaml

View File

@ -45,17 +45,14 @@ def load_model(model_name):
shared.is_RWKV = 'rwkv-' in model_name.lower()
shared.is_llamacpp = len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))) > 0
# Default settings
# Load the model in simple 16-bit mode by default
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]):
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), device_map='auto', load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
if torch.has_mps:
device = torch.device('mps')
model = model.to(device)
else:
model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
if torch.has_mps:
device = torch.device('mps')
model = model.to(device)
else:
model = model.cuda()
model = model.cuda()
# FlexGen
elif shared.args.flexgen:

View File

@ -21,6 +21,7 @@ from pathlib import Path
import gradio as gr
import psutil
import torch
import yaml
from PIL import Image
import modules.extensions as extensions_module
@ -233,7 +234,7 @@ def get_model_specific_settings(model):
model_settings = {}
for pat in settings:
if re.match(pat, model.lower()):
if re.match(pat.lower(), model.lower()):
for k in settings[pat]:
model_settings[k] = settings[pat][k]
@ -249,6 +250,29 @@ def load_model_specific_settings(model, state, return_dict=False):
return state
def save_model_settings(model, state):
if model == 'None':
yield ("Not saving the settings because no model is loaded.")
return
with Path(f'{shared.args.model_dir}/config-user.yaml') as p:
if p.exists():
user_config = yaml.safe_load(open(p, 'r').read())
else:
user_config = {}
if model not in user_config:
user_config[model] = {}
for k in ui.list_model_elements():
user_config[model][k] = state[k]
with open(p, 'w') as f:
f.write(yaml.dump(user_config))
yield (f"Settings for {model} saved to {p}")
def create_model_menus():
# Finding the default values for the GPU and CPU memories
total_mem = []
@ -285,10 +309,12 @@ def create_model_menus():
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras(), 'value': shared.lora_names}, 'refresh-button')
with gr.Column():
shared.gradio['lora_menu_apply'] = gr.Button(value='Apply the selected LoRAs')
with gr.Row():
shared.gradio['lora_menu_apply'] = gr.Button(value='Apply the selected LoRAs')
with gr.Row():
unload = gr.Button("Unload the model")
reload = gr.Button("Reload the model")
save_settings = gr.Button("Save current settings for this model")
with gr.Row():
with gr.Column():
@ -344,7 +370,11 @@ def create_model_menus():
unload_model, None, None).then(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
update_model_parameters, shared.gradio['interface_state'], None).then(
load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=True)
load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=False)
save_settings.click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
save_model_settings, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['model_status'], show_progress=False)
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False)
shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False)