mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-25 22:08:53 +01:00
Minor bug fix after https://github.com/oobabooga/text-generation-webui/pull/4814
This commit is contained in:
parent
c21a9668a5
commit
6430acadde
@ -264,14 +264,10 @@ def apply_stopping_strings(reply, all_stop_strings):
|
||||
|
||||
|
||||
def get_reply_from_output_ids(output_ids, state, starting_from=0):
|
||||
if shared.is_seq2seq:
|
||||
reply = decode(output_ids, state['skip_special_tokens'])
|
||||
else:
|
||||
reply = decode(output_ids[starting_from:], state['skip_special_tokens'])
|
||||
# Prevent LlamaTokenizer from skipping a space
|
||||
if type(shared.tokenizer) in [transformers.LlamaTokenizer, transformers.LlamaTokenizerFast] and len(output_ids) > 0:
|
||||
if shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])).startswith('▁'):
|
||||
reply = ' ' + reply
|
||||
reply = decode(output_ids[starting_from:], state['skip_special_tokens'])
|
||||
if type(shared.tokenizer) in [transformers.LlamaTokenizer, transformers.LlamaTokenizerFast] and len(output_ids) > starting_from:
|
||||
if shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])).startswith('▁'):
|
||||
reply = ' ' + reply
|
||||
|
||||
return reply
|
||||
|
||||
@ -343,7 +339,8 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
||||
if cuda:
|
||||
output = output.cuda()
|
||||
|
||||
yield get_reply_from_output_ids(output, state, starting_from=len(input_ids[0]))
|
||||
starting_from = 0 if shared.is_seq2seq else len(input_ids[0])
|
||||
yield get_reply_from_output_ids(output, state, starting_from=starting_from)
|
||||
|
||||
# Stream the reply 1 token at a time.
|
||||
# This is based on the trick of using 'stopping_criteria' to create an iterator.
|
||||
@ -360,7 +357,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
||||
|
||||
with generate_with_streaming(**generate_params) as generator:
|
||||
cumulative_reply = ''
|
||||
starting_from = len(input_ids[0])
|
||||
starting_from = 0 if shared.is_seq2seq else len(input_ids[0])
|
||||
for output in generator:
|
||||
if output[-1] in eos_token_ids:
|
||||
break
|
||||
|
Loading…
Reference in New Issue
Block a user