From 169209805d6703f11f9214bcc1eda555a2642b60 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 2 Mar 2023 11:25:04 -0300 Subject: [PATCH] Model-aware prompts and presets --- modules/shared.py | 15 +++++++++++---- server.py | 28 ++++++++++++---------------- settings-template.json | 21 ++++++++++++++++----- 3 files changed, 39 insertions(+), 25 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index d59c1344..90db11c4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -22,12 +22,9 @@ settings = { 'max_new_tokens': 200, 'max_new_tokens_min': 1, 'max_new_tokens_max': 2000, - 'preset': 'NovelAI-Sphinx Moth', 'name1': 'Person 1', 'name2': 'Person 2', 'context': 'This is a conversation between two people.', - 'prompt': 'Common sense questions and answers\n\nQuestion: \nFactual answer:', - 'prompt_gpt4chan': '-----\n--- 865467536\nInput text\n--- 865467537\n', 'stop_at_newline': True, 'chat_prompt_size': 2048, 'chat_prompt_size_min': 0, @@ -35,13 +32,23 @@ settings = { 'chat_generation_attempts': 1, 'chat_generation_attempts_min': 1, 'chat_generation_attempts_max': 5, - 'preset_pygmalion': 'Pygmalion', 'name1_pygmalion': 'You', 'name2_pygmalion': 'Kawaii', 'context_pygmalion': "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n", 'stop_at_newline_pygmalion': False, 'default_extensions': [], 'chat_default_extensions': ["gallery"], + 'presets': { + 'default': 'NovelAI-Sphinx Moth', + 'pygmalion-*': 'Pygmalion', + 'RWKV-*': 'Naive', + '(rosey|chip|joi)_.*_instruct.*': 'Instruct Joi (Contrastive Search)' + }, + 'prompts': { + 'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:', + '^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n', + '(rosey|chip|joi)_.*_instruct.*': 'User: \n' + } } parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54)) diff --git a/server.py b/server.py index 523fcff3..f8a7693c 100644 --- a/server.py +++ b/server.py @@ -94,8 +94,8 @@ def upload_soft_prompt(file): return name -def create_settings_menus(): - generate_params = load_preset_values(shared.settings[f'preset{suffix}'] if not shared.args.flexgen else 'Naive', return_dict=True) +def create_settings_menus(default_preset): + generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True) with gr.Row(): with gr.Column(): @@ -104,7 +104,7 @@ def create_settings_menus(): ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button') with gr.Column(): with gr.Row(): - shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=shared.settings[f'preset{suffix}'] if not shared.args.flexgen else 'Naive', label='Generation parameters preset') + shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset') ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button') with gr.Accordion('Custom generation parameters', open=False, elem_id='accordion'): @@ -150,8 +150,8 @@ available_presets = get_available_presets() available_characters = get_available_characters() available_softprompts = get_available_softprompts() +# Default extensions extensions_module.available_extensions = get_available_extensions() -# Activate the default extensions if shared.args.chat or shared.args.cai_chat: for extension in shared.settings['chat_default_extensions']: shared.args.extensions = shared.args.extensions or [] @@ -165,7 +165,7 @@ else: if shared.args.extensions is not None and len(shared.args.extensions) > 0: extensions_module.load_extensions() -# Choosing the default model +# Default model if shared.args.model is not None: shared.model_name = shared.args.model else: @@ -184,16 +184,12 @@ else: shared.model_name = available_models[i] shared.model, shared.tokenizer = load_model(shared.model_name) -# UI settings +# Default UI settings gen_events = [] -suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else '' +default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')] +default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')] description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n' -if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')): - default_text = shared.settings['prompt_gpt4chan'] -elif re.match('(rosey|chip|joi)_.*_instruct.*', shared.model_name.lower()) is not None: - default_text = 'User: \n' -else: - default_text = shared.settings['prompt'] +suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else '' if shared.args.chat or shared.args.cai_chat: with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False) as shared.gradio['interface']: @@ -257,7 +253,7 @@ if shared.args.chat or shared.args.cai_chat: with gr.Column(): shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)') - create_settings_menus() + create_settings_menus(default_preset) shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']] if shared.args.extensions is not None: @@ -321,7 +317,7 @@ elif shared.args.notebook: shared.gradio['Stop'] = gr.Button('Stop') shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) - create_settings_menus() + create_settings_menus(default_preset) if shared.args.extensions is not None: extensions_module.create_extensions_block() @@ -345,7 +341,7 @@ else: with gr.Column(): shared.gradio['Stop'] = gr.Button('Stop') - create_settings_menus() + create_settings_menus(default_preset) if shared.args.extensions is not None: extensions_module.create_extensions_block() diff --git a/settings-template.json b/settings-template.json index 13165641..6585f313 100644 --- a/settings-template.json +++ b/settings-template.json @@ -2,12 +2,9 @@ "max_new_tokens": 200, "max_new_tokens_min": 1, "max_new_tokens_max": 2000, - "preset": "NovelAI-Sphinx Moth", "name1": "Person 1", "name2": "Person 2", "context": "This is a conversation between two people.", - "prompt": "Common sense questions and answers\n\nQuestion: \nFactual answer:", - "prompt_gpt4chan": "-----\n--- 865467536\nInput text\n--- 865467537\n", "stop_at_newline": true, "chat_prompt_size": 2048, "chat_prompt_size_min": 0, @@ -15,9 +12,23 @@ "chat_generation_attempts": 1, "chat_generation_attempts_min": 1, "chat_generation_attempts_max": 5, - "preset_pygmalion": "Pygmalion", "name1_pygmalion": "You", "name2_pygmalion": "Kawaii", "context_pygmalion": "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n", - "stop_at_newline_pygmalion": false + "stop_at_newline_pygmalion": false, + "default_extensions": [], + "chat_default_extensions": [ + "gallery" + ], + "presets": { + "default": "NovelAI-Sphinx Moth", + "pygmalion-*": "Pygmalion", + "RWKV-*": "Naive", + "(rosey|chip|joi)_.*_instruct.*": "Instruct Joi (Contrastive Search)" + }, + "prompts": { + "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:", + "^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n", + "(rosey|chip|joi)_.*_instruct.*": "User: \n" + } }