Minor generate_chat_prompt simplification

This commit is contained in:
oobabooga 2023-04-14 23:02:08 -03:00
parent c4aa1a42b1
commit c3aa79118e

View File

@ -24,6 +24,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
is_instruct = state['mode'] == 'instruct' is_instruct = state['mode'] == 'instruct'
rows = [f"{state['context'].strip()}\n"] rows = [f"{state['context'].strip()}\n"]
min_rows = 3
# Finding the maximum prompt size # Finding the maximum prompt size
chat_prompt_size = state['chat_prompt_size'] chat_prompt_size = state['chat_prompt_size']
@ -40,21 +41,23 @@ def generate_chat_prompt(user_input, state, **kwargs):
i = len(shared.history['internal']) - 1 i = len(shared.history['internal']) - 1
while i >= 0 and len(encode(''.join(rows))[0]) < max_length: while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
if _continue and i == len(shared.history['internal']) - 1: if _continue and i == len(shared.history['internal']) - 1:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}") rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
else: else:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n") rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n")
string = shared.history['internal'][i][0] string = shared.history['internal'][i][0]
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']: if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
rows.insert(1, f"{prefix1}{string.strip()}{state['end_of_turn']}\n") rows.insert(1, f"{prefix1}{string.strip()}{state['end_of_turn']}\n")
i -= 1 i -= 1
if impersonate: if impersonate:
min_rows = 2
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}") rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
limit = 2 elif not _continue:
elif _continue:
limit = 3
else:
# Adding the user message # Adding the user message
user_input = fix_newlines(user_input) user_input = fix_newlines(user_input)
if len(user_input) > 0: if len(user_input) > 0:
@ -62,9 +65,8 @@ def generate_chat_prompt(user_input, state, **kwargs):
# Adding the Character prefix # Adding the Character prefix
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix")) rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
limit = 3
while len(rows) > limit and len(encode(''.join(rows))[0]) >= max_length: while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length:
rows.pop(1) rows.pop(1)
prompt = ''.join(rows) prompt = ''.join(rows)