Properly count tokens/s for llama.cpp in chat mode

This commit is contained in:
oobabooga 2023-03-31 17:00:55 -03:00
parent 5c4e44b452
commit 0aee7341d8

View File

@ -120,6 +120,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
try:
if shared.args.no_stream:
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
output = original_question+reply
if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply, "output")
yield formatted_outputs(reply, shared.model_name)
@ -130,6 +131,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# RWKV has proper streaming, which is very nice.
# No need to generate 8 tokens at a time.
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
output = original_question+reply
if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply, "output")
yield formatted_outputs(reply, shared.model_name)
@ -138,9 +140,9 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
traceback.print_exc()
finally:
t1 = time.time()
output = encode(reply)[0]
input_ids = encode(question)
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
original_tokens = len(encode(original_question)[0])
new_tokens = len(encode(output)[0]) - original_tokens
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
return
input_ids = encode(question, max_new_tokens)
@ -272,5 +274,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
traceback.print_exc()
finally:
t1 = time.time()
print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens, context {len(original_input_ids[0])})")
original_tokens = len(original_input_ids[0])
new_tokens = len(output)-original_tokens
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
return