Minor consistency fix

This commit is contained in:
oobabooga 2023-03-01 19:11:26 -03:00
parent 7a9b4407b0
commit 955cf431e8

View File

@ -68,10 +68,10 @@ def fix_galactica(s):
def formatted_outputs(reply, model_name): def formatted_outputs(reply, model_name):
if not (shared.args.chat or shared.args.cai_chat): if not (shared.args.chat or shared.args.cai_chat):
if shared.model_name.lower().startswith('galactica'): if model_name.lower().startswith('galactica'):
reply = fix_galactica(reply) reply = fix_galactica(reply)
return reply, reply, generate_basic_html(reply) return reply, reply, generate_basic_html(reply)
elif shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')): elif model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
reply = fix_gpt4chan(reply) reply = fix_gpt4chan(reply)
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply) return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
else: else:
@ -87,13 +87,13 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
if shared.is_RWKV: if shared.is_RWKV:
if shared.args.no_stream: if shared.args.no_stream:
reply = shared.model.generate(question, token_count=max_new_tokens, temperature=temperature, top_p=top_p) reply = shared.model.generate(question, token_count=max_new_tokens, temperature=temperature, top_p=top_p)
yield formatted_outputs(reply, None) yield formatted_outputs(reply, shared.model_name)
else: else:
for i in range(max_new_tokens//8): for i in range(max_new_tokens//8):
reply = shared.model.generate(question, token_count=8, temperature=temperature, top_p=top_p) reply = shared.model.generate(question, token_count=8, temperature=temperature, top_p=top_p)
yield formatted_outputs(reply, None) yield formatted_outputs(reply, shared.model_name)
question = reply question = reply
return formatted_outputs(reply, None) return formatted_outputs(reply, shared.model_name)
original_question = question original_question = question
if not (shared.args.chat or shared.args.cai_chat): if not (shared.args.chat or shared.args.cai_chat):