mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Apply the output extensions only once
Relevant for google translate, silero
This commit is contained in:
parent
77baf43f6d
commit
3e80f2aceb
@ -152,6 +152,7 @@ def get_stopping_strings(state):
|
|||||||
def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True):
|
def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True):
|
||||||
output = copy.deepcopy(history)
|
output = copy.deepcopy(history)
|
||||||
output = apply_extensions('history', output)
|
output = apply_extensions('history', output)
|
||||||
|
state = apply_extensions('state', state)
|
||||||
if shared.model_name == 'None' or shared.model is None:
|
if shared.model_name == 'None' or shared.model is None:
|
||||||
logger.error("No model is loaded! Select one in the Model tab.")
|
logger.error("No model is loaded! Select one in the Model tab.")
|
||||||
yield output
|
yield output
|
||||||
@ -161,6 +162,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
|
|||||||
just_started = True
|
just_started = True
|
||||||
visible_text = None
|
visible_text = None
|
||||||
stopping_strings = get_stopping_strings(state)
|
stopping_strings = get_stopping_strings(state)
|
||||||
|
is_stream = state['stream']
|
||||||
|
|
||||||
# Preparing the input
|
# Preparing the input
|
||||||
if not any((regenerate, _continue)):
|
if not any((regenerate, _continue)):
|
||||||
@ -204,11 +206,11 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
|
|||||||
|
|
||||||
# Extract the reply
|
# Extract the reply
|
||||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
||||||
visible_reply = apply_extensions("output", visible_reply)
|
|
||||||
|
|
||||||
# We need this global variable to handle the Stop event,
|
# We need this global variable to handle the Stop event,
|
||||||
# otherwise gradio gets confused
|
# otherwise gradio gets confused
|
||||||
if shared.stop_everything:
|
if shared.stop_everything:
|
||||||
|
output['visible'][-1][1] = apply_extensions("output", output['visible'][-1][1])
|
||||||
yield output
|
yield output
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -221,12 +223,12 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
|
|||||||
if _continue:
|
if _continue:
|
||||||
output['internal'][-1] = [text, last_reply[0] + reply]
|
output['internal'][-1] = [text, last_reply[0] + reply]
|
||||||
output['visible'][-1] = [visible_text, last_reply[1] + visible_reply]
|
output['visible'][-1] = [visible_text, last_reply[1] + visible_reply]
|
||||||
if state['stream']:
|
if is_stream:
|
||||||
yield output
|
yield output
|
||||||
elif not (j == 0 and visible_reply.strip() == ''):
|
elif not (j == 0 and visible_reply.strip() == ''):
|
||||||
output['internal'][-1] = [text, reply.lstrip(' ')]
|
output['internal'][-1] = [text, reply.lstrip(' ')]
|
||||||
output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')]
|
output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')]
|
||||||
if state['stream']:
|
if is_stream:
|
||||||
yield output
|
yield output
|
||||||
|
|
||||||
if reply in [None, cumulative_reply]:
|
if reply in [None, cumulative_reply]:
|
||||||
@ -234,6 +236,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
|
|||||||
else:
|
else:
|
||||||
cumulative_reply = reply
|
cumulative_reply = reply
|
||||||
|
|
||||||
|
output['visible'][-1][1] = apply_extensions("output", output['visible'][-1][1])
|
||||||
yield output
|
yield output
|
||||||
|
|
||||||
|
|
||||||
|
@ -103,9 +103,6 @@ def get_reply_from_output_ids(output_ids, input_ids, original_question, state, i
|
|||||||
if shared.tokenizer.convert_ids_to_tokens(int(output_ids[-new_tokens])).startswith('▁'):
|
if shared.tokenizer.convert_ids_to_tokens(int(output_ids[-new_tokens])).startswith('▁'):
|
||||||
reply = ' ' + reply
|
reply = ' ' + reply
|
||||||
|
|
||||||
if not is_chat:
|
|
||||||
reply = apply_extensions('output', reply)
|
|
||||||
|
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
|
|
||||||
@ -170,7 +167,6 @@ def apply_stopping_strings(reply, all_stop_strings):
|
|||||||
|
|
||||||
|
|
||||||
def _generate_reply(question, state, stopping_strings=None, is_chat=False):
|
def _generate_reply(question, state, stopping_strings=None, is_chat=False):
|
||||||
state = apply_extensions('state', state)
|
|
||||||
generate_func = apply_extensions('custom_generate_reply')
|
generate_func = apply_extensions('custom_generate_reply')
|
||||||
if generate_func is None:
|
if generate_func is None:
|
||||||
if shared.model_name == 'None' or shared.model is None:
|
if shared.model_name == 'None' or shared.model is None:
|
||||||
@ -188,6 +184,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
|
|||||||
# Preparing the input
|
# Preparing the input
|
||||||
original_question = question
|
original_question = question
|
||||||
if not is_chat:
|
if not is_chat:
|
||||||
|
state = apply_extensions('state', state)
|
||||||
question = apply_extensions('input', question)
|
question = apply_extensions('input', question)
|
||||||
|
|
||||||
# Finding the stopping strings
|
# Finding the stopping strings
|
||||||
@ -219,6 +216,9 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
|
|||||||
if stop_found:
|
if stop_found:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if not is_chat:
|
||||||
|
reply = apply_extensions('output', reply)
|
||||||
|
|
||||||
yield reply
|
yield reply
|
||||||
|
|
||||||
|
|
||||||
@ -311,15 +311,9 @@ def generate_reply_custom(question, original_question, seed, state, stopping_str
|
|||||||
|
|
||||||
if not state['stream']:
|
if not state['stream']:
|
||||||
reply = shared.model.generate(question, state)
|
reply = shared.model.generate(question, state)
|
||||||
if not is_chat:
|
|
||||||
reply = apply_extensions('output', reply)
|
|
||||||
|
|
||||||
yield reply
|
yield reply
|
||||||
else:
|
else:
|
||||||
for reply in shared.model.generate_with_streaming(question, state):
|
for reply in shared.model.generate_with_streaming(question, state):
|
||||||
if not is_chat:
|
|
||||||
reply = apply_extensions('output', reply)
|
|
||||||
|
|
||||||
yield reply
|
yield reply
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
|
Loading…
Reference in New Issue
Block a user