mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Add "Save current settings for this model" button
This commit is contained in:
parent
b9dcba7762
commit
ac189011cb
1
.gitignore
vendored
1
.gitignore
vendored
@ -24,3 +24,4 @@ settings.json
|
|||||||
img_bot*
|
img_bot*
|
||||||
img_me*
|
img_me*
|
||||||
prompts/[0-9]*
|
prompts/[0-9]*
|
||||||
|
models/config-user.yaml
|
||||||
|
@ -45,11 +45,8 @@ def load_model(model_name):
|
|||||||
shared.is_RWKV = 'rwkv-' in model_name.lower()
|
shared.is_RWKV = 'rwkv-' in model_name.lower()
|
||||||
shared.is_llamacpp = len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))) > 0
|
shared.is_llamacpp = len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))) > 0
|
||||||
|
|
||||||
# Default settings
|
# Load the model in simple 16-bit mode by default
|
||||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]):
|
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]):
|
||||||
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), device_map='auto', load_in_8bit=True)
|
|
||||||
else:
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
|
model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
|
||||||
if torch.has_mps:
|
if torch.has_mps:
|
||||||
device = torch.device('mps')
|
device = torch.device('mps')
|
||||||
|
34
server.py
34
server.py
@ -21,6 +21,7 @@ from pathlib import Path
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
import yaml
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import modules.extensions as extensions_module
|
import modules.extensions as extensions_module
|
||||||
@ -233,7 +234,7 @@ def get_model_specific_settings(model):
|
|||||||
model_settings = {}
|
model_settings = {}
|
||||||
|
|
||||||
for pat in settings:
|
for pat in settings:
|
||||||
if re.match(pat, model.lower()):
|
if re.match(pat.lower(), model.lower()):
|
||||||
for k in settings[pat]:
|
for k in settings[pat]:
|
||||||
model_settings[k] = settings[pat][k]
|
model_settings[k] = settings[pat][k]
|
||||||
|
|
||||||
@ -249,6 +250,29 @@ def load_model_specific_settings(model, state, return_dict=False):
|
|||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def save_model_settings(model, state):
|
||||||
|
if model == 'None':
|
||||||
|
yield ("Not saving the settings because no model is loaded.")
|
||||||
|
return
|
||||||
|
|
||||||
|
with Path(f'{shared.args.model_dir}/config-user.yaml') as p:
|
||||||
|
if p.exists():
|
||||||
|
user_config = yaml.safe_load(open(p, 'r').read())
|
||||||
|
else:
|
||||||
|
user_config = {}
|
||||||
|
|
||||||
|
if model not in user_config:
|
||||||
|
user_config[model] = {}
|
||||||
|
|
||||||
|
for k in ui.list_model_elements():
|
||||||
|
user_config[model][k] = state[k]
|
||||||
|
|
||||||
|
with open(p, 'w') as f:
|
||||||
|
f.write(yaml.dump(user_config))
|
||||||
|
|
||||||
|
yield (f"Settings for {model} saved to {p}")
|
||||||
|
|
||||||
|
|
||||||
def create_model_menus():
|
def create_model_menus():
|
||||||
# Finding the default values for the GPU and CPU memories
|
# Finding the default values for the GPU and CPU memories
|
||||||
total_mem = []
|
total_mem = []
|
||||||
@ -285,10 +309,12 @@ def create_model_menus():
|
|||||||
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras(), 'value': shared.lora_names}, 'refresh-button')
|
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras(), 'value': shared.lora_names}, 'refresh-button')
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
shared.gradio['lora_menu_apply'] = gr.Button(value='Apply the selected LoRAs')
|
shared.gradio['lora_menu_apply'] = gr.Button(value='Apply the selected LoRAs')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
unload = gr.Button("Unload the model")
|
unload = gr.Button("Unload the model")
|
||||||
reload = gr.Button("Reload the model")
|
reload = gr.Button("Reload the model")
|
||||||
|
save_settings = gr.Button("Save current settings for this model")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@ -344,7 +370,11 @@ def create_model_menus():
|
|||||||
unload_model, None, None).then(
|
unload_model, None, None).then(
|
||||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
update_model_parameters, shared.gradio['interface_state'], None).then(
|
update_model_parameters, shared.gradio['interface_state'], None).then(
|
||||||
load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=True)
|
load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=False)
|
||||||
|
|
||||||
|
save_settings.click(
|
||||||
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
save_model_settings, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['model_status'], show_progress=False)
|
||||||
|
|
||||||
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False)
|
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False)
|
||||||
shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False)
|
shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user