diff --git a/modules/text_generation.py b/modules/text_generation.py index acae1007..03f177b9 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -1,3 +1,4 @@ +import random import re import time import traceback @@ -97,10 +98,11 @@ def formatted_outputs(reply, model_name): def set_manual_seed(seed): - if seed != -1: - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) + if seed == -1: + seed = random.randint(1, 2**31) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) def stop_everything_event():