Fix "continue" in chat-instruct mode

This commit is contained in:
oobabooga 2023-05-21 22:05:59 -03:00
parent d7fabe693d
commit e18534fe12

View File

@ -75,6 +75,9 @@ def generate_chat_prompt(user_input, state, **kwargs):
wrapper += all_substrings['instruct']['bot_turn_stripped'] wrapper += all_substrings['instruct']['bot_turn_stripped']
if impersonate: if impersonate:
wrapper += substrings['user_turn_stripped'].rstrip(' ') wrapper += substrings['user_turn_stripped'].rstrip(' ')
elif _continue:
wrapper += apply_extensions("bot_prefix", substrings['bot_turn_stripped'])
wrapper += history[-1][1]
else: else:
wrapper += apply_extensions("bot_prefix", substrings['bot_turn_stripped'].rstrip(' ')) wrapper += apply_extensions("bot_prefix", substrings['bot_turn_stripped'].rstrip(' '))
else: else:
@ -86,7 +89,8 @@ def generate_chat_prompt(user_input, state, **kwargs):
rows = [state['context_instruct'] if is_instruct else f"{state['context'].strip()}\n"] rows = [state['context_instruct'] if is_instruct else f"{state['context'].strip()}\n"]
while i >= 0 and get_encoded_length(wrapper.replace('<|prompt|>', ''.join(rows))) < max_length: while i >= 0 and get_encoded_length(wrapper.replace('<|prompt|>', ''.join(rows))) < max_length:
if _continue and i == len(history) - 1: if _continue and i == len(history) - 1:
rows.insert(1, substrings['bot_turn_stripped'] + history[i][1].strip()) if state['mode'] != 'chat-instruct':
rows.insert(1, substrings['bot_turn_stripped'] + history[i][1].strip())
else: else:
rows.insert(1, substrings['bot_turn'].replace('<|bot-message|>', history[i][1].strip())) rows.insert(1, substrings['bot_turn'].replace('<|bot-message|>', history[i][1].strip()))