diff --git a/modules/text_generation.py b/modules/text_generation.py index 1324c8b8..cdf4adff 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -68,10 +68,10 @@ def fix_galactica(s): def formatted_outputs(reply, model_name): if not (shared.args.chat or shared.args.cai_chat): - if shared.model_name.lower().startswith('galactica'): + if model_name.lower().startswith('galactica'): reply = fix_galactica(reply) return reply, reply, generate_basic_html(reply) - elif shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')): + elif model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')): reply = fix_gpt4chan(reply) return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply) else: @@ -87,13 +87,13 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi if shared.is_RWKV: if shared.args.no_stream: reply = shared.model.generate(question, token_count=max_new_tokens, temperature=temperature, top_p=top_p) - yield formatted_outputs(reply, None) + yield formatted_outputs(reply, shared.model_name) else: for i in range(max_new_tokens//8): reply = shared.model.generate(question, token_count=8, temperature=temperature, top_p=top_p) - yield formatted_outputs(reply, None) + yield formatted_outputs(reply, shared.model_name) question = reply - return formatted_outputs(reply, None) + return formatted_outputs(reply, shared.model_name) original_question = question if not (shared.args.chat or shared.args.cai_chat):