I needed this

This commit is contained in:
oobabooga 2023-01-29 23:20:22 -03:00
parent 3ebca480f6
commit 161cae001b

View File

@ -390,7 +390,15 @@ if args.chat or args.cai_chat:
next_character_found = True
reply = clean_chat_message(reply)
return reply, next_character_found
# Detect if something like "\nYo" is generated just before
# "\nYou:" is completed
tmp = f"\n{other}:"
substring_found = False
for j in range(1, len(tmp)):
if reply[-j:] == tmp[:j]:
substring_found = True
return reply, next_character_found, substring_found
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
original_text = text
@ -400,21 +408,25 @@ if args.chat or args.cai_chat:
history['visible'].append(['', ''])
eos_token = '\n' if check else None
for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name1}:"):
reply, next_character_found = extract_message_from_reply(question, reply, name2, name1, check, extensions=True)
reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name2, name1, check, extensions=True)
history['internal'][-1] = [text, reply]
history['visible'][-1] = [original_text, apply_extensions(reply, "output")]
yield history['visible']
if not substring_found:
yield history['visible']
if next_character_found:
break
yield history['visible']
def impersonate_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size, impersonate=True)
eos_token = '\n' if check else None
for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token, stopping_string=f"\n{name2}:"):
reply, next_character_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False)
yield apply_extensions(reply, "output")
reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False)
if not substring_found:
yield apply_extensions(reply, "output")
if next_character_found:
break
yield apply_extensions(reply, "output")
def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
for _history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):