From 7c4d5ca8cca25d5e43a4423ac8d69f4583ec933e Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 1 Mar 2023 16:40:25 -0300 Subject: [PATCH] Improve the text generation call a bit --- modules/RWKV.py | 2 +- modules/text_generation.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/modules/RWKV.py b/modules/RWKV.py index 98b11847..88f1ec23 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -42,4 +42,4 @@ class RWKVModel: token_stop = token_stop ) - return self.pipeline.generate(context, token_count=token_count, args=args, callback=callback) + return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback) diff --git a/modules/text_generation.py b/modules/text_generation.py index 4c9d1f0e..cc8b62d4 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -86,15 +86,14 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi if shared.is_RWKV: if shared.args.no_stream: - reply = question + shared.model.generate(question, token_count=max_new_tokens, temperature=temperature) + reply = shared.model.generate(question, token_count=max_new_tokens, temperature=temperature) yield formatted_outputs(reply, None) - return formatted_outputs(reply, None) else: for i in range(max_new_tokens//8): - reply = question + shared.model.generate(question, token_count=8, temperature=temperature) + reply = shared.model.generate(question, token_count=8, temperature=temperature) yield formatted_outputs(reply, None) question = reply - return formatted_outputs(reply, None) + return formatted_outputs(reply, None) original_question = question if not (shared.args.chat or shared.args.cai_chat):