From 310bf46a945aacc507454509b82f6807c48cc093 Mon Sep 17 00:00:00 2001 From: OWKenobi Date: Thu, 6 Apr 2023 22:40:44 +0200 Subject: [PATCH] Instruction Character Vicuna, Instruction Mode Bugfix (#838) --- characters/instruction-following/Vicuna.yaml | 3 +++ modules/chat.py | 14 ++++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) create mode 100644 characters/instruction-following/Vicuna.yaml diff --git a/characters/instruction-following/Vicuna.yaml b/characters/instruction-following/Vicuna.yaml new file mode 100644 index 00000000..026901d4 --- /dev/null +++ b/characters/instruction-following/Vicuna.yaml @@ -0,0 +1,3 @@ +name: "### Assistant:" +your_name: "### Human:" +context: "Below is an instruction that describes a task. Write a response that appropriately completes the request." diff --git a/modules/chat.py b/modules/chat.py index 749ff8c2..36932641 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -99,6 +99,11 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline): return reply, next_character_found def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False): + if mode == 'instruct': + stopping_strings = [f"\n{name1}", f"\n{name2}"] + else: + stopping_strings = [f"\n{name1}:", f"\n{name2}:"] + eos_token = '\n' if generate_state['stop_at_newline'] else None name1_original = name1 if 'pygmalion' in shared.model_name.lower(): @@ -133,7 +138,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu just_started = True for i in range(generate_state['chat_generation_attempts']): reply = None - for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): + for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings): reply = cumulative_reply + reply # Extracting the reply @@ -163,6 +168,11 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu yield shared.history['visible'] def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): + if mode == 'instruct': + stopping_strings = [f"\n{name1}", f"\n{name2}"] + else: + stopping_strings = [f"\n{name1}:", f"\n{name2}:"] + eos_token = '\n' if generate_state['stop_at_newline'] else None if 'pygmalion' in shared.model_name.lower(): name1 = "You" @@ -175,7 +185,7 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o cumulative_reply = '' for i in range(generate_state['chat_generation_attempts']): reply = None - for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): + for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings): reply = cumulative_reply + reply reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline']) yield reply