From a84f499718dac1f82fd7165a77156d44760dcef3 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 17 May 2023 00:03:39 -0300 Subject: [PATCH] Allow extensions to define custom CSS and JS --- css/main.css | 26 -------------------------- docs/Extensions.md | 16 ++++------------ extensions/sd_api_pictures/script.py | 5 +++++ modules/extensions.py | 22 +++++++++++++++++++++- server.py | 13 +++++++++---- 5 files changed, 39 insertions(+), 43 deletions(-) diff --git a/css/main.css b/css/main.css index cdde2705..83b8850b 100644 --- a/css/main.css +++ b/css/main.css @@ -110,30 +110,4 @@ button { .small-button { max-width: 171px; -} - -/* Align the elements for SD_api_picture extension */ -.SDAP #sampler_box { - padding-top: var(--spacing-sm); - padding-bottom: var(--spacing-sm); -} - -.SDAP #seed_box, -.SDAP #cfg_box { - padding-top: var(--spacing-md); -} - -.SDAP #sampler_box span, -.SDAP #seed_box span, -.SDAP #cfg_box span{ - margin-bottom: var(--spacing-sm); -} - -.SDAP svg.dropdown-arrow { - flex-shrink: 0 !important; - margin: 0px !important; -} - -.SDAP .hires_opts input[type="number"] { - width: 6em !important; } \ No newline at end of file diff --git a/docs/Extensions.md b/docs/Extensions.md index 08d698e9..0e52c8d1 100644 --- a/docs/Extensions.md +++ b/docs/Extensions.md @@ -1,16 +1,6 @@ -This web UI supports extensions. They are simply files under +Extensions are defined by files named `script.py` inside subfolders of `text-generation-webui/extensions`. They are invoked with the `--extensions` flag. -``` -extensions/your_extension_name/script.py -``` - -which can be invoked with the - -``` ---extension your_extension_name -``` - -command-line flag. +For instance, `extensions/silero_tts/script.py` gets invoked with `python server.py --extensions silero_tts`. ## [text-generation-webui-extensions](https://github.com/oobabooga/text-generation-webui-extensions) @@ -44,6 +34,8 @@ Most of these have been created by the extremely talented contributors that you | Function | Description | |-------------|-------------| | `def ui()` | Creates custom gradio elements when the UI is launched. | +| `def custom_css()` | Returns custom CSS as a string. | +| `def custom_js()` | Returns custom javascript as a string. | | `def input_modifier(string)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. | | `def output_modifier(string)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. | | `def state_modifier(state)` | Modifies the dictionary containing the input parameters before it is used by the text generation functions. | diff --git a/extensions/sd_api_pictures/script.py b/extensions/sd_api_pictures/script.py index 2c054242..949531c9 100644 --- a/extensions/sd_api_pictures/script.py +++ b/extensions/sd_api_pictures/script.py @@ -261,6 +261,11 @@ def SD_api_address_update(address): return gr.Textbox.update(label=msg) +def custom_css(): + path_to_css = Path(__file__).parent.resolve() / 'style.css' + return open(path_to_css, 'r').read() + + def ui(): # Gradio elements diff --git a/modules/extensions.py b/modules/extensions.py index 62a21dfb..fe8cb7be 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -119,6 +119,24 @@ def _apply_custom_generate_reply(): return None +def _apply_custom_css(): + all_css = '' + for extension, _ in iterator(): + if hasattr(extension, 'custom_css'): + all_css += getattr(extension, 'custom_css')() + + return all_css + + +def _apply_custom_js(): + all_js = '' + for extension, _ in iterator(): + if hasattr(extension, 'custom_js'): + all_js += getattr(extension, 'custom_js')() + + return all_js + + EXTENSION_MAP = { "input": partial(_apply_string_extensions, "input_modifier"), "output": partial(_apply_string_extensions, "output_modifier"), @@ -128,7 +146,9 @@ EXTENSION_MAP = { "input_hijack": _apply_input_hijack, "custom_generate_chat_prompt": _apply_custom_generate_chat_prompt, "custom_generate_reply": _apply_custom_generate_reply, - "tokenized_length": _apply_custom_tokenized_length + "tokenized_length": _apply_custom_tokenized_length, + "css": _apply_custom_css, + "js": _apply_custom_js } diff --git a/server.py b/server.py index e4993203..334841e6 100644 --- a/server.py +++ b/server.py @@ -45,6 +45,7 @@ from PIL import Image import modules.extensions as extensions_module from modules import chat, shared, training, ui, utils +from modules.extensions import apply_extensions from modules.html_generator import chat_html_wrapper from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt, unload_model @@ -519,7 +520,13 @@ def create_interface(): if shared.args.extensions is not None and len(shared.args.extensions) > 0: extensions_module.load_extensions() - with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title, theme=ui.theme) as shared.gradio['interface']: + # css/js strings + css = ui.css if not shared.is_chat() else ui.css + ui.chat_css + js = ui.main_js if not shared.is_chat() else ui.main_js + ui.chat_js + css += apply_extensions('css') + js += apply_extensions('js') + + with gr.Blocks(css=css, analytics_enabled=False, title=title, theme=ui.theme) as shared.gradio['interface']: # Create chat mode interface if shared.is_chat(): @@ -826,8 +833,6 @@ def create_interface(): chat.upload_your_profile_picture, shared.gradio['your_picture'], None).then( partial(chat.redraw_html, reset_cache=True), shared.reload_inputs, shared.gradio['display']) - shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") - # notebook/default modes event handlers else: shared.input_params = [shared.gradio[k] for k in ['textbox', 'interface_state']] @@ -869,8 +874,8 @@ def create_interface(): shared.gradio['prompt_menu'].change(load_prompt, shared.gradio['prompt_menu'], shared.gradio['textbox'], show_progress=False) shared.gradio['save_prompt'].click(save_prompt, shared.gradio['textbox'], shared.gradio['status'], show_progress=False) shared.gradio['count_tokens'].click(count_tokens, shared.gradio['textbox'], shared.gradio['status'], show_progress=False) - shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") + shared.gradio['interface'].load(None, None, None, _js=f"() => {{{js}}}") shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False) # Extensions block if shared.args.extensions is not None: