mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Add "Save current settings for this model" button
This commit is contained in:
parent
b9dcba7762
commit
ac189011cb
1
.gitignore
vendored
1
.gitignore
vendored
@ -24,3 +24,4 @@ settings.json
|
||||
img_bot*
|
||||
img_me*
|
||||
prompts/[0-9]*
|
||||
models/config-user.yaml
|
||||
|
@ -45,17 +45,14 @@ def load_model(model_name):
|
||||
shared.is_RWKV = 'rwkv-' in model_name.lower()
|
||||
shared.is_llamacpp = len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))) > 0
|
||||
|
||||
# Default settings
|
||||
# Load the model in simple 16-bit mode by default
|
||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]):
|
||||
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
|
||||
model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), device_map='auto', load_in_8bit=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
|
||||
if torch.has_mps:
|
||||
device = torch.device('mps')
|
||||
model = model.to(device)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
|
||||
if torch.has_mps:
|
||||
device = torch.device('mps')
|
||||
model = model.to(device)
|
||||
else:
|
||||
model = model.cuda()
|
||||
model = model.cuda()
|
||||
|
||||
# FlexGen
|
||||
elif shared.args.flexgen:
|
||||
|
36
server.py
36
server.py
@ -21,6 +21,7 @@ from pathlib import Path
|
||||
import gradio as gr
|
||||
import psutil
|
||||
import torch
|
||||
import yaml
|
||||
from PIL import Image
|
||||
|
||||
import modules.extensions as extensions_module
|
||||
@ -233,7 +234,7 @@ def get_model_specific_settings(model):
|
||||
model_settings = {}
|
||||
|
||||
for pat in settings:
|
||||
if re.match(pat, model.lower()):
|
||||
if re.match(pat.lower(), model.lower()):
|
||||
for k in settings[pat]:
|
||||
model_settings[k] = settings[pat][k]
|
||||
|
||||
@ -249,6 +250,29 @@ def load_model_specific_settings(model, state, return_dict=False):
|
||||
return state
|
||||
|
||||
|
||||
def save_model_settings(model, state):
|
||||
if model == 'None':
|
||||
yield ("Not saving the settings because no model is loaded.")
|
||||
return
|
||||
|
||||
with Path(f'{shared.args.model_dir}/config-user.yaml') as p:
|
||||
if p.exists():
|
||||
user_config = yaml.safe_load(open(p, 'r').read())
|
||||
else:
|
||||
user_config = {}
|
||||
|
||||
if model not in user_config:
|
||||
user_config[model] = {}
|
||||
|
||||
for k in ui.list_model_elements():
|
||||
user_config[model][k] = state[k]
|
||||
|
||||
with open(p, 'w') as f:
|
||||
f.write(yaml.dump(user_config))
|
||||
|
||||
yield (f"Settings for {model} saved to {p}")
|
||||
|
||||
|
||||
def create_model_menus():
|
||||
# Finding the default values for the GPU and CPU memories
|
||||
total_mem = []
|
||||
@ -285,10 +309,12 @@ def create_model_menus():
|
||||
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras(), 'value': shared.lora_names}, 'refresh-button')
|
||||
|
||||
with gr.Column():
|
||||
shared.gradio['lora_menu_apply'] = gr.Button(value='Apply the selected LoRAs')
|
||||
with gr.Row():
|
||||
shared.gradio['lora_menu_apply'] = gr.Button(value='Apply the selected LoRAs')
|
||||
with gr.Row():
|
||||
unload = gr.Button("Unload the model")
|
||||
reload = gr.Button("Reload the model")
|
||||
save_settings = gr.Button("Save current settings for this model")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
@ -344,7 +370,11 @@ def create_model_menus():
|
||||
unload_model, None, None).then(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
update_model_parameters, shared.gradio['interface_state'], None).then(
|
||||
load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=True)
|
||||
load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=False)
|
||||
|
||||
save_settings.click(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
save_model_settings, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['model_status'], show_progress=False)
|
||||
|
||||
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False)
|
||||
shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False)
|
||||
|
Loading…
Reference in New Issue
Block a user