Add a progress bar

This commit is contained in:
oobabooga 2023-01-19 12:20:57 -03:00
parent 5390fc87c8
commit 39bfea5a22

View File

@ -12,6 +12,7 @@ from html_generator import *
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
import gc
from tqdm import tqdm
transformers.logging.set_verbosity_error()
@ -175,7 +176,7 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
yield formatted_outputs(question, model_name)
input_ids = encode(question, 1)
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
for i in range(tokens):
for i in tqdm(range(tokens)):
output = eval(f"model.generate(input_ids, {preset}){cuda}")
reply = decode(output[0])
if eos_token is not None and reply[-1] == eos_token: