From 30befe492a18bd34c27534d852ac2df2054f8f8f Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Mon, 10 Apr 2023 06:29:10 -0700 Subject: [PATCH 1/2] fix random seeds to actually randomize Without this fix, manual seeds get locked in. --- modules/text_generation.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index acae1007..03f177b9 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -1,3 +1,4 @@ +import random import re import time import traceback @@ -97,10 +98,11 @@ def formatted_outputs(reply, model_name): def set_manual_seed(seed): - if seed != -1: - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) + if seed == -1: + seed = random.randint(1, 2**31) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) def stop_everything_event(): From 769aa900ea1f45bbababa5d07a1d120af9e1c2ad Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 10 Apr 2023 10:53:31 -0300 Subject: [PATCH 2/2] Print the used seed --- modules/text_generation.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index 03f177b9..8846eaff 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -98,11 +98,13 @@ def formatted_outputs(reply, model_name): def set_manual_seed(seed): + seed = int(seed) if seed == -1: seed = random.randint(1, 2**31) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) + return seed def stop_everything_event(): @@ -111,7 +113,7 @@ def stop_everything_event(): def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]): clear_torch_cache() - set_manual_seed(generate_state['seed']) + seed = set_manual_seed(generate_state['seed']) shared.stop_everything = False generate_params = {} t0 = time.time() @@ -153,7 +155,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[] t1 = time.time() original_tokens = len(encode(original_question)[0]) new_tokens = len(encode(output)[0]) - original_tokens - print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})') + print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') return input_ids = encode(question, generate_state['max_new_tokens']) @@ -274,5 +276,5 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[] t1 = time.time() original_tokens = len(original_input_ids[0]) new_tokens = len(output) - original_tokens - print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})') + print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') return