From dcc3e540058752bb3fcf223cf85e477c73efdd71 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 21 May 2023 22:54:28 -0300 Subject: [PATCH] Various "impersonate" fixes --- modules/chat.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 7e980f32..a4d10f2d 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -53,7 +53,7 @@ def generate_chat_prompt(user_input, state, **kwargs): history = kwargs.get('history', shared.history)['internal'] is_instruct = state['mode'] == 'instruct' - # FInd the maximum prompt size + # Find the maximum prompt size chat_prompt_size = state['chat_prompt_size'] if shared.soft_prompt: chat_prompt_size -= shared.soft_prompt_tensor.shape[1] @@ -295,16 +295,19 @@ def impersonate_wrapper(text, state): for reply in generate_reply(prompt + cumulative_reply, state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True): reply = cumulative_reply + reply reply, next_character_found = extract_message_from_reply(reply, state) - yield reply + yield reply.lstrip(' ') + if shared.stop_everything: + return + if next_character_found: break - if reply in [None, '']: + if reply in [None, cumulative_reply]: break else: cumulative_reply = reply - yield cumulative_reply + yield cumulative_reply.lstrip(' ') def generate_chat_reply(text, history, state, regenerate=False, _continue=False, loading_message=True):