diff --git a/js/switch_tabs.js b/js/switch_tabs.js new file mode 100644 index 00000000..ed6c653d --- /dev/null +++ b/js/switch_tabs.js @@ -0,0 +1,31 @@ +let chat_tab = document.getElementById('chat-tab'); +let main_parent = chat_tab.parentNode; + +function switch_to_chat() { + let chat_tab_button = main_parent.childNodes[0].childNodes[1]; + chat_tab_button.click(); +} + +function switch_to_default() { + let default_tab_button = main_parent.childNodes[0].childNodes[4]; + default_tab_button.click(); +} + +function switch_to_notebook() { + let notebook_tab_button = main_parent.childNodes[0].childNodes[7]; + notebook_tab_button.click(); +} + +function switch_to_generation_parameters() { + let parameters_tab_button = main_parent.childNodes[0].childNodes[10]; + let generation_tab_button = document.getElementById('character-menu').parentNode.parentNode.parentNode.parentNode.parentNode.parentNode.childNodes[0].childNodes[1]; + parameters_tab_button.click(); + generation_tab_button.click(); +} + +function switch_to_character() { + let parameters_tab_button = main_parent.childNodes[0].childNodes[10]; + let character_tab_button = document.getElementById('character-menu').parentNode.parentNode.parentNode.parentNode.parentNode.parentNode.childNodes[0].childNodes[4]; + parameters_tab_button.click(); + character_tab_button.click(); +} diff --git a/modules/prompts.py b/modules/prompts.py index 8a3cf3e3..e7654fbf 100644 --- a/modules/prompts.py +++ b/modules/prompts.py @@ -1,4 +1,3 @@ -import re from pathlib import Path import yaml @@ -10,26 +9,6 @@ from modules.text_generation import get_encoded_length def load_prompt(fname): if fname in ['None', '']: return '' - elif fname.startswith('Instruct-'): - fname = re.sub('^Instruct-', '', fname) - file_path = Path(f'instruction-templates/{fname}.yaml') - if not file_path.exists(): - return '' - - with open(file_path, 'r', encoding='utf-8') as f: - data = yaml.safe_load(f) - output = '' - if 'context' in data: - output += data['context'] - - replacements = { - '<|user|>': data['user'], - '<|bot|>': data['bot'], - '<|user-message|>': 'Input', - } - - output += utils.replace_all(data['turn_template'].split('<|bot-message|>')[0], replacements) - return output.rstrip(' ') else: file_path = Path(f'prompts/{fname}.txt') if not file_path.exists(): @@ -43,6 +22,27 @@ def load_prompt(fname): return text +def load_instruction_prompt_simple(fname): + file_path = Path(f'instruction-templates/{fname}.yaml') + if not file_path.exists(): + return '' + + with open(file_path, 'r', encoding='utf-8') as f: + data = yaml.safe_load(f) + output = '' + if 'context' in data: + output += data['context'] + + replacements = { + '<|user|>': data['user'], + '<|bot|>': data['bot'], + '<|user-message|>': 'Input', + } + + output += utils.replace_all(data['turn_template'].split('<|bot-message|>')[0], replacements) + return output.rstrip(' ') + + def count_tokens(text): try: tokens = get_encoded_length(text) diff --git a/modules/ui.py b/modules/ui.py index e7817f73..a7d7811e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -12,6 +12,8 @@ with open(Path(__file__).resolve().parent / '../js/main.js', 'r') as f: js = f.read() with open(Path(__file__).resolve().parent / '../js/save_files.js', 'r') as f: save_files_js = f.read() +with open(Path(__file__).resolve().parent / '../js/switch_tabs.js', 'r') as f: + switch_tabs_js = f.read() refresh_symbol = '🔄' delete_symbol = '🗑️' diff --git a/modules/ui_chat.py b/modules/ui_chat.py index 76e70ed0..461cf811 100644 --- a/modules/ui_chat.py +++ b/modules/ui_chat.py @@ -5,7 +5,7 @@ from pathlib import Path import gradio as gr from PIL import Image -from modules import chat, shared, ui, utils +from modules import chat, prompts, shared, ui, utils from modules.html_generator import chat_html_wrapper from modules.text_generation import stop_everything_event from modules.utils import gradio @@ -83,6 +83,11 @@ def create_chat_settings_ui(): shared.gradio['name2_instruct'] = gr.Textbox(value='', lines=1, label='Bot string') shared.gradio['context_instruct'] = gr.Textbox(value='', lines=4, label='Context') shared.gradio['turn_template'] = gr.Textbox(value=shared.settings['turn_template'], lines=1, label='Turn template', info='Used to precisely define the placement of spaces and new line characters in instruction prompts.') + with gr.Row(): + shared.gradio['send_instruction_to_default'] = gr.Button('Send to default', elem_classes=['small-button']) + shared.gradio['send_instruction_to_notebook'] = gr.Button('Send to notebook', elem_classes=['small-button']) + shared.gradio['send_instruction_to_negative_prompt'] = gr.Button('Send to negative prompt', elem_classes=['small-button']) + with gr.Row(): shared.gradio['chat-instruct_command'] = gr.Textbox(value=shared.settings['chat-instruct_command'], lines=4, label='Command for chat-instruct mode', info='<|character|> gets replaced by the bot name, and <|prompt|> gets replaced by the regular chat prompt.', elem_classes=['add_scrollbar']) @@ -217,7 +222,7 @@ def create_event_handlers(): shared.gradio['load_chat_history'].upload( chat.load_history, gradio('load_chat_history', 'history'), gradio('history')).then( chat.redraw_html, gradio(reload_arr), gradio('display')).then( - None, None, None, _js='() => {alert("The history has been loaded.")}') + lambda: None, None, None, _js=f'() => {{{ui.switch_tabs_js}; switch_to_chat()}}') shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, gradio('history'), gradio('textbox'), show_progress=False) @@ -245,11 +250,11 @@ def create_event_handlers(): shared.gradio['Submit character'].click( chat.upload_character, gradio('upload_json', 'upload_img_bot'), gradio('character_menu')).then( - None, None, None, _js='() => {alert("The character has been loaded.")}') + lambda: None, None, None, _js=f'() => {{{ui.switch_tabs_js}; switch_to_character()}}') shared.gradio['Submit tavern character'].click( chat.upload_tavern_character, gradio('upload_img_tavern', 'tavern_json'), gradio('character_menu')).then( - None, None, None, _js='() => {alert("The character has been loaded.")}') + lambda: None, None, None, _js=f'() => {{{ui.switch_tabs_js}; switch_to_character()}}') shared.gradio['upload_json'].upload(lambda: gr.update(interactive=True), None, gradio('Submit character')) shared.gradio['upload_json'].clear(lambda: gr.update(interactive=False), None, gradio('Submit character')) @@ -258,3 +263,15 @@ def create_event_handlers(): shared.gradio['your_picture'].change( chat.upload_your_profile_picture, gradio('your_picture'), None).then( partial(chat.redraw_html, reset_cache=True), gradio(reload_arr), gradio('display')) + + shared.gradio['send_instruction_to_default'].click( + prompts.load_instruction_prompt_simple, gradio('instruction_template'), gradio('textbox-default')).then( + lambda: None, None, None, _js=f'() => {{{ui.switch_tabs_js}; switch_to_default()}}') + + shared.gradio['send_instruction_to_notebook'].click( + prompts.load_instruction_prompt_simple, gradio('instruction_template'), gradio('textbox-notebook')).then( + lambda: None, None, None, _js=f'() => {{{ui.switch_tabs_js}; switch_to_notebook()}}') + + shared.gradio['send_instruction_to_negative_prompt'].click( + prompts.load_instruction_prompt_simple, gradio('instruction_template'), gradio('negative_prompt')).then( + lambda: None, None, None, _js=f'() => {{{ui.switch_tabs_js}; switch_to_generation_parameters()}}') diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index 2f0c2efd..c6d38804 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -98,7 +98,7 @@ def create_ui(default_preset): with gr.Row(): with gr.Column(): shared.gradio['guidance_scale'] = gr.Slider(-0.5, 2.5, step=0.05, value=generate_params['guidance_scale'], label='guidance_scale', info='For CFG. 1.5 is a good value.') - shared.gradio['negative_prompt'] = gr.Textbox(value=shared.settings['negative_prompt'], label='Negative prompt') + shared.gradio['negative_prompt'] = gr.Textbox(value=shared.settings['negative_prompt'], label='Negative prompt', lines=3, elem_classes=['add_scrollbar']) shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.') shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau') shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta') diff --git a/modules/utils.py b/modules/utils.py index 6fa94730..0a7edffa 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -88,7 +88,6 @@ def get_available_prompts(): files = set((k.stem for k in Path('prompts').glob('*.txt'))) prompts += sorted([k for k in files if re.match('^[0-9]', k)], key=natural_keys, reverse=True) prompts += sorted([k for k in files if re.match('^[^0-9]', k)], key=natural_keys) - prompts += ['Instruct-' + k for k in get_available_instruction_templates() if k != 'None'] prompts += ['None'] return prompts