From 1e5821bd9e90e11ef5d0cba487465410d430fe85 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 21 May 2023 13:24:54 -0300 Subject: [PATCH] Fix silero tts autoplay (attempt #2) --- docs/Extensions.md | 1 + extensions/silero_tts/script.py | 19 ++++++++++++------- modules/chat.py | 4 ++-- modules/extensions.py | 10 ++++++++++ 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/docs/Extensions.md b/docs/Extensions.md index 0e396ce2..b80e1a26 100644 --- a/docs/Extensions.md +++ b/docs/Extensions.md @@ -41,6 +41,7 @@ script.py may define the special functions and variables below. | `def input_modifier(string)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. | | `def output_modifier(string)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. | | `def state_modifier(state)` | Modifies the dictionary containing the UI input parameters before it is used by the text generation functions. | +| `def history_modifier(history)` | Modifies the chat history before the text generation in chat mode begins. | | `def bot_prefix_modifier(string)` | Applied in chat mode to the prefix for the bot's reply. | | `def custom_generate_reply(...)` | Overrides the main text generation function. | | `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. | diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py index e3b5b1aa..c54d9ac5 100644 --- a/extensions/silero_tts/script.py +++ b/extensions/silero_tts/script.py @@ -83,17 +83,22 @@ def input_modifier(string): they are fed into the model. """ - # Remove autoplay from the last reply - if shared.is_chat() and len(shared.history['internal']) > 0: - shared.history['visible'][-1] = [ - shared.history['visible'][-1][0], - shared.history['visible'][-1][1].replace('controls autoplay>', 'controls>') - ] - shared.processing_message = "*Is recording a voice message...*" return string +def history_modifier(history): + + # Remove autoplay from the last reply + if len(history['internal']) > 0: + history['visible'][-1] = [ + history['visible'][-1][0], + history['visible'][-1][1].replace('controls autoplay>', 'controls>') + ] + + return history + + def output_modifier(string): """ This function is applied to the model outputs. diff --git a/modules/chat.py b/modules/chat.py index ea52145d..5565b6c8 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -181,6 +181,7 @@ def extract_message_from_reply(reply, state): def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True): output = copy.deepcopy(history) + output = apply_extensions('history', output) if shared.model_name == 'None' or shared.model is None: logging.error("No model is loaded! Select one in the Model tab.") yield output @@ -199,6 +200,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa if visible_text is None: visible_text = text + text = apply_extensions('input', text) # *Is typing...* if loading_message: yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']} @@ -307,8 +309,6 @@ def generate_chat_reply(text, history, state, regenerate=False, _continue=False, if (len(history['visible']) == 1 and not history['visible'][0][0]) or len(history['internal']) == 0: yield history return - else: - text = apply_extensions('input', text) for history in chatbot_wrapper(text, history, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message): yield history diff --git a/modules/extensions.py b/modules/extensions.py index 2f68caf5..d41ae3df 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -91,6 +91,15 @@ def _apply_state_modifier_extensions(state): return state +# Extension that modifies the chat history before it is used +def _apply_history_modifier_extensions(history): + for extension, _ in iterator(): + if hasattr(extension, "history_modifier"): + history = getattr(extension, "history_modifier")(history) + + return history + + # Extension functions that override the default tokenizer output - currently only the first one will work def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds): for extension, _ in iterator(): @@ -165,6 +174,7 @@ EXTENSION_MAP = { "input": partial(_apply_string_extensions, "input_modifier"), "output": partial(_apply_string_extensions, "output_modifier"), "state": _apply_state_modifier_extensions, + "history": _apply_history_modifier_extensions, "bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"), "tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"), "input_hijack": _apply_input_hijack,