diff --git a/modules/chat.py b/modules/chat.py index 613cae1b..3106d3d2 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -112,6 +112,13 @@ def generate_chat_prompt(user_input, state, **kwargs): if user_input and not impersonate and not _continue: messages.append({"role": "user", "content": user_input}) + def remove_extra_bos(prompt): + for bos_token in ['', '<|startoftext|>']: + while prompt.startswith(bos_token): + prompt = prompt[len(bos_token):] + + return prompt + def make_prompt(messages): if state['mode'] == 'chat-instruct' and _continue: prompt = renderer(messages=messages[:-1]) @@ -123,6 +130,7 @@ def generate_chat_prompt(user_input, state, **kwargs): if state['custom_system_message'].strip() != '': outer_messages.append({"role": "system", "content": state['custom_system_message']}) + prompt = remove_extra_bos(prompt) command = state['chat-instruct_command'] command = command.replace('<|character|>', state['name2'] if not impersonate else state['name1']) command = command.replace('<|prompt|>', prompt) @@ -153,6 +161,7 @@ def generate_chat_prompt(user_input, state, **kwargs): prompt += prefix + prompt = remove_extra_bos(prompt) return prompt prompt = make_prompt(messages)