Fix silero tts autoplay (attempt #2)

This commit is contained in:
oobabooga 2023-05-21 13:24:54 -03:00
parent a5d5bb9390
commit 1e5821bd9e
4 changed files with 25 additions and 9 deletions

View File

@ -41,6 +41,7 @@ script.py may define the special functions and variables below.
| `def input_modifier(string)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. | | `def input_modifier(string)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. |
| `def output_modifier(string)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. | | `def output_modifier(string)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. |
| `def state_modifier(state)` | Modifies the dictionary containing the UI input parameters before it is used by the text generation functions. | | `def state_modifier(state)` | Modifies the dictionary containing the UI input parameters before it is used by the text generation functions. |
| `def history_modifier(history)` | Modifies the chat history before the text generation in chat mode begins. |
| `def bot_prefix_modifier(string)` | Applied in chat mode to the prefix for the bot's reply. | | `def bot_prefix_modifier(string)` | Applied in chat mode to the prefix for the bot's reply. |
| `def custom_generate_reply(...)` | Overrides the main text generation function. | | `def custom_generate_reply(...)` | Overrides the main text generation function. |
| `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. | | `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. |

View File

@ -83,17 +83,22 @@ def input_modifier(string):
they are fed into the model. they are fed into the model.
""" """
# Remove autoplay from the last reply
if shared.is_chat() and len(shared.history['internal']) > 0:
shared.history['visible'][-1] = [
shared.history['visible'][-1][0],
shared.history['visible'][-1][1].replace('controls autoplay>', 'controls>')
]
shared.processing_message = "*Is recording a voice message...*" shared.processing_message = "*Is recording a voice message...*"
return string return string
def history_modifier(history):
# Remove autoplay from the last reply
if len(history['internal']) > 0:
history['visible'][-1] = [
history['visible'][-1][0],
history['visible'][-1][1].replace('controls autoplay>', 'controls>')
]
return history
def output_modifier(string): def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.

View File

@ -181,6 +181,7 @@ def extract_message_from_reply(reply, state):
def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True): def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True):
output = copy.deepcopy(history) output = copy.deepcopy(history)
output = apply_extensions('history', output)
if shared.model_name == 'None' or shared.model is None: if shared.model_name == 'None' or shared.model is None:
logging.error("No model is loaded! Select one in the Model tab.") logging.error("No model is loaded! Select one in the Model tab.")
yield output yield output
@ -199,6 +200,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
if visible_text is None: if visible_text is None:
visible_text = text visible_text = text
text = apply_extensions('input', text)
# *Is typing...* # *Is typing...*
if loading_message: if loading_message:
yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']} yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']}
@ -307,8 +309,6 @@ def generate_chat_reply(text, history, state, regenerate=False, _continue=False,
if (len(history['visible']) == 1 and not history['visible'][0][0]) or len(history['internal']) == 0: if (len(history['visible']) == 1 and not history['visible'][0][0]) or len(history['internal']) == 0:
yield history yield history
return return
else:
text = apply_extensions('input', text)
for history in chatbot_wrapper(text, history, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message): for history in chatbot_wrapper(text, history, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message):
yield history yield history

View File

@ -91,6 +91,15 @@ def _apply_state_modifier_extensions(state):
return state return state
# Extension that modifies the chat history before it is used
def _apply_history_modifier_extensions(history):
for extension, _ in iterator():
if hasattr(extension, "history_modifier"):
history = getattr(extension, "history_modifier")(history)
return history
# Extension functions that override the default tokenizer output - currently only the first one will work # Extension functions that override the default tokenizer output - currently only the first one will work
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds): def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
for extension, _ in iterator(): for extension, _ in iterator():
@ -165,6 +174,7 @@ EXTENSION_MAP = {
"input": partial(_apply_string_extensions, "input_modifier"), "input": partial(_apply_string_extensions, "input_modifier"),
"output": partial(_apply_string_extensions, "output_modifier"), "output": partial(_apply_string_extensions, "output_modifier"),
"state": _apply_state_modifier_extensions, "state": _apply_state_modifier_extensions,
"history": _apply_history_modifier_extensions,
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"), "bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"), "tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
"input_hijack": _apply_input_hijack, "input_hijack": _apply_input_hijack,