Instruction templates: better handle unwanted bos tokens

This commit is contained in:
oobabooga 2023-12-17 21:04:03 -08:00
parent 3f3cd4fbe4
commit cac89df97b

View File

@ -112,6 +112,13 @@ def generate_chat_prompt(user_input, state, **kwargs):
if user_input and not impersonate and not _continue: if user_input and not impersonate and not _continue:
messages.append({"role": "user", "content": user_input}) messages.append({"role": "user", "content": user_input})
def remove_extra_bos(prompt):
for bos_token in ['<s>', '<|startoftext|>']:
while prompt.startswith(bos_token):
prompt = prompt[len(bos_token):]
return prompt
def make_prompt(messages): def make_prompt(messages):
if state['mode'] == 'chat-instruct' and _continue: if state['mode'] == 'chat-instruct' and _continue:
prompt = renderer(messages=messages[:-1]) prompt = renderer(messages=messages[:-1])
@ -123,6 +130,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
if state['custom_system_message'].strip() != '': if state['custom_system_message'].strip() != '':
outer_messages.append({"role": "system", "content": state['custom_system_message']}) outer_messages.append({"role": "system", "content": state['custom_system_message']})
prompt = remove_extra_bos(prompt)
command = state['chat-instruct_command'] command = state['chat-instruct_command']
command = command.replace('<|character|>', state['name2'] if not impersonate else state['name1']) command = command.replace('<|character|>', state['name2'] if not impersonate else state['name1'])
command = command.replace('<|prompt|>', prompt) command = command.replace('<|prompt|>', prompt)
@ -153,6 +161,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
prompt += prefix prompt += prefix
prompt = remove_extra_bos(prompt)
return prompt return prompt
prompt = make_prompt(messages) prompt = make_prompt(messages)