Print the used seed

This commit is contained in:
oobabooga 2023-04-10 10:53:31 -03:00
parent 30befe492a
commit 769aa900ea

View File

@ -98,11 +98,13 @@ def formatted_outputs(reply, model_name):
def set_manual_seed(seed): def set_manual_seed(seed):
seed = int(seed)
if seed == -1: if seed == -1:
seed = random.randint(1, 2**31) seed = random.randint(1, 2**31)
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
return seed
def stop_everything_event(): def stop_everything_event():
@ -111,7 +113,7 @@ def stop_everything_event():
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]): def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
clear_torch_cache() clear_torch_cache()
set_manual_seed(generate_state['seed']) seed = set_manual_seed(generate_state['seed'])
shared.stop_everything = False shared.stop_everything = False
generate_params = {} generate_params = {}
t0 = time.time() t0 = time.time()
@ -153,7 +155,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
t1 = time.time() t1 = time.time()
original_tokens = len(encode(original_question)[0]) original_tokens = len(encode(original_question)[0])
new_tokens = len(encode(output)[0]) - original_tokens new_tokens = len(encode(output)[0]) - original_tokens
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})') print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return return
input_ids = encode(question, generate_state['max_new_tokens']) input_ids = encode(question, generate_state['max_new_tokens'])
@ -274,5 +276,5 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
t1 = time.time() t1 = time.time()
original_tokens = len(original_input_ids[0]) original_tokens = len(original_input_ids[0])
new_tokens = len(output) - original_tokens new_tokens = len(output) - original_tokens
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})') print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return return