diff --git a/modules/text_generation.py b/modules/text_generation.py index b8b2f496..f4cc25d4 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -120,6 +120,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi try: if shared.args.no_stream: reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty) + output = original_question+reply if not (shared.args.chat or shared.args.cai_chat): reply = original_question + apply_extensions(reply, "output") yield formatted_outputs(reply, shared.model_name) @@ -130,6 +131,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi # RWKV has proper streaming, which is very nice. # No need to generate 8 tokens at a time. for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k): + output = original_question+reply if not (shared.args.chat or shared.args.cai_chat): reply = original_question + apply_extensions(reply, "output") yield formatted_outputs(reply, shared.model_name) @@ -138,9 +140,9 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi traceback.print_exc() finally: t1 = time.time() - output = encode(reply)[0] - input_ids = encode(question) - print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)") + original_tokens = len(encode(original_question)[0]) + new_tokens = len(encode(output)[0]) - original_tokens + print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})") return input_ids = encode(question, max_new_tokens) @@ -272,5 +274,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi traceback.print_exc() finally: t1 = time.time() - print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens, context {len(original_input_ids[0])})") + original_tokens = len(original_input_ids[0]) + new_tokens = len(output)-original_tokens + print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})") return