This commit is contained in:
oobabooga 2023-12-05 10:05:54 -08:00
parent c21a9668a5
commit 6430acadde

View File

@ -264,12 +264,8 @@ def apply_stopping_strings(reply, all_stop_strings):
def get_reply_from_output_ids(output_ids, state, starting_from=0): def get_reply_from_output_ids(output_ids, state, starting_from=0):
if shared.is_seq2seq:
reply = decode(output_ids, state['skip_special_tokens'])
else:
reply = decode(output_ids[starting_from:], state['skip_special_tokens']) reply = decode(output_ids[starting_from:], state['skip_special_tokens'])
# Prevent LlamaTokenizer from skipping a space if type(shared.tokenizer) in [transformers.LlamaTokenizer, transformers.LlamaTokenizerFast] and len(output_ids) > starting_from:
if type(shared.tokenizer) in [transformers.LlamaTokenizer, transformers.LlamaTokenizerFast] and len(output_ids) > 0:
if shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])).startswith(''): if shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])).startswith(''):
reply = ' ' + reply reply = ' ' + reply
@ -343,7 +339,8 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
if cuda: if cuda:
output = output.cuda() output = output.cuda()
yield get_reply_from_output_ids(output, state, starting_from=len(input_ids[0])) starting_from = 0 if shared.is_seq2seq else len(input_ids[0])
yield get_reply_from_output_ids(output, state, starting_from=starting_from)
# Stream the reply 1 token at a time. # Stream the reply 1 token at a time.
# This is based on the trick of using 'stopping_criteria' to create an iterator. # This is based on the trick of using 'stopping_criteria' to create an iterator.
@ -360,7 +357,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
with generate_with_streaming(**generate_params) as generator: with generate_with_streaming(**generate_params) as generator:
cumulative_reply = '' cumulative_reply = ''
starting_from = len(input_ids[0]) starting_from = 0 if shared.is_seq2seq else len(input_ids[0])
for output in generator: for output in generator:
if output[-1] in eos_token_ids: if output[-1] in eos_token_ids:
break break