Apply the output extensions only once

Relevant for google translate, silero
This commit is contained in:
oobabooga 2023-06-24 10:59:07 -03:00
parent 77baf43f6d
commit 3e80f2aceb
2 changed files with 10 additions and 13 deletions

View File

@ -152,6 +152,7 @@ def get_stopping_strings(state):
def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True): def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True):
output = copy.deepcopy(history) output = copy.deepcopy(history)
output = apply_extensions('history', output) output = apply_extensions('history', output)
state = apply_extensions('state', state)
if shared.model_name == 'None' or shared.model is None: if shared.model_name == 'None' or shared.model is None:
logger.error("No model is loaded! Select one in the Model tab.") logger.error("No model is loaded! Select one in the Model tab.")
yield output yield output
@ -161,6 +162,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
just_started = True just_started = True
visible_text = None visible_text = None
stopping_strings = get_stopping_strings(state) stopping_strings = get_stopping_strings(state)
is_stream = state['stream']
# Preparing the input # Preparing the input
if not any((regenerate, _continue)): if not any((regenerate, _continue)):
@ -204,11 +206,11 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
# Extract the reply # Extract the reply
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply) visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
visible_reply = apply_extensions("output", visible_reply)
# We need this global variable to handle the Stop event, # We need this global variable to handle the Stop event,
# otherwise gradio gets confused # otherwise gradio gets confused
if shared.stop_everything: if shared.stop_everything:
output['visible'][-1][1] = apply_extensions("output", output['visible'][-1][1])
yield output yield output
return return
@ -221,12 +223,12 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
if _continue: if _continue:
output['internal'][-1] = [text, last_reply[0] + reply] output['internal'][-1] = [text, last_reply[0] + reply]
output['visible'][-1] = [visible_text, last_reply[1] + visible_reply] output['visible'][-1] = [visible_text, last_reply[1] + visible_reply]
if state['stream']: if is_stream:
yield output yield output
elif not (j == 0 and visible_reply.strip() == ''): elif not (j == 0 and visible_reply.strip() == ''):
output['internal'][-1] = [text, reply.lstrip(' ')] output['internal'][-1] = [text, reply.lstrip(' ')]
output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')] output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')]
if state['stream']: if is_stream:
yield output yield output
if reply in [None, cumulative_reply]: if reply in [None, cumulative_reply]:
@ -234,6 +236,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
else: else:
cumulative_reply = reply cumulative_reply = reply
output['visible'][-1][1] = apply_extensions("output", output['visible'][-1][1])
yield output yield output

View File

@ -103,9 +103,6 @@ def get_reply_from_output_ids(output_ids, input_ids, original_question, state, i
if shared.tokenizer.convert_ids_to_tokens(int(output_ids[-new_tokens])).startswith(''): if shared.tokenizer.convert_ids_to_tokens(int(output_ids[-new_tokens])).startswith(''):
reply = ' ' + reply reply = ' ' + reply
if not is_chat:
reply = apply_extensions('output', reply)
return reply return reply
@ -170,7 +167,6 @@ def apply_stopping_strings(reply, all_stop_strings):
def _generate_reply(question, state, stopping_strings=None, is_chat=False): def _generate_reply(question, state, stopping_strings=None, is_chat=False):
state = apply_extensions('state', state)
generate_func = apply_extensions('custom_generate_reply') generate_func = apply_extensions('custom_generate_reply')
if generate_func is None: if generate_func is None:
if shared.model_name == 'None' or shared.model is None: if shared.model_name == 'None' or shared.model is None:
@ -188,6 +184,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
# Preparing the input # Preparing the input
original_question = question original_question = question
if not is_chat: if not is_chat:
state = apply_extensions('state', state)
question = apply_extensions('input', question) question = apply_extensions('input', question)
# Finding the stopping strings # Finding the stopping strings
@ -219,6 +216,9 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
if stop_found: if stop_found:
break break
if not is_chat:
reply = apply_extensions('output', reply)
yield reply yield reply
@ -311,15 +311,9 @@ def generate_reply_custom(question, original_question, seed, state, stopping_str
if not state['stream']: if not state['stream']:
reply = shared.model.generate(question, state) reply = shared.model.generate(question, state)
if not is_chat:
reply = apply_extensions('output', reply)
yield reply yield reply
else: else:
for reply in shared.model.generate_with_streaming(question, state): for reply in shared.model.generate_with_streaming(question, state):
if not is_chat:
reply = apply_extensions('output', reply)
yield reply yield reply
except Exception: except Exception: