From ef8637e32da3447d22c321c0fcf326f291939a11 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 25 Jul 2023 18:49:56 -0300 Subject: [PATCH] Add extension example, replace input_hijack with chat_input_modifier (#3307) --- docs/Extensions.md | 223 +++++++++++++++++++---------- extensions/api/blocking_api.py | 13 +- extensions/api/streaming_api.py | 9 +- extensions/api/util.py | 1 - extensions/example/script.py | 129 +++++++++++++++++ extensions/multimodal/script.py | 9 ++ extensions/send_pictures/script.py | 16 ++- extensions/whisper_stt/script.py | 10 ++ modules/chat.py | 2 +- modules/extensions.py | 23 ++- 10 files changed, 335 insertions(+), 100 deletions(-) create mode 100644 extensions/example/script.py diff --git a/docs/Extensions.md b/docs/Extensions.md index e156456b..b0c88188 100644 --- a/docs/Extensions.md +++ b/docs/Extensions.md @@ -1,45 +1,45 @@ -Extensions are defined by files named `script.py` inside subfolders of `text-generation-webui/extensions`. They are loaded at startup if specified with the `--extensions` flag. +# Extensions + +Extensions are defined by files named `script.py` inside subfolders of `text-generation-webui/extensions`. They are loaded at startup if the folder name is specified after the `--extensions` flag. For instance, `extensions/silero_tts/script.py` gets loaded with `python server.py --extensions silero_tts`. ## [text-generation-webui-extensions](https://github.com/oobabooga/text-generation-webui-extensions) -The link above contains a directory of user extensions for text-generation-webui. +The repository above contains a directory of user extensions. -If you create an extension, you are welcome to host it in a GitHub repository and submit it to the list above. +If you create an extension, you are welcome to host it in a GitHub repository and submit a PR adding it to the list above. ## Built-in extensions -Most of these have been created by the extremely talented contributors that you can find here: [contributors](https://github.com/oobabooga/text-generation-webui/graphs/contributors?from=2022-12-18&to=&type=a). - |Extension|Description| |---------|-----------| -|[api](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/api)| Creates an API with two endpoints, one for streaming at `/api/v1/stream` port 5005 and another for blocking at `/api/v1/generate` port 5000. This is the main API for this web UI. | +|[api](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/api)| Creates an API with two endpoints, one for streaming at `/api/v1/stream` port 5005 and another for blocking at `/api/v1/generate` port 5000. This is the main API for the webui. | +|[openai](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai)| Creates an API that mimics the OpenAI API and can be used as a drop-in replacement. | +|[multimodal](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal) | Adds multimodality support (text+images). For a detailed description see [README.md](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal/README.md) in the extension directory. | |[google_translate](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/google_translate)| Automatically translates inputs and outputs using Google Translate.| -|[character_bias](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/character_bias)| Just a very simple example that biases the bot's responses in chat mode.| -|[gallery](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/gallery/)| Creates a gallery with the chat characters and their pictures. | -|[silero_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/silero_tts)| Text-to-speech extension using [Silero](https://github.com/snakers4/silero-models). When used in chat mode, it replaces the responses with an audio widget. | +|[silero_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/silero_tts)| Text-to-speech extension using [Silero](https://github.com/snakers4/silero-models). When used in chat mode, responses are replaced with an audio widget. | |[elevenlabs_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/elevenlabs_tts)| Text-to-speech extension using the [ElevenLabs](https://beta.elevenlabs.io/) API. You need an API key to use it. | -|[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. | |[whisper_stt](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/whisper_stt)| Allows you to enter your inputs in chat mode using your microphone. | |[sd_api_pictures](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/sd_api_pictures)| Allows you to request pictures from the bot in chat mode, which will be generated using the AUTOMATIC1111 Stable Diffusion API. See examples [here](https://github.com/oobabooga/text-generation-webui/pull/309). | -|[multimodal](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal) | Adds multimodality support (text+images). For a detailed description see [README.md](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/multimodal/README.md) in the extension directory. | -|[openai](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai)| Creates an API that mimics the OpenAI API and can be used as a drop-in replacement. | +|[character_bias](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/character_bias)| Just a very simple example that adds a hidden string at the beginning of the bot's reply in chat mode. | +|[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. | +|[gallery](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/gallery/)| Creates a gallery with the chat characters and their pictures. | |[superbooga](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/superbooga)| An extension that uses ChromaDB to create an arbitrarily large pseudocontext, taking as input text files, URLs, or pasted text. Based on https://github.com/kaiokendev/superbig. | ## How to write an extension -script.py may define the special functions and variables below. - -#### Predefined functions +The extensions framework is based on special functions and variables that you can define in `script.py`. The functions are the following: | Function | Description | |-------------|-------------| +| `def setup()` | Is executed when the extension gets imported. | | `def ui()` | Creates custom gradio elements when the UI is launched. | | `def custom_css()` | Returns custom CSS as a string. It is applied whenever the web UI is loaded. | | `def custom_js()` | Same as above but for javascript. | | `def input_modifier(string, state)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. | | `def output_modifier(string, state)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. | +| `def chat_input_modifier(text, visible_text, state)` | Modifies both the visible and internal inputs in chat mode. Can be used to hijack the chat input with custom content. | | `def bot_prefix_modifier(string, state)` | Applied in chat mode to the prefix for the bot's reply. | | `def state_modifier(state)` | Modifies the dictionary containing the UI input parameters before it is used by the text generation functions. | | `def history_modifier(history)` | Modifies the chat history before the text generation in chat mode begins. | @@ -48,9 +48,7 @@ script.py may define the special functions and variables below. | `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See the `multimodal` extension for an example. | | `def custom_tokenized_length(prompt)` | Used in conjunction with `tokenizer_modifier`, returns the length in tokens of `prompt`. See the `multimodal` extension for an example. | -#### `params` dictionary - -In this dictionary, `display_name` is used to define the displayed name of the extension in the UI, and `is_tab` is used to define whether the extension should appear in a new tab. By default, extensions appear at the bottom of the "Text generation" tab. +Additionally, you can define a special `params` dictionary. In it, the `display_name` key is used to define the displayed name of the extension in the UI, and the `is_tab` key is used to define whether the extension should appear in a new tab. By default, extensions appear at the bottom of the "Text generation" tab. Example: @@ -61,7 +59,7 @@ params = { } ``` -Additionally, `params` may contain variables that you want to be customizable through a `settings.json` file. For instance, assuming the extension is in `extensions/google_translate`, the variable `language string` in +Additionally, `params` may contain variables that you want to be customizable through a `settings.yaml` file. For instance, assuming the extension is in `extensions/google_translate`, the variable `language string` in ```python params = { @@ -71,32 +69,19 @@ params = { } ``` -can be customized by adding a key called `google_translate-language string` to `settings.json`: +can be customized by adding a key called `google_translate-language string` to `settings.yaml`: ```python -"google_translate-language string": "fr", +google_translate-language string: 'fr' ``` -That is, the syntax is `extension_name-variable_name`. - -#### `input_hijack` dictionary - -```python -input_hijack = { - 'state': False, - 'value': ["", ""] -} -``` -This is only used in chat mode. If your extension sets `input_hijack['state'] = True` at any moment, the next call to `modules.chat.chatbot_wrapper` will use the values inside `input_hijack['value']` as the user input for text generation. See the `send_pictures` extension above for an example. - -Additionally, your extension can set the value to be a callback in the form of `def cb(text: str, visible_text: str) -> [str, str]`. See the `multimodal` extension above for an example. +That is, the syntax for the key is `extension_name-variable_name`. ## Using multiple extensions at the same time -In order to use your extension, you must start the web UI with the `--extensions` flag followed by the name of your extension (the folder under `text-generation-webui/extension` where `script.py` resides). - -You can activate more than one extension at a time by providing their names separated by spaces. The input, output, and bot prefix modifiers will be applied in the specified order. +You can activate more than one extension at a time by providing their names separated by spaces after `--extensions`. The input, output, and bot prefix modifiers will be applied in the specified order. +Example: ``` python server.py --extensions enthusiasm translate # First apply enthusiasm, then translate @@ -106,56 +91,142 @@ python server.py --extensions translate enthusiasm # First apply translate, then Do note, that for: - `custom_generate_chat_prompt` - `custom_generate_reply` -- `tokenizer_modifier` - `custom_tokenized_length` only the first declaration encountered will be used and the rest will be ignored. -## The `bot_prefix_modifier` +## A full example -In chat mode, this function modifies the prefix for a new bot message. For instance, if your bot is named `Marie Antoinette`, the default prefix for a new message will be - -``` -Marie Antoinette: -``` - -Using `bot_prefix_modifier`, you can change it to: - -``` -Marie Antoinette: *I am very enthusiastic* -``` - -Marie Antoinette will become very enthusiastic in all her messages. - -## `custom_generate_reply` example - -Once defined in a `script.py`, this function is executed in place of the main generation functions. You can use it to connect the web UI to an external API, or to load a custom model that is not supported yet. - -Note that in chat mode, this function must only return the new text, whereas in other modes it must return the original prompt + the new text. +The source code below can be found at [extensions/example/script.py](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/example/script.py). ```python -import datetime +""" +An example of extension. It does nothing, but you can add transformations +before the return statements to customize the webui behavior. -def custom_generate_reply(question, original_question, seed, state, stopping_strings): - cumulative = '' - for i in range(10): - cumulative += f"Counting: {i}...\n" - yield cumulative +Starting from history_modifier and ending in output_modifier, the +functions are declared in the same order that they are called at +generation time. +""" - cumulative += f"Done! {str(datetime.datetime.now())}" - yield cumulative -``` - -## `custom_generate_chat_prompt` example - -Below is an extension that just reproduces the default prompt generator in `modules/chat.py`. You can modify it freely to come up with your own prompts in chat mode. - -```python +import torch from modules import chat +from modules.text_generation import ( + decode, + encode, + generate_reply, +) +from transformers import LogitsProcessor + +params = { + "display_name": "Example Extension", + "is_tab": False, +} + +class MyLogits(LogitsProcessor): + """ + Manipulates the probabilities for the next token before it gets sampled. + It gets used in the custom_logits_processor function below. + """ + def __init__(self): + pass + + def __call__(self, input_ids, scores): + # probs = torch.softmax(scores, dim=-1, dtype=torch.float) + # probs[0] /= probs[0].sum() + # scores = torch.log(probs / (1 - probs)) + return scores + +def history_modifier(history): + """ + Modifies the chat history. + Only used in chat mode. + """ + return history + +def state_modifier(state): + """ + Modifies the state variable, which is a dictionary containing the input + values in the UI like sliders and checkboxes. + """ + return state + +def chat_input_modifier(text, visible_text, state): + """ + Modifies the internal and visible input strings in chat mode. + """ + return text, visible_text + +def input_modifier(string, state): + """ + In chat mode, modifies the user input. The modified version goes into + history['internal'], and the original version goes into history['visible']. + + In default/notebook modes, modifies the whole prompt. + """ + return string + +def bot_prefix_modifier(string, state): + """ + Modifies the prefix for the next bot reply in chat mode. + By default, the prefix will be something like "Bot Name:". + """ + return string + +def tokenizer_modifier(state, prompt, input_ids, input_embeds): + """ + Modifies the input ids and embeds. + Used by the multimodal extension to put image embeddings in the prompt. + Only used by loaders that use the transformers library for sampling. + """ + return prompt, input_ids, input_embeds + +def logits_processor_modifier(processor_list, input_ids): + """ + Adds logits processors to the list. + Only used by loaders that use the transformers library for sampling. + """ + processor_list.append(MyLogits()) + return processor_list + +def output_modifier(string, state): + """ + Modifies the LLM output before it gets presented. + + In chat mode, the modified version goes into history['internal'], and the original version goes into history['visible']. + """ + return string def custom_generate_chat_prompt(user_input, state, **kwargs): - - # Do something with kwargs['history'] or state + """ + Replaces the function that generates the prompt from the chat history. + Only used in chat mode. + """ + result = chat.generate_chat_prompt(user_input, state, **kwargs) + return result - return chat.generate_chat_prompt(user_input, state, **kwargs) +def custom_css(): + """ + Returns a CSS string that gets appended to the CSS for the webui. + """ + return '' + +def custom_js(): + """ + Returns a javascript string that gets appended to the javascript for the webui. + """ + return '' + +def setup(): + """ + Gets executed only once, when the extension is imported. + """ + pass + +def ui(): + """ + Gets executed when the UI is drawn. Custom gradio elements and their corresponding + event handlers should be defined here. + """ + pass ``` diff --git a/extensions/api/blocking_api.py b/extensions/api/blocking_api.py index edc6d8f4..fbbc5ec1 100644 --- a/extensions/api/blocking_api.py +++ b/extensions/api/blocking_api.py @@ -7,10 +7,15 @@ from modules import shared from modules.chat import generate_chat_reply from modules.LoRA import add_lora_to_model from modules.models import load_model, unload_model -from modules.models_settings import (get_model_settings_from_yamls, - update_model_parameters) -from modules.text_generation import (encode, generate_reply, - stop_everything_event) +from modules.models_settings import ( + get_model_settings_from_yamls, + update_model_parameters +) +from modules.text_generation import ( + encode, + generate_reply, + stop_everything_event +) from modules.utils import get_available_models diff --git a/extensions/api/streaming_api.py b/extensions/api/streaming_api.py index 88359e3e..6afa827d 100644 --- a/extensions/api/streaming_api.py +++ b/extensions/api/streaming_api.py @@ -2,12 +2,15 @@ import asyncio import json from threading import Thread -from websockets.server import serve - -from extensions.api.util import build_parameters, try_start_cloudflared, with_api_lock +from extensions.api.util import ( + build_parameters, + try_start_cloudflared, + with_api_lock +) from modules import shared from modules.chat import generate_chat_reply from modules.text_generation import generate_reply +from websockets.server import serve PATH = '/api/v1/stream' diff --git a/extensions/api/util.py b/extensions/api/util.py index a9d581eb..2358b7d2 100644 --- a/extensions/api/util.py +++ b/extensions/api/util.py @@ -10,7 +10,6 @@ from modules import shared from modules.chat import load_character_memoized from modules.presets import load_preset_memoized - # We use a thread local to store the asyncio lock, so that each thread # has its own lock. This isn't strictly necessary, but it makes it # such that if we can support multiple worker threads in the future, diff --git a/extensions/example/script.py b/extensions/example/script.py new file mode 100644 index 00000000..d47b4361 --- /dev/null +++ b/extensions/example/script.py @@ -0,0 +1,129 @@ +""" +An example of extension. It does nothing, but you can add transformations +before the return statements to customize the webui behavior. + +Starting from history_modifier and ending in output_modifier, the +functions are declared in the same order that they are called at +generation time. +""" + +import torch +from modules import chat +from modules.text_generation import ( + decode, + encode, + generate_reply, +) +from transformers import LogitsProcessor + +params = { + "display_name": "Example Extension", + "is_tab": False, +} + +class MyLogits(LogitsProcessor): + """ + Manipulates the probabilities for the next token before it gets sampled. + It gets used in the custom_logits_processor function below. + """ + def __init__(self): + pass + + def __call__(self, input_ids, scores): + # probs = torch.softmax(scores, dim=-1, dtype=torch.float) + # probs[0] /= probs[0].sum() + # scores = torch.log(probs / (1 - probs)) + return scores + +def history_modifier(history): + """ + Modifies the chat history. + Only used in chat mode. + """ + return history + +def state_modifier(state): + """ + Modifies the state variable, which is a dictionary containing the input + values in the UI like sliders and checkboxes. + """ + return state + +def chat_input_modifier(text, visible_text, state): + """ + Modifies the internal and visible input strings in chat mode. + """ + return text, visible_text + +def input_modifier(string, state): + """ + In chat mode, modifies the user input. The modified version goes into + history['internal'], and the original version goes into history['visible']. + + In default/notebook modes, modifies the whole prompt. + """ + return string + +def bot_prefix_modifier(string, state): + """ + Modifies the prefix for the next bot reply in chat mode. + By default, the prefix will be something like "Bot Name:". + """ + return string + +def tokenizer_modifier(state, prompt, input_ids, input_embeds): + """ + Modifies the input ids and embeds. + Used by the multimodal extension to put image embeddings in the prompt. + Only used by loaders that use the transformers library for sampling. + """ + return prompt, input_ids, input_embeds + +def logits_processor_modifier(processor_list, input_ids): + """ + Adds logits processors to the list. + Only used by loaders that use the transformers library for sampling. + """ + processor_list.append(MyLogits()) + return processor_list + +def output_modifier(string, state): + """ + Modifies the LLM output before it gets presented. + + In chat mode, the modified version goes into history['internal'], and the original version goes into history['visible']. + """ + return string + +def custom_generate_chat_prompt(user_input, state, **kwargs): + """ + Replaces the function that generates the prompt from the chat history. + Only used in chat mode. + """ + result = chat.generate_chat_prompt(user_input, state, **kwargs) + return result + +def custom_css(): + """ + Returns a CSS string that gets appended to the CSS for the webui. + """ + return '' + +def custom_js(): + """ + Returns a javascript string that gets appended to the javascript for the webui. + """ + return '' + +def setup(): + """ + Gets executed only once, when the extension is imported. + """ + pass + +def ui(): + """ + Gets executed when the UI is drawn. Custom gradio elements and their corresponding + event handlers should be defined here. + """ + pass diff --git a/extensions/multimodal/script.py b/extensions/multimodal/script.py index b3f654e4..8bc26315 100644 --- a/extensions/multimodal/script.py +++ b/extensions/multimodal/script.py @@ -35,6 +35,15 @@ input_hijack = { multimodal_embedder: MultimodalEmbedder = None +def chat_input_modifier(text, visible_text, state): + global input_hijack + if input_hijack['state']: + input_hijack['state'] = False + return input_hijack['value'](text, visible_text) + else: + return text, visible_text + + def add_chat_picture(picture, text, visible_text): # resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable) max_hw, min_hw = max(picture.size), min(picture.size) diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py index 63421743..39c9362a 100644 --- a/extensions/send_pictures/script.py +++ b/extensions/send_pictures/script.py @@ -9,8 +9,6 @@ from modules import chat, shared from modules.ui import gather_interface_values from modules.utils import gradio -# If 'state' is True, will hijack the next chat generation with -# custom input text given by 'value' in the format [text, visible_text] input_hijack = { 'state': False, 'value': ["", ""] @@ -20,6 +18,15 @@ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu") +def chat_input_modifier(text, visible_text, state): + global input_hijack + if input_hijack['state']: + input_hijack['state'] = False + return input_hijack['value'] + else: + return text, visible_text + + def caption_image(raw_image): inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32) out = model.generate(**inputs, max_new_tokens=100) @@ -42,7 +49,10 @@ def ui(): # Prepare the input hijack, update the interface values, call the generation function, and clear the picture picture_select.upload( - lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None).then( + lambda picture, name1, name2: input_hijack.update({ + "state": True, + "value": generate_chat_picture(picture, name1, name2) + }), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None).then( gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.generate_chat_reply_wrapper, shared.input_params, gradio('display', 'history'), show_progress=False).then( lambda: None, None, picture_select, show_progress=False) diff --git a/extensions/whisper_stt/script.py b/extensions/whisper_stt/script.py index 1e07ad2c..cdc55687 100644 --- a/extensions/whisper_stt/script.py +++ b/extensions/whisper_stt/script.py @@ -16,6 +16,15 @@ params = { } +def chat_input_modifier(text, visible_text, state): + global input_hijack + if input_hijack['state']: + input_hijack['state'] = False + return input_hijack['value'] + else: + return text, visible_text + + def do_stt(audio, whipser_model, whipser_language): transcription = "" r = sr.Recognizer() @@ -56,6 +65,7 @@ def ui(): audio.change( auto_transcribe, [audio, auto_submit, whipser_model, whipser_language], [shared.gradio['textbox'], audio]).then( None, auto_submit, None, _js="(check) => {if (check) { document.getElementById('Generate').click() }}") + whipser_model.change(lambda x: params.update({"whipser_model": x}), whipser_model, None) whipser_language.change(lambda x: params.update({"whipser_language": x}), whipser_language, None) auto_submit.change(lambda x: params.update({"auto_submit": x}), auto_submit, None) diff --git a/modules/chat.py b/modules/chat.py index d2423555..0e0e416c 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -175,7 +175,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess # Preparing the input if not any((regenerate, _continue)): - text, visible_text = apply_extensions('input_hijack', text, visible_text) + text, visible_text = apply_extensions('chat_input', text, visible_text, state) if visible_text is None: visible_text = text diff --git a/modules/extensions.py b/modules/extensions.py index faf6cf6d..76b6be8b 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,13 +1,12 @@ import traceback from functools import partial +from inspect import signature import gradio as gr import extensions import modules.shared as shared from modules.logging_colors import logger -from inspect import signature - state = {} available_extensions = [] @@ -66,15 +65,11 @@ def _apply_string_extensions(function_name, text, state): return text -# Input hijack of extensions -def _apply_input_hijack(text, visible_text): +# Extension functions that map string -> string +def _apply_chat_input_extensions(text, visible_text, state): for extension, _ in iterator(): - if hasattr(extension, 'input_hijack') and extension.input_hijack['state']: - extension.input_hijack['state'] = False - if callable(extension.input_hijack['value']): - text, visible_text = extension.input_hijack['value'](text, visible_text) - else: - text, visible_text = extension.input_hijack['value'] + if hasattr(extension, 'chat_input_modifier'): + text, visible_text = extension.chat_input_modifier(text, visible_text, state) return text, visible_text @@ -120,7 +115,11 @@ def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_e def _apply_logits_processor_extensions(function_name, processor_list, input_ids): for extension, _ in iterator(): if hasattr(extension, function_name): - getattr(extension, function_name)(processor_list, input_ids) + result = getattr(extension, function_name)(processor_list, input_ids) + if type(result) is list: + processor_list = result + + return processor_list # Get prompt length in tokens after applying extension functions which override the default tokenizer output @@ -187,12 +186,12 @@ def create_extensions_tabs(): EXTENSION_MAP = { "input": partial(_apply_string_extensions, "input_modifier"), "output": partial(_apply_string_extensions, "output_modifier"), + "chat_input": _apply_chat_input_extensions, "state": _apply_state_modifier_extensions, "history": _apply_history_modifier_extensions, "bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"), "tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"), 'logits_processor': partial(_apply_logits_processor_extensions, 'logits_processor_modifier'), - "input_hijack": _apply_input_hijack, "custom_generate_chat_prompt": _apply_custom_generate_chat_prompt, "custom_generate_reply": _apply_custom_generate_reply, "tokenized_length": _apply_custom_tokenized_length,