mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Minor bug fix after https://github.com/oobabooga/text-generation-webui/pull/4814
This commit is contained in:
parent
c21a9668a5
commit
6430acadde
@ -264,14 +264,10 @@ def apply_stopping_strings(reply, all_stop_strings):
|
|||||||
|
|
||||||
|
|
||||||
def get_reply_from_output_ids(output_ids, state, starting_from=0):
|
def get_reply_from_output_ids(output_ids, state, starting_from=0):
|
||||||
if shared.is_seq2seq:
|
reply = decode(output_ids[starting_from:], state['skip_special_tokens'])
|
||||||
reply = decode(output_ids, state['skip_special_tokens'])
|
if type(shared.tokenizer) in [transformers.LlamaTokenizer, transformers.LlamaTokenizerFast] and len(output_ids) > starting_from:
|
||||||
else:
|
if shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])).startswith('▁'):
|
||||||
reply = decode(output_ids[starting_from:], state['skip_special_tokens'])
|
reply = ' ' + reply
|
||||||
# Prevent LlamaTokenizer from skipping a space
|
|
||||||
if type(shared.tokenizer) in [transformers.LlamaTokenizer, transformers.LlamaTokenizerFast] and len(output_ids) > 0:
|
|
||||||
if shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])).startswith('▁'):
|
|
||||||
reply = ' ' + reply
|
|
||||||
|
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
@ -343,7 +339,8 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||||||
if cuda:
|
if cuda:
|
||||||
output = output.cuda()
|
output = output.cuda()
|
||||||
|
|
||||||
yield get_reply_from_output_ids(output, state, starting_from=len(input_ids[0]))
|
starting_from = 0 if shared.is_seq2seq else len(input_ids[0])
|
||||||
|
yield get_reply_from_output_ids(output, state, starting_from=starting_from)
|
||||||
|
|
||||||
# Stream the reply 1 token at a time.
|
# Stream the reply 1 token at a time.
|
||||||
# This is based on the trick of using 'stopping_criteria' to create an iterator.
|
# This is based on the trick of using 'stopping_criteria' to create an iterator.
|
||||||
@ -360,7 +357,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||||||
|
|
||||||
with generate_with_streaming(**generate_params) as generator:
|
with generate_with_streaming(**generate_params) as generator:
|
||||||
cumulative_reply = ''
|
cumulative_reply = ''
|
||||||
starting_from = len(input_ids[0])
|
starting_from = 0 if shared.is_seq2seq else len(input_ids[0])
|
||||||
for output in generator:
|
for output in generator:
|
||||||
if output[-1] in eos_token_ids:
|
if output[-1] in eos_token_ids:
|
||||||
break
|
break
|
||||||
|
Loading…
Reference in New Issue
Block a user