From 955cf431e8f6b42662219fca23c507cbad08bf80 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 1 Mar 2023 19:11:26 -0300 Subject: [PATCH] Minor consistency fix --- modules/text_generation.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index 1324c8b8..cdf4adff 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -68,10 +68,10 @@ def fix_galactica(s): def formatted_outputs(reply, model_name): if not (shared.args.chat or shared.args.cai_chat): - if shared.model_name.lower().startswith('galactica'): + if model_name.lower().startswith('galactica'): reply = fix_galactica(reply) return reply, reply, generate_basic_html(reply) - elif shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')): + elif model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')): reply = fix_gpt4chan(reply) return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply) else: @@ -87,13 +87,13 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi if shared.is_RWKV: if shared.args.no_stream: reply = shared.model.generate(question, token_count=max_new_tokens, temperature=temperature, top_p=top_p) - yield formatted_outputs(reply, None) + yield formatted_outputs(reply, shared.model_name) else: for i in range(max_new_tokens//8): reply = shared.model.generate(question, token_count=8, temperature=temperature, top_p=top_p) - yield formatted_outputs(reply, None) + yield formatted_outputs(reply, shared.model_name) question = reply - return formatted_outputs(reply, None) + return formatted_outputs(reply, shared.model_name) original_question = question if not (shared.args.chat or shared.args.cai_chat):