Stop generating in chat mode when \nYou: is generated

This commit is contained in:
oobabooga 2023-01-18 21:51:18 -03:00
parent 022960a087
commit df2e910421

View File

@ -143,7 +143,6 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1') preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
cuda = ".cuda()" if args.cpu else "" cuda = ".cuda()" if args.cpu else ""
for i in range(tokens): for i in range(tokens):
if eos_token is None: if eos_token is None:
output = eval(f"model.generate(input_ids, {preset}){cuda}") output = eval(f"model.generate(input_ids, {preset}){cuda}")
else: else:
@ -246,6 +245,7 @@ elif args.chat or args.cai_chat:
eos_token = '\n' if check else None eos_token = '\n' if check else None
for i in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token): for i in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
reply = i[0] reply = i[0]
next_character_found = False
if check: if check:
idx = reply.rfind(question[-1024:]) idx = reply.rfind(question[-1024:])
@ -256,6 +256,7 @@ elif args.chat or args.cai_chat:
idx = reply.find(f"\n{name1}:") idx = reply.find(f"\n{name1}:")
if idx != -1: if idx != -1:
reply = reply[:idx] reply = reply[:idx]
next_character_found = True
reply = clean_chat_message(reply) reply = clean_chat_message(reply)
history[-1] = [text, reply] history[-1] = [text, reply]
@ -263,14 +264,17 @@ elif args.chat or args.cai_chat:
# Prevent the chat log from flashing if something like "\nYo" is generated just # Prevent the chat log from flashing if something like "\nYo" is generated just
# before "\nYou:" is completed # before "\nYou:" is completed
tmp = f"\n{name1}:" tmp = f"\n{name1}:"
found = False next_character_substring_found = False
for j in range(1, len(tmp)): for j in range(1, len(tmp)+1):
if reply[-j:] == tmp[:j]: if reply[-j:] == tmp[:j]:
found = True next_character_substring_found = True
if not found: if not next_character_substring_found:
yield history yield history
if next_character_found:
break
def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
yield generate_chat_html(history, name1, name2) yield generate_chat_html(history, name1, name2)