From 7f664213692c31166869f67614453189dc4ab8ff Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 5 Apr 2023 14:22:32 -0300 Subject: [PATCH] Fix loading characters --- modules/chat.py | 8 ++++---- modules/html_generator.py | 3 +-- server.py | 6 +++--- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 978a08f2..1140b5fa 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -332,7 +332,7 @@ def generate_pfp_cache(character): return img return None -def load_character(character, name1, name2, instruct=False): +def load_character(character, name1, name2, mode): shared.character = character shared.history['internal'] = [] shared.history['visible'] = [] @@ -345,7 +345,7 @@ def load_character(character, name1, name2, instruct=False): Path("cache/pfp_character.png").unlink() if character != 'None': - folder = "characters" if not instruct else "characters/instruction-following" + folder = 'characters' if not mode == 'instruct' else 'characters/instruction-following' picture = generate_pfp_cache(character) for extension in ["yml", "yaml", "json"]: filepath = Path(f'{folder}/{character}.{extension}') @@ -386,10 +386,10 @@ def load_character(character, name1, name2, instruct=False): shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] shared.history['visible'] += [['', apply_extensions(greeting, "output")]] - return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, reset_cache=True) + return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True) def load_default_history(name1, name2): - load_character("None", name1, name2) + load_character("None", name1, name2, "chat") def upload_character(json_file, img, tavern=False): json_file = json_file if type(json_file) == str else json_file.decode('utf-8') diff --git a/modules/html_generator.py b/modules/html_generator.py index 6fb8457f..e5c0bb56 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -203,8 +203,7 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False): def generate_chat_html(history, name1, name2): return generate_cai_chat_html(history, name1, name2) -def chat_html_wrapper(history, name1, name2, mode="cai-chat", reset_cache=False): - +def chat_html_wrapper(history, name1, name2, mode, reset_cache=False): if mode == "cai-chat": return generate_cai_chat_html(history, name1, name2, reset_cache) elif mode == "chat": diff --git a/server.py b/server.py index 5bcb2cb7..8bcb6502 100644 --- a/server.py +++ b/server.py @@ -304,7 +304,7 @@ def create_interface(): if shared.is_chat(): shared.gradio['Chat input'] = gr.State() with gr.Tab("Text generation", elem_id="main"): - shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'])) + shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat')) shared.gradio['textbox'] = gr.Textbox(label='Input') with gr.Row(): shared.gradio['Generate'] = gr.Button('Generate') @@ -412,8 +412,8 @@ def create_interface(): shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False) shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) - shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'display']]) - shared.gradio['Instruction templates'].change(lambda character, name1, name2: chat.load_character(character, name1, name2, instruct=True), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]) + shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]) + shared.gradio['Instruction templates'].change(lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]) shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], []) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display'])