Fix generation_attempts continuing after an empty reply

This commit is contained in:
oobabooga 2023-05-21 22:14:50 -03:00
parent e18534fe12
commit fb91406e93

View File

@ -53,7 +53,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
history = kwargs.get('history', shared.history)['internal'] history = kwargs.get('history', shared.history)['internal']
is_instruct = state['mode'] == 'instruct' is_instruct = state['mode'] == 'instruct'
# Finding the maximum prompt size # FInd the maximum prompt size
chat_prompt_size = state['chat_prompt_size'] chat_prompt_size = state['chat_prompt_size']
if shared.soft_prompt: if shared.soft_prompt:
chat_prompt_size -= shared.soft_prompt_tensor.shape[1] chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
@ -66,7 +66,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
substrings = all_substrings['instruct' if is_instruct else 'chat'] substrings = all_substrings['instruct' if is_instruct else 'chat']
# Creating the template for "chat-instruct" mode # Create the template for "chat-instruct" mode
if state['mode'] == 'chat-instruct': if state['mode'] == 'chat-instruct':
wrapper = '' wrapper = ''
command = state['chat-instruct_command'].replace('<|character|>', state['name2'] if not impersonate else state['name1']) command = state['chat-instruct_command'].replace('<|character|>', state['name2'] if not impersonate else state['name1'])
@ -83,7 +83,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
else: else:
wrapper = '<|prompt|>' wrapper = '<|prompt|>'
# Building the prompt # Build the prompt
min_rows = 3 min_rows = 3
i = len(history) - 1 i = len(history) - 1
rows = [state['context_instruct'] if is_instruct else f"{state['context'].strip()}\n"] rows = [state['context_instruct'] if is_instruct else f"{state['context'].strip()}\n"]
@ -107,11 +107,11 @@ def generate_chat_prompt(user_input, state, **kwargs):
min_rows = 2 min_rows = 2
rows.append(substrings['user_turn_stripped'].rstrip(' ')) rows.append(substrings['user_turn_stripped'].rstrip(' '))
elif not _continue: elif not _continue:
# Adding the user message # Add the user message
if len(user_input) > 0: if len(user_input) > 0:
rows.append(replace_all(substrings['user_turn'], {'<|user-message|>': user_input.strip(), '<|round|>': str(len(history))})) rows.append(replace_all(substrings['user_turn'], {'<|user-message|>': user_input.strip(), '<|round|>': str(len(history))}))
# Adding the Character prefix # Add the character prefix
if state['mode'] != 'chat-instruct': if state['mode'] != 'chat-instruct':
rows.append(apply_extensions("bot_prefix", substrings['bot_turn_stripped'].rstrip(' '))) rows.append(apply_extensions("bot_prefix", substrings['bot_turn_stripped'].rstrip(' ')))
@ -192,7 +192,6 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
return return
# Defining some variables # Defining some variables
cumulative_reply = ''
just_started = True just_started = True
visible_text = None visible_text = None
eos_token = '\n' if state['stop_at_newline'] else None eos_token = '\n' if state['stop_at_newline'] else None
@ -232,12 +231,13 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
prompt = generate_chat_prompt(text, state, **kwargs) prompt = generate_chat_prompt(text, state, **kwargs)
# Generate # Generate
cumulative_reply = ''
for i in range(state['chat_generation_attempts']): for i in range(state['chat_generation_attempts']):
reply = None reply = None
for j, reply in enumerate(generate_reply(prompt + cumulative_reply, state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True)): for j, reply in enumerate(generate_reply(prompt + cumulative_reply, state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True)):
reply = cumulative_reply + reply reply = cumulative_reply + reply
# Extracting the reply # Extract the reply
reply, next_character_found = extract_message_from_reply(reply, state) reply, next_character_found = extract_message_from_reply(reply, state)
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply) visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
visible_reply = apply_extensions("output", visible_reply) visible_reply = apply_extensions("output", visible_reply)
@ -268,7 +268,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
if next_character_found: if next_character_found:
break break
if reply in [None, '']: if reply in [None, cumulative_reply]:
break break
else: else:
cumulative_reply = reply cumulative_reply = reply