From df2e910421d547bb226dd9089d51ceaaaf00e7c4 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 18 Jan 2023 21:51:18 -0300 Subject: [PATCH] Stop generating in chat mode when \nYou: is generated --- server.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/server.py b/server.py index 79fe6628..0cbbb0b2 100644 --- a/server.py +++ b/server.py @@ -143,7 +143,6 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1') cuda = ".cuda()" if args.cpu else "" for i in range(tokens): - if eos_token is None: output = eval(f"model.generate(input_ids, {preset}){cuda}") else: @@ -246,6 +245,7 @@ elif args.chat or args.cai_chat: eos_token = '\n' if check else None for i in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token): reply = i[0] + next_character_found = False if check: idx = reply.rfind(question[-1024:]) @@ -256,6 +256,7 @@ elif args.chat or args.cai_chat: idx = reply.find(f"\n{name1}:") if idx != -1: reply = reply[:idx] + next_character_found = True reply = clean_chat_message(reply) history[-1] = [text, reply] @@ -263,14 +264,17 @@ elif args.chat or args.cai_chat: # Prevent the chat log from flashing if something like "\nYo" is generated just # before "\nYou:" is completed tmp = f"\n{name1}:" - found = False - for j in range(1, len(tmp)): + next_character_substring_found = False + for j in range(1, len(tmp)+1): if reply[-j:] == tmp[:j]: - found = True + next_character_substring_found = True - if not found: + if not next_character_substring_found: yield history + if next_character_found: + break + def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): yield generate_chat_html(history, name1, name2)