Revert "Fix stopping strings for llama-3 and phi (#6043)"

This reverts commit 5499bc9bc8.
This commit is contained in:
oobabooga 2024-05-22 17:18:08 -07:00
parent 5499bc9bc8
commit ad54d524f7

View File

@ -45,35 +45,34 @@ yaml.add_representer(str, str_presenter)
yaml.representer.SafeRepresenter.add_representer(str, str_presenter) yaml.representer.SafeRepresenter.add_representer(str, str_presenter)
def extract_message_prefix_suffix(renderer, strip_trailing_spaces=True): def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=True):
''' '''
Given a Jinja template, extracts the prefix and suffix for Given a Jinja template, reverse-engineers the prefix and the suffix for
an assistant message and a user message. It assumes that they an assistant message (if impersonate=False) or an user message
share the same suffix. (if impersonate=True)
''' '''
messages = [ if impersonate:
{"role": "user", "content": "<<|user-message-1|>>"}, messages = [
{"role": "assistant", "content": "<<|assistant-message-1|>>"}, {"role": "user", "content": "<<|user-message-1|>>"},
{"role": "user", "content": "<<|user-message-2|>>"}, {"role": "user", "content": "<<|user-message-2|>>"},
{"role": "assistant", "content": "<<|assistant-message-2|>>"}, ]
] else:
messages = [
{"role": "assistant", "content": "<<|user-message-1|>>"},
{"role": "assistant", "content": "<<|user-message-2|>>"},
]
prompt = renderer(messages=messages) prompt = renderer(messages=messages)
unwanted_suffix = renderer(messages=[])
suffix = prompt.split('<<|assistant-message-2|>>')[1] suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0]
if unwanted_suffix != '': suffix = prompt.split("<<|user-message-2|>>")[1]
suffix = suffix[:-len(unwanted_suffix)] prefix = suffix_plus_prefix[len(suffix):]
prefix_user = prompt.split('<<|assistant-message-1|>>')[1].split('<<|user-message-2|>>')[0][len(suffix):]
prefix_assistant = prompt.split('<<|user-message-1|>>')[1].split('<<|assistant-message-1|>>')[0][len(suffix):]
if strip_trailing_spaces: if strip_trailing_spaces:
prefix_user = prefix_user.rstrip(' ') prefix = prefix.rstrip(' ')
prefix_assistant = prefix_assistant.rstrip(' ')
return prefix_user, prefix_assistant, suffix return prefix, suffix
def generate_chat_prompt(user_input, state, **kwargs): def generate_chat_prompt(user_input, state, **kwargs):
@ -126,12 +125,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
messages.append({"role": "user", "content": user_input}) messages.append({"role": "user", "content": user_input})
def remove_extra_bos(prompt): def remove_extra_bos(prompt):
if hasattr(shared.tokenizer, 'bos_token_id'): for bos_token in ['<s>', '<|startoftext|>', '<BOS_TOKEN>', '<|endoftext|>']:
bos_tokens = [shared.tokenizer.decode(shared.tokenizer.bos_token_id)]
else:
bos_tokens = ['<s>', '<|startoftext|>', '<BOS_TOKEN>']
for bos_token in bos_tokens:
while prompt.startswith(bos_token): while prompt.startswith(bos_token):
prompt = prompt[len(bos_token):] prompt = prompt[len(bos_token):]
@ -143,9 +137,6 @@ def generate_chat_prompt(user_input, state, **kwargs):
else: else:
prompt = renderer(messages=messages) prompt = renderer(messages=messages)
prefix_user, prefix_assistant, suffix = extract_message_prefix_suffix(renderer, strip_trailing_spaces=not _continue)
prefix = prefix_user if impersonate else prefix_assistant
if state['mode'] == 'chat-instruct': if state['mode'] == 'chat-instruct':
outer_messages = [] outer_messages = []
if state['custom_system_message'].strip() != '': if state['custom_system_message'].strip() != '':
@ -157,25 +148,29 @@ def generate_chat_prompt(user_input, state, **kwargs):
command = command.replace('<|prompt|>', prompt) command = command.replace('<|prompt|>', prompt)
command = replace_character_names(command, state['name1'], state['name2']) command = replace_character_names(command, state['name1'], state['name2'])
if _continue: if _continue:
prefix = get_generation_prompt(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0]
prefix += messages[-1]["content"] prefix += messages[-1]["content"]
elif not impersonate: else:
prefix = apply_extensions('bot_prefix', prefix, state) prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
if not impersonate:
prefix = apply_extensions('bot_prefix', prefix, state)
outer_messages.append({"role": "user", "content": command}) outer_messages.append({"role": "user", "content": command})
outer_messages.append({"role": "assistant", "content": prefix}) outer_messages.append({"role": "assistant", "content": prefix})
prompt = instruction_template.render(messages=outer_messages) prompt = instruction_template.render(messages=outer_messages)
suffix = get_generation_prompt(instruct_renderer, impersonate=False)[1]
if len(suffix) > 0: if len(suffix) > 0:
prompt = prompt[:-len(suffix)] prompt = prompt[:-len(suffix)]
else: else:
if _continue: if _continue:
suffix = get_generation_prompt(renderer, impersonate=impersonate)[1]
if len(suffix) > 0: if len(suffix) > 0:
prompt = prompt[:-len(suffix)] prompt = prompt[:-len(suffix)]
else: else:
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
if state['mode'] == 'chat' and not impersonate: if state['mode'] == 'chat' and not impersonate:
prefix = apply_extensions('bot_prefix', prefix, state) prefix = apply_extensions('bot_prefix', prefix, state)
@ -254,11 +249,15 @@ def get_stopping_strings(state):
renderers.append(renderer) renderers.append(renderer)
for renderer in renderers: for renderer in renderers:
prefix_user, prefix_assistant, suffix = extract_message_prefix_suffix(renderer) prefix_bot, suffix_bot = get_generation_prompt(renderer, impersonate=False)
prefix_user, suffix_user = get_generation_prompt(renderer, impersonate=True)
for item in [suffix + prefix_assistant, suffix + prefix_user, suffix]: stopping_strings += [
stopping_strings.append(item) suffix_user + prefix_bot,
stopping_strings.append(item.rstrip()) suffix_user + prefix_user,
suffix_bot + prefix_bot,
suffix_bot + prefix_user,
]
if 'stopping_strings' in state and isinstance(state['stopping_strings'], list): if 'stopping_strings' in state and isinstance(state['stopping_strings'], list):
stopping_strings += state.pop('stopping_strings') stopping_strings += state.pop('stopping_strings')