mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
Print the used seed
This commit is contained in:
parent
30befe492a
commit
769aa900ea
@ -98,11 +98,13 @@ def formatted_outputs(reply, model_name):
|
|||||||
|
|
||||||
|
|
||||||
def set_manual_seed(seed):
|
def set_manual_seed(seed):
|
||||||
|
seed = int(seed)
|
||||||
if seed == -1:
|
if seed == -1:
|
||||||
seed = random.randint(1, 2**31)
|
seed = random.randint(1, 2**31)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
return seed
|
||||||
|
|
||||||
|
|
||||||
def stop_everything_event():
|
def stop_everything_event():
|
||||||
@ -111,7 +113,7 @@ def stop_everything_event():
|
|||||||
|
|
||||||
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
|
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
set_manual_seed(generate_state['seed'])
|
seed = set_manual_seed(generate_state['seed'])
|
||||||
shared.stop_everything = False
|
shared.stop_everything = False
|
||||||
generate_params = {}
|
generate_params = {}
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
@ -153,7 +155,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
original_tokens = len(encode(original_question)[0])
|
original_tokens = len(encode(original_question)[0])
|
||||||
new_tokens = len(encode(output)[0]) - original_tokens
|
new_tokens = len(encode(output)[0]) - original_tokens
|
||||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})')
|
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||||
return
|
return
|
||||||
|
|
||||||
input_ids = encode(question, generate_state['max_new_tokens'])
|
input_ids = encode(question, generate_state['max_new_tokens'])
|
||||||
@ -274,5 +276,5 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
original_tokens = len(original_input_ids[0])
|
original_tokens = len(original_input_ids[0])
|
||||||
new_tokens = len(output) - original_tokens
|
new_tokens = len(output) - original_tokens
|
||||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})')
|
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||||
return
|
return
|
||||||
|
Loading…
Reference in New Issue
Block a user