From 76d28eaa9e26ac3c8e6f9b06a1f7d25e75894f56 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 16 Feb 2024 14:21:17 -0300 Subject: [PATCH] Add a menu for customizing the instruction template for the model (#5521) --- modules/chat.py | 3 ++ modules/models_settings.py | 61 +++++++++++++++++++++++++++----------- modules/shared.py | 23 ++++++++++---- modules/ui_chat.py | 2 +- modules/ui_model_menu.py | 14 +++++++++ modules/utils.py | 2 +- 6 files changed, 81 insertions(+), 24 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index c431d2d0..de7f19de 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -691,6 +691,9 @@ def load_character(character, name1, name2): def load_instruction_template(template): + if template == 'None': + return '' + for filepath in [Path(f'instruction-templates/{template}.yaml'), Path('instruction-templates/Alpaca.yaml')]: if filepath.exists(): break diff --git a/modules/models_settings.py b/modules/models_settings.py index 9acc7efa..b4473275 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -243,27 +243,54 @@ def save_model_settings(model, state): Save the settings for this model to models/config-user.yaml ''' if model == 'None': - yield ("Not saving the settings because no model is loaded.") + yield ("Not saving the settings because no model is selected in the menu.") return - with Path(f'{shared.args.model_dir}/config-user.yaml') as p: - if p.exists(): - user_config = yaml.safe_load(open(p, 'r').read()) - else: - user_config = {} + user_config = shared.load_user_config() + model_regex = model + '$' # For exact matches + if model_regex not in user_config: + user_config[model_regex] = {} - model_regex = model + '$' # For exact matches - if model_regex not in user_config: - user_config[model_regex] = {} + for k in ui.list_model_elements(): + if k == 'loader' or k in loaders.loaders_and_params[state['loader']]: + user_config[model_regex][k] = state[k] - for k in ui.list_model_elements(): - if k == 'loader' or k in loaders.loaders_and_params[state['loader']]: - user_config[model_regex][k] = state[k] + shared.user_config = user_config - shared.user_config = user_config + output = yaml.dump(user_config, sort_keys=False) + p = Path(f'{shared.args.model_dir}/config-user.yaml') + with open(p, 'w') as f: + f.write(output) - output = yaml.dump(user_config, sort_keys=False) - with open(p, 'w') as f: - f.write(output) + yield (f"Settings for `{model}` saved to `{p}`.") - yield (f"Settings for `{model}` saved to `{p}`.") + +def save_instruction_template(model, template): + ''' + Similar to the function above, but it saves only the instruction template. + ''' + if model == 'None': + yield ("Not saving the template because no model is selected in the menu.") + return + + user_config = shared.load_user_config() + model_regex = model + '$' # For exact matches + if model_regex not in user_config: + user_config[model_regex] = {} + + if template == 'None': + user_config[model_regex].pop('instruction_template', None) + else: + user_config[model_regex]['instruction_template'] = template + + shared.user_config = user_config + + output = yaml.dump(user_config, sort_keys=False) + p = Path(f'{shared.args.model_dir}/config-user.yaml') + with open(p, 'w') as f: + f.write(output) + + if template == 'None': + yield (f"Instruction template for `{model}` unset in `{p}`, as the value for template was `{template}`.") + else: + yield (f"Instruction template for `{model}` saved to `{p}` as `{template}`.") diff --git a/modules/shared.py b/modules/shared.py index 2861d690..d8aef367 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -279,6 +279,23 @@ def is_chat(): return True +def load_user_config(): + ''' + Loads custom model-specific settings + ''' + if Path(f'{args.model_dir}/config-user.yaml').exists(): + file_content = open(f'{args.model_dir}/config-user.yaml', 'r').read().strip() + + if file_content: + user_config = yaml.safe_load(file_content) + else: + user_config = {} + else: + user_config = {} + + return user_config + + args.loader = fix_loader_name(args.loader) # Activate the multimodal extension @@ -297,11 +314,7 @@ with Path(f'{args.model_dir}/config.yaml') as p: model_config = {} # Load custom model-specific settings -with Path(f'{args.model_dir}/config-user.yaml') as p: - if p.exists(): - user_config = yaml.safe_load(open(p, 'r').read()) - else: - user_config = {} +user_config = load_user_config() model_config = OrderedDict(model_config) user_config = OrderedDict(user_config) diff --git a/modules/ui_chat.py b/modules/ui_chat.py index 42e5cae2..7576628d 100644 --- a/modules/ui_chat.py +++ b/modules/ui_chat.py @@ -109,7 +109,7 @@ def create_chat_settings_ui(): with gr.Row(): with gr.Column(): with gr.Row(): - shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Saved instruction templates', info="After selecting the template, click on \"Load\" to load and apply it.", value='Select template to load...', elem_classes='slim-dropdown') + shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Saved instruction templates', info="After selecting the template, click on \"Load\" to load and apply it.", value='None', elem_classes='slim-dropdown') ui.create_refresh_button(shared.gradio['instruction_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu) shared.gradio['load_template'] = gr.Button("Load", elem_classes='refresh-button') shared.gradio['save_template'] = gr.Button('💾', elem_classes='refresh-button', interactive=not mu) diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index ca0de873..94b01937 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -17,6 +17,7 @@ from modules.models import load_model, unload_model from modules.models_settings import ( apply_model_settings_to_state, get_model_metadata, + save_instruction_template, save_model_settings, update_model_parameters ) @@ -165,6 +166,14 @@ def create_ui(): shared.gradio['create_llamacpp_hf_button'] = gr.Button("Submit", variant="primary", interactive=not mu) gr.Markdown("This will move your gguf file into a subfolder of `models` along with the necessary tokenizer files.") + with gr.Tab("Customize instruction template"): + with gr.Row(): + shared.gradio['customized_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), value='None', label='Select the desired instruction template', elem_classes='slim-dropdown') + ui.create_refresh_button(shared.gradio['customized_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu) + + shared.gradio['customized_template_submit'] = gr.Button("Submit", variant="primary", interactive=not mu) + gr.Markdown("This allows you to set a customized template for the model currently selected in the \"Model loader\" menu. Whenver the model gets loaded, this template will be used in place of the template specified in the model's medatada, which sometimes is wrong.") + with gr.Row(): shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready') @@ -214,6 +223,7 @@ def create_event_handlers(): shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True) shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), gradio('autoload_model'), gradio('load_model')) shared.gradio['create_llamacpp_hf_button'].click(create_llamacpp_hf, gradio('gguf_menu', 'unquantized_url'), gradio('model_status'), show_progress=True) + shared.gradio['customized_template_submit'].click(save_instruction_template, gradio('model_menu', 'customized_template'), gradio('model_status'), show_progress=True) def load_model_wrapper(selected_model, loader, autoload=False): @@ -320,3 +330,7 @@ def update_truncation_length(current_length, state): return state['n_ctx'] return current_length + + +def save_model_template(model, template): + pass diff --git a/modules/utils.py b/modules/utils.py index be06ec34..4b65736b 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -114,7 +114,7 @@ def get_available_instruction_templates(): if os.path.exists(path): paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml')) - return ['Select template to load...'] + sorted(set((k.stem for k in paths)), key=natural_keys) + return ['None'] + sorted(set((k.stem for k in paths)), key=natural_keys) def get_available_extensions():