Add a menu for customizing the instruction template for the model (#5521)

This commit is contained in:
oobabooga 2024-02-16 14:21:17 -03:00 committed by GitHub
parent 0e1d8d5601
commit 76d28eaa9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 81 additions and 24 deletions

View File

@ -691,6 +691,9 @@ def load_character(character, name1, name2):
def load_instruction_template(template): def load_instruction_template(template):
if template == 'None':
return ''
for filepath in [Path(f'instruction-templates/{template}.yaml'), Path('instruction-templates/Alpaca.yaml')]: for filepath in [Path(f'instruction-templates/{template}.yaml'), Path('instruction-templates/Alpaca.yaml')]:
if filepath.exists(): if filepath.exists():
break break

View File

@ -243,15 +243,10 @@ def save_model_settings(model, state):
Save the settings for this model to models/config-user.yaml Save the settings for this model to models/config-user.yaml
''' '''
if model == 'None': if model == 'None':
yield ("Not saving the settings because no model is loaded.") yield ("Not saving the settings because no model is selected in the menu.")
return return
with Path(f'{shared.args.model_dir}/config-user.yaml') as p: user_config = shared.load_user_config()
if p.exists():
user_config = yaml.safe_load(open(p, 'r').read())
else:
user_config = {}
model_regex = model + '$' # For exact matches model_regex = model + '$' # For exact matches
if model_regex not in user_config: if model_regex not in user_config:
user_config[model_regex] = {} user_config[model_regex] = {}
@ -263,7 +258,39 @@ def save_model_settings(model, state):
shared.user_config = user_config shared.user_config = user_config
output = yaml.dump(user_config, sort_keys=False) output = yaml.dump(user_config, sort_keys=False)
p = Path(f'{shared.args.model_dir}/config-user.yaml')
with open(p, 'w') as f: with open(p, 'w') as f:
f.write(output) f.write(output)
yield (f"Settings for `{model}` saved to `{p}`.") yield (f"Settings for `{model}` saved to `{p}`.")
def save_instruction_template(model, template):
'''
Similar to the function above, but it saves only the instruction template.
'''
if model == 'None':
yield ("Not saving the template because no model is selected in the menu.")
return
user_config = shared.load_user_config()
model_regex = model + '$' # For exact matches
if model_regex not in user_config:
user_config[model_regex] = {}
if template == 'None':
user_config[model_regex].pop('instruction_template', None)
else:
user_config[model_regex]['instruction_template'] = template
shared.user_config = user_config
output = yaml.dump(user_config, sort_keys=False)
p = Path(f'{shared.args.model_dir}/config-user.yaml')
with open(p, 'w') as f:
f.write(output)
if template == 'None':
yield (f"Instruction template for `{model}` unset in `{p}`, as the value for template was `{template}`.")
else:
yield (f"Instruction template for `{model}` saved to `{p}` as `{template}`.")

View File

@ -279,6 +279,23 @@ def is_chat():
return True return True
def load_user_config():
'''
Loads custom model-specific settings
'''
if Path(f'{args.model_dir}/config-user.yaml').exists():
file_content = open(f'{args.model_dir}/config-user.yaml', 'r').read().strip()
if file_content:
user_config = yaml.safe_load(file_content)
else:
user_config = {}
else:
user_config = {}
return user_config
args.loader = fix_loader_name(args.loader) args.loader = fix_loader_name(args.loader)
# Activate the multimodal extension # Activate the multimodal extension
@ -297,11 +314,7 @@ with Path(f'{args.model_dir}/config.yaml') as p:
model_config = {} model_config = {}
# Load custom model-specific settings # Load custom model-specific settings
with Path(f'{args.model_dir}/config-user.yaml') as p: user_config = load_user_config()
if p.exists():
user_config = yaml.safe_load(open(p, 'r').read())
else:
user_config = {}
model_config = OrderedDict(model_config) model_config = OrderedDict(model_config)
user_config = OrderedDict(user_config) user_config = OrderedDict(user_config)

View File

@ -109,7 +109,7 @@ def create_chat_settings_ui():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Saved instruction templates', info="After selecting the template, click on \"Load\" to load and apply it.", value='Select template to load...', elem_classes='slim-dropdown') shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Saved instruction templates', info="After selecting the template, click on \"Load\" to load and apply it.", value='None', elem_classes='slim-dropdown')
ui.create_refresh_button(shared.gradio['instruction_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu) ui.create_refresh_button(shared.gradio['instruction_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu)
shared.gradio['load_template'] = gr.Button("Load", elem_classes='refresh-button') shared.gradio['load_template'] = gr.Button("Load", elem_classes='refresh-button')
shared.gradio['save_template'] = gr.Button('💾', elem_classes='refresh-button', interactive=not mu) shared.gradio['save_template'] = gr.Button('💾', elem_classes='refresh-button', interactive=not mu)

View File

@ -17,6 +17,7 @@ from modules.models import load_model, unload_model
from modules.models_settings import ( from modules.models_settings import (
apply_model_settings_to_state, apply_model_settings_to_state,
get_model_metadata, get_model_metadata,
save_instruction_template,
save_model_settings, save_model_settings,
update_model_parameters update_model_parameters
) )
@ -165,6 +166,14 @@ def create_ui():
shared.gradio['create_llamacpp_hf_button'] = gr.Button("Submit", variant="primary", interactive=not mu) shared.gradio['create_llamacpp_hf_button'] = gr.Button("Submit", variant="primary", interactive=not mu)
gr.Markdown("This will move your gguf file into a subfolder of `models` along with the necessary tokenizer files.") gr.Markdown("This will move your gguf file into a subfolder of `models` along with the necessary tokenizer files.")
with gr.Tab("Customize instruction template"):
with gr.Row():
shared.gradio['customized_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), value='None', label='Select the desired instruction template', elem_classes='slim-dropdown')
ui.create_refresh_button(shared.gradio['customized_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu)
shared.gradio['customized_template_submit'] = gr.Button("Submit", variant="primary", interactive=not mu)
gr.Markdown("This allows you to set a customized template for the model currently selected in the \"Model loader\" menu. Whenver the model gets loaded, this template will be used in place of the template specified in the model's medatada, which sometimes is wrong.")
with gr.Row(): with gr.Row():
shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready') shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
@ -214,6 +223,7 @@ def create_event_handlers():
shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True) shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), gradio('autoload_model'), gradio('load_model')) shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), gradio('autoload_model'), gradio('load_model'))
shared.gradio['create_llamacpp_hf_button'].click(create_llamacpp_hf, gradio('gguf_menu', 'unquantized_url'), gradio('model_status'), show_progress=True) shared.gradio['create_llamacpp_hf_button'].click(create_llamacpp_hf, gradio('gguf_menu', 'unquantized_url'), gradio('model_status'), show_progress=True)
shared.gradio['customized_template_submit'].click(save_instruction_template, gradio('model_menu', 'customized_template'), gradio('model_status'), show_progress=True)
def load_model_wrapper(selected_model, loader, autoload=False): def load_model_wrapper(selected_model, loader, autoload=False):
@ -320,3 +330,7 @@ def update_truncation_length(current_length, state):
return state['n_ctx'] return state['n_ctx']
return current_length return current_length
def save_model_template(model, template):
pass

View File

@ -114,7 +114,7 @@ def get_available_instruction_templates():
if os.path.exists(path): if os.path.exists(path):
paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml')) paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
return ['Select template to load...'] + sorted(set((k.stem for k in paths)), key=natural_keys) return ['None'] + sorted(set((k.stem for k in paths)), key=natural_keys)
def get_available_extensions(): def get_available_extensions():