Add a progress bar

This commit is contained in:
oobabooga 2023-01-19 12:20:57 -03:00
parent 5390fc87c8
commit 39bfea5a22

View File

@ -12,6 +12,7 @@ from html_generator import *
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings import warnings
import gc import gc
from tqdm import tqdm
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
@ -175,7 +176,7 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
yield formatted_outputs(question, model_name) yield formatted_outputs(question, model_name)
input_ids = encode(question, 1) input_ids = encode(question, 1)
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1') preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
for i in range(tokens): for i in tqdm(range(tokens)):
output = eval(f"model.generate(input_ids, {preset}){cuda}") output = eval(f"model.generate(input_ids, {preset}){cuda}")
reply = decode(output[0]) reply = decode(output[0])
if eos_token is not None and reply[-1] == eos_token: if eos_token is not None and reply[-1] == eos_token: