Fix output_ids decoding for Qwen/Qwen-7B-Chat (#5045)

This commit is contained in:
zhangningboo 2023-12-23 10:11:02 +08:00 committed by GitHub
parent dbe438564e
commit 1b8b61b928
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -265,7 +265,14 @@ def apply_stopping_strings(reply, all_stop_strings):
def get_reply_from_output_ids(output_ids, state, starting_from=0): def get_reply_from_output_ids(output_ids, state, starting_from=0):
reply = decode(output_ids[starting_from:], state['skip_special_tokens']) reply = decode(output_ids[starting_from:], state['skip_special_tokens'])
if (hasattr(shared.tokenizer, 'convert_ids_to_tokens') and len(output_ids) > starting_from and shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])).startswith('')) and not reply.startswith(' '):
# Handle tokenizers that do not add the leading space for the first token
if (hasattr(shared.tokenizer, 'convert_ids_to_tokens') and len(output_ids) > starting_from) and not reply.startswith(' '):
first_token = shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from]))
if isinstance(first_token, (bytes,)):
first_token = first_token.decode('utf8')
if first_token.startswith(''):
reply = ' ' + reply reply = ' ' + reply
return reply return reply