From e18534fe122c41211c486f0275dd9afc0a178d12 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 21 May 2023 22:05:59 -0300 Subject: [PATCH] Fix "continue" in chat-instruct mode --- modules/chat.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/chat.py b/modules/chat.py index 5565b6c8..6e1930d2 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -75,6 +75,9 @@ def generate_chat_prompt(user_input, state, **kwargs): wrapper += all_substrings['instruct']['bot_turn_stripped'] if impersonate: wrapper += substrings['user_turn_stripped'].rstrip(' ') + elif _continue: + wrapper += apply_extensions("bot_prefix", substrings['bot_turn_stripped']) + wrapper += history[-1][1] else: wrapper += apply_extensions("bot_prefix", substrings['bot_turn_stripped'].rstrip(' ')) else: @@ -86,7 +89,8 @@ def generate_chat_prompt(user_input, state, **kwargs): rows = [state['context_instruct'] if is_instruct else f"{state['context'].strip()}\n"] while i >= 0 and get_encoded_length(wrapper.replace('<|prompt|>', ''.join(rows))) < max_length: if _continue and i == len(history) - 1: - rows.insert(1, substrings['bot_turn_stripped'] + history[i][1].strip()) + if state['mode'] != 'chat-instruct': + rows.insert(1, substrings['bot_turn_stripped'] + history[i][1].strip()) else: rows.insert(1, substrings['bot_turn'].replace('<|bot-message|>', history[i][1].strip()))