Generate 8 tokens at a time in streaming mode instead of just 1

This is a performance optimization.
This commit is contained in:
oobabooga 2023-01-25 10:38:26 -03:00
parent 651eb50dd1
commit ebed1dea56

View File

@ -204,8 +204,8 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
# Generate the reply 1 token at a time # Generate the reply 1 token at a time
else: else:
yield formatted_outputs(question, model_name) yield formatted_outputs(question, model_name)
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1') preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=8')
for i in tqdm(range(tokens)): for i in tqdm(range(tokens//8+1)):
output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}") output = eval(f"model.generate(input_ids, eos_token_id={n}, stopping_criteria=stopping_criteria_list, {preset}){cuda}")
reply = decode(output[0]) reply = decode(output[0])
if eos_token is not None and reply[-1] == eos_token: if eos_token is not None and reply[-1] == eos_token: