Add also_return_rows to generate_chat_prompt

This commit is contained in:
oobabooga 2023-04-01 01:12:13 -03:00
parent 8c51b405e4
commit fcda3f8776

View File

@ -22,7 +22,7 @@ def generate_chat_output(history, name1, name2, character):
else: else:
return history return history
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False): def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False, also_return_rows=False):
user_input = fix_newlines(user_input) user_input = fix_newlines(user_input)
rows = [f"{context.strip()}\n"] rows = [f"{context.strip()}\n"]
@ -51,6 +51,10 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
rows.pop(1) rows.pop(1)
prompt = ''.join(rows) prompt = ''.join(rows)
if also_return_rows:
return prompt, rows
else:
return prompt return prompt
def extract_message_from_reply(reply, name1, name2, check): def extract_message_from_reply(reply, name1, name2, check):