Reorganize some chat functions

This commit is contained in:
oobabooga 2023-04-07 11:07:03 -03:00
parent ec979cd9c4
commit a453d4e9c4

View File

@ -105,14 +105,16 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
else: else:
stopping_strings = [f"\n{name1}:", f"\n{name2}:"] stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
eos_token = '\n' if generate_state['stop_at_newline'] else None # Defining some variables
cumulative_reply = ''
just_started = True
name1_original = name1 name1_original = name1
visible_text = custom_generate_chat_prompt = None
eos_token = '\n' if generate_state['stop_at_newline'] else None
if 'pygmalion' in shared.model_name.lower(): if 'pygmalion' in shared.model_name.lower():
name1 = "You" name1 = "You"
# Check if any extension wants to hijack this function call # Check if any extension wants to hijack this function call
visible_text = None
custom_generate_chat_prompt = None
for extension, _ in extensions_module.iterator(): for extension, _ in extensions_module.iterator():
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']: if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
extension.input_hijack['state'] = False extension.input_hijack['state'] = False
@ -124,6 +126,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
visible_text = text visible_text = text
text = apply_extensions(text, "input") text = apply_extensions(text, "input")
# Generating the prompt
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'} kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
if custom_generate_chat_prompt is None: if custom_generate_chat_prompt is None:
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs) prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
@ -135,8 +138,6 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
yield shared.history['visible'] + [[visible_text, shared.processing_message]] yield shared.history['visible'] + [[visible_text, shared.processing_message]]
# Generate # Generate
cumulative_reply = ''
just_started = True
for i in range(generate_state['chat_generation_attempts']): for i in range(generate_state['chat_generation_attempts']):
reply = None reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings): for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
@ -175,6 +176,8 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
else: else:
stopping_strings = [f"\n{name1}:", f"\n{name2}:"] stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
# Defining some variables
cumulative_reply = ''
eos_token = '\n' if generate_state['stop_at_newline'] else None eos_token = '\n' if generate_state['stop_at_newline'] else None
if 'pygmalion' in shared.model_name.lower(): if 'pygmalion' in shared.model_name.lower():
name1 = "You" name1 = "You"
@ -184,7 +187,6 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
# Yield *Is typing...* # Yield *Is typing...*
yield shared.processing_message yield shared.processing_message
cumulative_reply = ''
for i in range(generate_state['chat_generation_attempts']): for i in range(generate_state['chat_generation_attempts']):
reply = None reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings): for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
@ -264,7 +266,7 @@ def redraw_html(name1, name2, mode):
def tokenize_dialogue(dialogue, name1, name2, mode): def tokenize_dialogue(dialogue, name1, name2, mode):
history = [] history = []
messages = []
dialogue = re.sub('<START>', '', dialogue) dialogue = re.sub('<START>', '', dialogue)
dialogue = re.sub('<start>', '', dialogue) dialogue = re.sub('<start>', '', dialogue)
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue) dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
@ -273,7 +275,6 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
if len(idx) == 0: if len(idx) == 0:
return history return history
messages = []
for i in range(len(idx) - 1): for i in range(len(idx) - 1):
messages.append(dialogue[idx[i]:idx[i + 1]].strip()) messages.append(dialogue[idx[i]:idx[i + 1]].strip())
messages.append(dialogue[idx[-1]:].strip()) messages.append(dialogue[idx[-1]:].strip())