Text Generation: stop if EOS token is reached (#4213)

This commit is contained in:
Brian Dashore 2023-10-07 18:46:42 -04:00 committed by GitHub
parent 7743b5e9de
commit 98fa73a974
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -353,10 +353,11 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
with generate_with_streaming(**generate_params) as generator: with generate_with_streaming(**generate_params) as generator:
for output in generator: for output in generator:
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
if output[-1] in eos_token_ids: if output[-1] in eos_token_ids:
break break
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
finally: finally: