mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-25 05:48:55 +01:00
Fix generation_attempts continuing after an empty reply
This commit is contained in:
parent
e18534fe12
commit
fb91406e93
@ -53,7 +53,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
history = kwargs.get('history', shared.history)['internal']
|
history = kwargs.get('history', shared.history)['internal']
|
||||||
is_instruct = state['mode'] == 'instruct'
|
is_instruct = state['mode'] == 'instruct'
|
||||||
|
|
||||||
# Finding the maximum prompt size
|
# FInd the maximum prompt size
|
||||||
chat_prompt_size = state['chat_prompt_size']
|
chat_prompt_size = state['chat_prompt_size']
|
||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||||
@ -66,7 +66,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
|
|
||||||
substrings = all_substrings['instruct' if is_instruct else 'chat']
|
substrings = all_substrings['instruct' if is_instruct else 'chat']
|
||||||
|
|
||||||
# Creating the template for "chat-instruct" mode
|
# Create the template for "chat-instruct" mode
|
||||||
if state['mode'] == 'chat-instruct':
|
if state['mode'] == 'chat-instruct':
|
||||||
wrapper = ''
|
wrapper = ''
|
||||||
command = state['chat-instruct_command'].replace('<|character|>', state['name2'] if not impersonate else state['name1'])
|
command = state['chat-instruct_command'].replace('<|character|>', state['name2'] if not impersonate else state['name1'])
|
||||||
@ -83,7 +83,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
else:
|
else:
|
||||||
wrapper = '<|prompt|>'
|
wrapper = '<|prompt|>'
|
||||||
|
|
||||||
# Building the prompt
|
# Build the prompt
|
||||||
min_rows = 3
|
min_rows = 3
|
||||||
i = len(history) - 1
|
i = len(history) - 1
|
||||||
rows = [state['context_instruct'] if is_instruct else f"{state['context'].strip()}\n"]
|
rows = [state['context_instruct'] if is_instruct else f"{state['context'].strip()}\n"]
|
||||||
@ -107,11 +107,11 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
min_rows = 2
|
min_rows = 2
|
||||||
rows.append(substrings['user_turn_stripped'].rstrip(' '))
|
rows.append(substrings['user_turn_stripped'].rstrip(' '))
|
||||||
elif not _continue:
|
elif not _continue:
|
||||||
# Adding the user message
|
# Add the user message
|
||||||
if len(user_input) > 0:
|
if len(user_input) > 0:
|
||||||
rows.append(replace_all(substrings['user_turn'], {'<|user-message|>': user_input.strip(), '<|round|>': str(len(history))}))
|
rows.append(replace_all(substrings['user_turn'], {'<|user-message|>': user_input.strip(), '<|round|>': str(len(history))}))
|
||||||
|
|
||||||
# Adding the Character prefix
|
# Add the character prefix
|
||||||
if state['mode'] != 'chat-instruct':
|
if state['mode'] != 'chat-instruct':
|
||||||
rows.append(apply_extensions("bot_prefix", substrings['bot_turn_stripped'].rstrip(' ')))
|
rows.append(apply_extensions("bot_prefix", substrings['bot_turn_stripped'].rstrip(' ')))
|
||||||
|
|
||||||
@ -192,7 +192,6 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Defining some variables
|
# Defining some variables
|
||||||
cumulative_reply = ''
|
|
||||||
just_started = True
|
just_started = True
|
||||||
visible_text = None
|
visible_text = None
|
||||||
eos_token = '\n' if state['stop_at_newline'] else None
|
eos_token = '\n' if state['stop_at_newline'] else None
|
||||||
@ -232,12 +231,13 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
|
|||||||
prompt = generate_chat_prompt(text, state, **kwargs)
|
prompt = generate_chat_prompt(text, state, **kwargs)
|
||||||
|
|
||||||
# Generate
|
# Generate
|
||||||
|
cumulative_reply = ''
|
||||||
for i in range(state['chat_generation_attempts']):
|
for i in range(state['chat_generation_attempts']):
|
||||||
reply = None
|
reply = None
|
||||||
for j, reply in enumerate(generate_reply(prompt + cumulative_reply, state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True)):
|
for j, reply in enumerate(generate_reply(prompt + cumulative_reply, state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True)):
|
||||||
reply = cumulative_reply + reply
|
reply = cumulative_reply + reply
|
||||||
|
|
||||||
# Extracting the reply
|
# Extract the reply
|
||||||
reply, next_character_found = extract_message_from_reply(reply, state)
|
reply, next_character_found = extract_message_from_reply(reply, state)
|
||||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
||||||
visible_reply = apply_extensions("output", visible_reply)
|
visible_reply = apply_extensions("output", visible_reply)
|
||||||
@ -268,7 +268,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
|
|||||||
if next_character_found:
|
if next_character_found:
|
||||||
break
|
break
|
||||||
|
|
||||||
if reply in [None, '']:
|
if reply in [None, cumulative_reply]:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
cumulative_reply = reply
|
cumulative_reply = reply
|
||||||
|
Loading…
Reference in New Issue
Block a user