mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-23 21:18:00 +01:00
Fix generation_attempts continuing after an empty reply
This commit is contained in:
parent
e18534fe12
commit
fb91406e93
@ -53,7 +53,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||
history = kwargs.get('history', shared.history)['internal']
|
||||
is_instruct = state['mode'] == 'instruct'
|
||||
|
||||
# Finding the maximum prompt size
|
||||
# FInd the maximum prompt size
|
||||
chat_prompt_size = state['chat_prompt_size']
|
||||
if shared.soft_prompt:
|
||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||
@ -66,7 +66,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||
|
||||
substrings = all_substrings['instruct' if is_instruct else 'chat']
|
||||
|
||||
# Creating the template for "chat-instruct" mode
|
||||
# Create the template for "chat-instruct" mode
|
||||
if state['mode'] == 'chat-instruct':
|
||||
wrapper = ''
|
||||
command = state['chat-instruct_command'].replace('<|character|>', state['name2'] if not impersonate else state['name1'])
|
||||
@ -83,7 +83,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||
else:
|
||||
wrapper = '<|prompt|>'
|
||||
|
||||
# Building the prompt
|
||||
# Build the prompt
|
||||
min_rows = 3
|
||||
i = len(history) - 1
|
||||
rows = [state['context_instruct'] if is_instruct else f"{state['context'].strip()}\n"]
|
||||
@ -107,11 +107,11 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||
min_rows = 2
|
||||
rows.append(substrings['user_turn_stripped'].rstrip(' '))
|
||||
elif not _continue:
|
||||
# Adding the user message
|
||||
# Add the user message
|
||||
if len(user_input) > 0:
|
||||
rows.append(replace_all(substrings['user_turn'], {'<|user-message|>': user_input.strip(), '<|round|>': str(len(history))}))
|
||||
|
||||
# Adding the Character prefix
|
||||
# Add the character prefix
|
||||
if state['mode'] != 'chat-instruct':
|
||||
rows.append(apply_extensions("bot_prefix", substrings['bot_turn_stripped'].rstrip(' ')))
|
||||
|
||||
@ -192,7 +192,6 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
|
||||
return
|
||||
|
||||
# Defining some variables
|
||||
cumulative_reply = ''
|
||||
just_started = True
|
||||
visible_text = None
|
||||
eos_token = '\n' if state['stop_at_newline'] else None
|
||||
@ -232,12 +231,13 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
|
||||
prompt = generate_chat_prompt(text, state, **kwargs)
|
||||
|
||||
# Generate
|
||||
cumulative_reply = ''
|
||||
for i in range(state['chat_generation_attempts']):
|
||||
reply = None
|
||||
for j, reply in enumerate(generate_reply(prompt + cumulative_reply, state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True)):
|
||||
reply = cumulative_reply + reply
|
||||
|
||||
# Extracting the reply
|
||||
# Extract the reply
|
||||
reply, next_character_found = extract_message_from_reply(reply, state)
|
||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
||||
visible_reply = apply_extensions("output", visible_reply)
|
||||
@ -268,7 +268,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
|
||||
if next_character_found:
|
||||
break
|
||||
|
||||
if reply in [None, '']:
|
||||
if reply in [None, cumulative_reply]:
|
||||
break
|
||||
else:
|
||||
cumulative_reply = reply
|
||||
|
Loading…
Reference in New Issue
Block a user