diff --git a/modules/chat.py b/modules/chat.py index 6e1930d2..386fae0a 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -53,7 +53,7 @@ def generate_chat_prompt(user_input, state, **kwargs): history = kwargs.get('history', shared.history)['internal'] is_instruct = state['mode'] == 'instruct' - # Finding the maximum prompt size + # FInd the maximum prompt size chat_prompt_size = state['chat_prompt_size'] if shared.soft_prompt: chat_prompt_size -= shared.soft_prompt_tensor.shape[1] @@ -66,7 +66,7 @@ def generate_chat_prompt(user_input, state, **kwargs): substrings = all_substrings['instruct' if is_instruct else 'chat'] - # Creating the template for "chat-instruct" mode + # Create the template for "chat-instruct" mode if state['mode'] == 'chat-instruct': wrapper = '' command = state['chat-instruct_command'].replace('<|character|>', state['name2'] if not impersonate else state['name1']) @@ -83,7 +83,7 @@ def generate_chat_prompt(user_input, state, **kwargs): else: wrapper = '<|prompt|>' - # Building the prompt + # Build the prompt min_rows = 3 i = len(history) - 1 rows = [state['context_instruct'] if is_instruct else f"{state['context'].strip()}\n"] @@ -107,11 +107,11 @@ def generate_chat_prompt(user_input, state, **kwargs): min_rows = 2 rows.append(substrings['user_turn_stripped'].rstrip(' ')) elif not _continue: - # Adding the user message + # Add the user message if len(user_input) > 0: rows.append(replace_all(substrings['user_turn'], {'<|user-message|>': user_input.strip(), '<|round|>': str(len(history))})) - # Adding the Character prefix + # Add the character prefix if state['mode'] != 'chat-instruct': rows.append(apply_extensions("bot_prefix", substrings['bot_turn_stripped'].rstrip(' '))) @@ -192,7 +192,6 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa return # Defining some variables - cumulative_reply = '' just_started = True visible_text = None eos_token = '\n' if state['stop_at_newline'] else None @@ -232,12 +231,13 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa prompt = generate_chat_prompt(text, state, **kwargs) # Generate + cumulative_reply = '' for i in range(state['chat_generation_attempts']): reply = None for j, reply in enumerate(generate_reply(prompt + cumulative_reply, state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True)): reply = cumulative_reply + reply - # Extracting the reply + # Extract the reply reply, next_character_found = extract_message_from_reply(reply, state) visible_reply = re.sub("(||{{user}})", state['name1'], reply) visible_reply = apply_extensions("output", visible_reply) @@ -268,7 +268,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa if next_character_found: break - if reply in [None, '']: + if reply in [None, cumulative_reply]: break else: cumulative_reply = reply