mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 01:09:22 +01:00
Add a menu for customizing the instruction template for the model (#5521)
This commit is contained in:
parent
0e1d8d5601
commit
76d28eaa9e
@ -691,6 +691,9 @@ def load_character(character, name1, name2):
|
||||
|
||||
|
||||
def load_instruction_template(template):
|
||||
if template == 'None':
|
||||
return ''
|
||||
|
||||
for filepath in [Path(f'instruction-templates/{template}.yaml'), Path('instruction-templates/Alpaca.yaml')]:
|
||||
if filepath.exists():
|
||||
break
|
||||
|
@ -243,27 +243,54 @@ def save_model_settings(model, state):
|
||||
Save the settings for this model to models/config-user.yaml
|
||||
'''
|
||||
if model == 'None':
|
||||
yield ("Not saving the settings because no model is loaded.")
|
||||
yield ("Not saving the settings because no model is selected in the menu.")
|
||||
return
|
||||
|
||||
with Path(f'{shared.args.model_dir}/config-user.yaml') as p:
|
||||
if p.exists():
|
||||
user_config = yaml.safe_load(open(p, 'r').read())
|
||||
else:
|
||||
user_config = {}
|
||||
user_config = shared.load_user_config()
|
||||
model_regex = model + '$' # For exact matches
|
||||
if model_regex not in user_config:
|
||||
user_config[model_regex] = {}
|
||||
|
||||
model_regex = model + '$' # For exact matches
|
||||
if model_regex not in user_config:
|
||||
user_config[model_regex] = {}
|
||||
for k in ui.list_model_elements():
|
||||
if k == 'loader' or k in loaders.loaders_and_params[state['loader']]:
|
||||
user_config[model_regex][k] = state[k]
|
||||
|
||||
for k in ui.list_model_elements():
|
||||
if k == 'loader' or k in loaders.loaders_and_params[state['loader']]:
|
||||
user_config[model_regex][k] = state[k]
|
||||
shared.user_config = user_config
|
||||
|
||||
shared.user_config = user_config
|
||||
output = yaml.dump(user_config, sort_keys=False)
|
||||
p = Path(f'{shared.args.model_dir}/config-user.yaml')
|
||||
with open(p, 'w') as f:
|
||||
f.write(output)
|
||||
|
||||
output = yaml.dump(user_config, sort_keys=False)
|
||||
with open(p, 'w') as f:
|
||||
f.write(output)
|
||||
yield (f"Settings for `{model}` saved to `{p}`.")
|
||||
|
||||
yield (f"Settings for `{model}` saved to `{p}`.")
|
||||
|
||||
def save_instruction_template(model, template):
|
||||
'''
|
||||
Similar to the function above, but it saves only the instruction template.
|
||||
'''
|
||||
if model == 'None':
|
||||
yield ("Not saving the template because no model is selected in the menu.")
|
||||
return
|
||||
|
||||
user_config = shared.load_user_config()
|
||||
model_regex = model + '$' # For exact matches
|
||||
if model_regex not in user_config:
|
||||
user_config[model_regex] = {}
|
||||
|
||||
if template == 'None':
|
||||
user_config[model_regex].pop('instruction_template', None)
|
||||
else:
|
||||
user_config[model_regex]['instruction_template'] = template
|
||||
|
||||
shared.user_config = user_config
|
||||
|
||||
output = yaml.dump(user_config, sort_keys=False)
|
||||
p = Path(f'{shared.args.model_dir}/config-user.yaml')
|
||||
with open(p, 'w') as f:
|
||||
f.write(output)
|
||||
|
||||
if template == 'None':
|
||||
yield (f"Instruction template for `{model}` unset in `{p}`, as the value for template was `{template}`.")
|
||||
else:
|
||||
yield (f"Instruction template for `{model}` saved to `{p}` as `{template}`.")
|
||||
|
@ -279,6 +279,23 @@ def is_chat():
|
||||
return True
|
||||
|
||||
|
||||
def load_user_config():
|
||||
'''
|
||||
Loads custom model-specific settings
|
||||
'''
|
||||
if Path(f'{args.model_dir}/config-user.yaml').exists():
|
||||
file_content = open(f'{args.model_dir}/config-user.yaml', 'r').read().strip()
|
||||
|
||||
if file_content:
|
||||
user_config = yaml.safe_load(file_content)
|
||||
else:
|
||||
user_config = {}
|
||||
else:
|
||||
user_config = {}
|
||||
|
||||
return user_config
|
||||
|
||||
|
||||
args.loader = fix_loader_name(args.loader)
|
||||
|
||||
# Activate the multimodal extension
|
||||
@ -297,11 +314,7 @@ with Path(f'{args.model_dir}/config.yaml') as p:
|
||||
model_config = {}
|
||||
|
||||
# Load custom model-specific settings
|
||||
with Path(f'{args.model_dir}/config-user.yaml') as p:
|
||||
if p.exists():
|
||||
user_config = yaml.safe_load(open(p, 'r').read())
|
||||
else:
|
||||
user_config = {}
|
||||
user_config = load_user_config()
|
||||
|
||||
model_config = OrderedDict(model_config)
|
||||
user_config = OrderedDict(user_config)
|
||||
|
@ -109,7 +109,7 @@ def create_chat_settings_ui():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Saved instruction templates', info="After selecting the template, click on \"Load\" to load and apply it.", value='Select template to load...', elem_classes='slim-dropdown')
|
||||
shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Saved instruction templates', info="After selecting the template, click on \"Load\" to load and apply it.", value='None', elem_classes='slim-dropdown')
|
||||
ui.create_refresh_button(shared.gradio['instruction_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu)
|
||||
shared.gradio['load_template'] = gr.Button("Load", elem_classes='refresh-button')
|
||||
shared.gradio['save_template'] = gr.Button('💾', elem_classes='refresh-button', interactive=not mu)
|
||||
|
@ -17,6 +17,7 @@ from modules.models import load_model, unload_model
|
||||
from modules.models_settings import (
|
||||
apply_model_settings_to_state,
|
||||
get_model_metadata,
|
||||
save_instruction_template,
|
||||
save_model_settings,
|
||||
update_model_parameters
|
||||
)
|
||||
@ -165,6 +166,14 @@ def create_ui():
|
||||
shared.gradio['create_llamacpp_hf_button'] = gr.Button("Submit", variant="primary", interactive=not mu)
|
||||
gr.Markdown("This will move your gguf file into a subfolder of `models` along with the necessary tokenizer files.")
|
||||
|
||||
with gr.Tab("Customize instruction template"):
|
||||
with gr.Row():
|
||||
shared.gradio['customized_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), value='None', label='Select the desired instruction template', elem_classes='slim-dropdown')
|
||||
ui.create_refresh_button(shared.gradio['customized_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu)
|
||||
|
||||
shared.gradio['customized_template_submit'] = gr.Button("Submit", variant="primary", interactive=not mu)
|
||||
gr.Markdown("This allows you to set a customized template for the model currently selected in the \"Model loader\" menu. Whenver the model gets loaded, this template will be used in place of the template specified in the model's medatada, which sometimes is wrong.")
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
|
||||
|
||||
@ -214,6 +223,7 @@ def create_event_handlers():
|
||||
shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
|
||||
shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), gradio('autoload_model'), gradio('load_model'))
|
||||
shared.gradio['create_llamacpp_hf_button'].click(create_llamacpp_hf, gradio('gguf_menu', 'unquantized_url'), gradio('model_status'), show_progress=True)
|
||||
shared.gradio['customized_template_submit'].click(save_instruction_template, gradio('model_menu', 'customized_template'), gradio('model_status'), show_progress=True)
|
||||
|
||||
|
||||
def load_model_wrapper(selected_model, loader, autoload=False):
|
||||
@ -320,3 +330,7 @@ def update_truncation_length(current_length, state):
|
||||
return state['n_ctx']
|
||||
|
||||
return current_length
|
||||
|
||||
|
||||
def save_model_template(model, template):
|
||||
pass
|
||||
|
@ -114,7 +114,7 @@ def get_available_instruction_templates():
|
||||
if os.path.exists(path):
|
||||
paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
||||
|
||||
return ['Select template to load...'] + sorted(set((k.stem for k in paths)), key=natural_keys)
|
||||
return ['None'] + sorted(set((k.stem for k in paths)), key=natural_keys)
|
||||
|
||||
|
||||
def get_available_extensions():
|
||||
|
Loading…
Reference in New Issue
Block a user