From a453d4e9c4c9ab76f3de138d3cc311c7b5ac261d Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 7 Apr 2023 11:07:03 -0300 Subject: [PATCH] Reorganize some chat functions --- modules/chat.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 10146086..207009da 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -105,14 +105,16 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu else: stopping_strings = [f"\n{name1}:", f"\n{name2}:"] - eos_token = '\n' if generate_state['stop_at_newline'] else None + # Defining some variables + cumulative_reply = '' + just_started = True name1_original = name1 + visible_text = custom_generate_chat_prompt = None + eos_token = '\n' if generate_state['stop_at_newline'] else None if 'pygmalion' in shared.model_name.lower(): name1 = "You" # Check if any extension wants to hijack this function call - visible_text = None - custom_generate_chat_prompt = None for extension, _ in extensions_module.iterator(): if hasattr(extension, 'input_hijack') and extension.input_hijack['state']: extension.input_hijack['state'] = False @@ -124,6 +126,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu visible_text = text text = apply_extensions(text, "input") + # Generating the prompt kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'} if custom_generate_chat_prompt is None: prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs) @@ -135,8 +138,6 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu yield shared.history['visible'] + [[visible_text, shared.processing_message]] # Generate - cumulative_reply = '' - just_started = True for i in range(generate_state['chat_generation_attempts']): reply = None for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings): @@ -175,6 +176,8 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o else: stopping_strings = [f"\n{name1}:", f"\n{name2}:"] + # Defining some variables + cumulative_reply = '' eos_token = '\n' if generate_state['stop_at_newline'] else None if 'pygmalion' in shared.model_name.lower(): name1 = "You" @@ -184,7 +187,6 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o # Yield *Is typing...* yield shared.processing_message - cumulative_reply = '' for i in range(generate_state['chat_generation_attempts']): reply = None for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings): @@ -264,7 +266,7 @@ def redraw_html(name1, name2, mode): def tokenize_dialogue(dialogue, name1, name2, mode): history = [] - + messages = [] dialogue = re.sub('', '', dialogue) dialogue = re.sub('', '', dialogue) dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue) @@ -273,7 +275,6 @@ def tokenize_dialogue(dialogue, name1, name2, mode): if len(idx) == 0: return history - messages = [] for i in range(len(idx) - 1): messages.append(dialogue[idx[i]:idx[i + 1]].strip()) messages.append(dialogue[idx[-1]:].strip())