mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
Minor generate_chat_prompt simplification
This commit is contained in:
parent
c4aa1a42b1
commit
c3aa79118e
@ -24,6 +24,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||||
is_instruct = state['mode'] == 'instruct'
|
is_instruct = state['mode'] == 'instruct'
|
||||||
rows = [f"{state['context'].strip()}\n"]
|
rows = [f"{state['context'].strip()}\n"]
|
||||||
|
min_rows = 3
|
||||||
|
|
||||||
# Finding the maximum prompt size
|
# Finding the maximum prompt size
|
||||||
chat_prompt_size = state['chat_prompt_size']
|
chat_prompt_size = state['chat_prompt_size']
|
||||||
@ -40,21 +41,23 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
|
|
||||||
i = len(shared.history['internal']) - 1
|
i = len(shared.history['internal']) - 1
|
||||||
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
|
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
|
||||||
|
|
||||||
if _continue and i == len(shared.history['internal']) - 1:
|
if _continue and i == len(shared.history['internal']) - 1:
|
||||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
|
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
|
||||||
else:
|
else:
|
||||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n")
|
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n")
|
||||||
|
|
||||||
string = shared.history['internal'][i][0]
|
string = shared.history['internal'][i][0]
|
||||||
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||||
rows.insert(1, f"{prefix1}{string.strip()}{state['end_of_turn']}\n")
|
rows.insert(1, f"{prefix1}{string.strip()}{state['end_of_turn']}\n")
|
||||||
|
|
||||||
i -= 1
|
i -= 1
|
||||||
|
|
||||||
if impersonate:
|
if impersonate:
|
||||||
|
min_rows = 2
|
||||||
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
|
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
|
||||||
limit = 2
|
elif not _continue:
|
||||||
elif _continue:
|
|
||||||
limit = 3
|
|
||||||
else:
|
|
||||||
# Adding the user message
|
# Adding the user message
|
||||||
user_input = fix_newlines(user_input)
|
user_input = fix_newlines(user_input)
|
||||||
if len(user_input) > 0:
|
if len(user_input) > 0:
|
||||||
@ -62,9 +65,8 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
|
|
||||||
# Adding the Character prefix
|
# Adding the Character prefix
|
||||||
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
|
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
|
||||||
limit = 3
|
|
||||||
|
|
||||||
while len(rows) > limit and len(encode(''.join(rows))[0]) >= max_length:
|
while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length:
|
||||||
rows.pop(1)
|
rows.pop(1)
|
||||||
prompt = ''.join(rows)
|
prompt = ''.join(rows)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user