From 161cae001b1ba1562a7c6f537e16c9916bc06e2e Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 29 Jan 2023 23:20:22 -0300 Subject: [PATCH] I needed this --- server.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/server.py b/server.py index 140cdc4e..65810370 100644 --- a/server.py +++ b/server.py @@ -390,7 +390,15 @@ if args.chat or args.cai_chat: next_character_found = True reply = clean_chat_message(reply) - return reply, next_character_found + # Detect if something like "\nYo" is generated just before + # "\nYou:" is completed + tmp = f"\n{other}:" + substring_found = False + for j in range(1, len(tmp)): + if reply[-j:] == tmp[:j]: + substring_found = True + + return reply, next_character_found, substring_found def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): original_text = text @@ -400,21 +408,25 @@ if args.chat or args.cai_chat: history['visible'].append(['', '']) eos_token = '\n' if check else None for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name1}:"): - reply, next_character_found = extract_message_from_reply(question, reply, name2, name1, check, extensions=True) + reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name2, name1, check, extensions=True) history['internal'][-1] = [text, reply] history['visible'][-1] = [original_text, apply_extensions(reply, "output")] - yield history['visible'] + if not substring_found: + yield history['visible'] if next_character_found: break + yield history['visible'] def impersonate_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): question = generate_chat_prompt(text, tokens, name1, name2, context, history_size, impersonate=True) eos_token = '\n' if check else None for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name2}:"): - reply, next_character_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False) - yield apply_extensions(reply, "output") + reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False) + if not substring_found: + yield apply_extensions(reply, "output") if next_character_found: break + yield apply_extensions(reply, "output") def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): for _history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):