Allow extensions to define custom CSS and JS

This commit is contained in:
oobabooga 2023-05-17 00:03:39 -03:00
parent 824fa8fc0e
commit a84f499718
5 changed files with 39 additions and 43 deletions

View File

@ -110,30 +110,4 @@ button {
.small-button { .small-button {
max-width: 171px; max-width: 171px;
}
/* Align the elements for SD_api_picture extension */
.SDAP #sampler_box {
padding-top: var(--spacing-sm);
padding-bottom: var(--spacing-sm);
}
.SDAP #seed_box,
.SDAP #cfg_box {
padding-top: var(--spacing-md);
}
.SDAP #sampler_box span,
.SDAP #seed_box span,
.SDAP #cfg_box span{
margin-bottom: var(--spacing-sm);
}
.SDAP svg.dropdown-arrow {
flex-shrink: 0 !important;
margin: 0px !important;
}
.SDAP .hires_opts input[type="number"] {
width: 6em !important;
} }

View File

@ -1,16 +1,6 @@
This web UI supports extensions. They are simply files under Extensions are defined by files named `script.py` inside subfolders of `text-generation-webui/extensions`. They are invoked with the `--extensions` flag.
``` For instance, `extensions/silero_tts/script.py` gets invoked with `python server.py --extensions silero_tts`.
extensions/your_extension_name/script.py
```
which can be invoked with the
```
--extension your_extension_name
```
command-line flag.
## [text-generation-webui-extensions](https://github.com/oobabooga/text-generation-webui-extensions) ## [text-generation-webui-extensions](https://github.com/oobabooga/text-generation-webui-extensions)
@ -44,6 +34,8 @@ Most of these have been created by the extremely talented contributors that you
| Function | Description | | Function | Description |
|-------------|-------------| |-------------|-------------|
| `def ui()` | Creates custom gradio elements when the UI is launched. | | `def ui()` | Creates custom gradio elements when the UI is launched. |
| `def custom_css()` | Returns custom CSS as a string. |
| `def custom_js()` | Returns custom javascript as a string. |
| `def input_modifier(string)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. | | `def input_modifier(string)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. |
| `def output_modifier(string)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. | | `def output_modifier(string)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. |
| `def state_modifier(state)` | Modifies the dictionary containing the input parameters before it is used by the text generation functions. | | `def state_modifier(state)` | Modifies the dictionary containing the input parameters before it is used by the text generation functions. |

View File

@ -261,6 +261,11 @@ def SD_api_address_update(address):
return gr.Textbox.update(label=msg) return gr.Textbox.update(label=msg)
def custom_css():
path_to_css = Path(__file__).parent.resolve() / 'style.css'
return open(path_to_css, 'r').read()
def ui(): def ui():
# Gradio elements # Gradio elements

View File

@ -119,6 +119,24 @@ def _apply_custom_generate_reply():
return None return None
def _apply_custom_css():
all_css = ''
for extension, _ in iterator():
if hasattr(extension, 'custom_css'):
all_css += getattr(extension, 'custom_css')()
return all_css
def _apply_custom_js():
all_js = ''
for extension, _ in iterator():
if hasattr(extension, 'custom_js'):
all_js += getattr(extension, 'custom_js')()
return all_js
EXTENSION_MAP = { EXTENSION_MAP = {
"input": partial(_apply_string_extensions, "input_modifier"), "input": partial(_apply_string_extensions, "input_modifier"),
"output": partial(_apply_string_extensions, "output_modifier"), "output": partial(_apply_string_extensions, "output_modifier"),
@ -128,7 +146,9 @@ EXTENSION_MAP = {
"input_hijack": _apply_input_hijack, "input_hijack": _apply_input_hijack,
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt, "custom_generate_chat_prompt": _apply_custom_generate_chat_prompt,
"custom_generate_reply": _apply_custom_generate_reply, "custom_generate_reply": _apply_custom_generate_reply,
"tokenized_length": _apply_custom_tokenized_length "tokenized_length": _apply_custom_tokenized_length,
"css": _apply_custom_css,
"js": _apply_custom_js
} }

View File

@ -45,6 +45,7 @@ from PIL import Image
import modules.extensions as extensions_module import modules.extensions as extensions_module
from modules import chat, shared, training, ui, utils from modules import chat, shared, training, ui, utils
from modules.extensions import apply_extensions
from modules.html_generator import chat_html_wrapper from modules.html_generator import chat_html_wrapper
from modules.LoRA import add_lora_to_model from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt, unload_model from modules.models import load_model, load_soft_prompt, unload_model
@ -519,7 +520,13 @@ def create_interface():
if shared.args.extensions is not None and len(shared.args.extensions) > 0: if shared.args.extensions is not None and len(shared.args.extensions) > 0:
extensions_module.load_extensions() extensions_module.load_extensions()
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title, theme=ui.theme) as shared.gradio['interface']: # css/js strings
css = ui.css if not shared.is_chat() else ui.css + ui.chat_css
js = ui.main_js if not shared.is_chat() else ui.main_js + ui.chat_js
css += apply_extensions('css')
js += apply_extensions('js')
with gr.Blocks(css=css, analytics_enabled=False, title=title, theme=ui.theme) as shared.gradio['interface']:
# Create chat mode interface # Create chat mode interface
if shared.is_chat(): if shared.is_chat():
@ -826,8 +833,6 @@ def create_interface():
chat.upload_your_profile_picture, shared.gradio['your_picture'], None).then( chat.upload_your_profile_picture, shared.gradio['your_picture'], None).then(
partial(chat.redraw_html, reset_cache=True), shared.reload_inputs, shared.gradio['display']) partial(chat.redraw_html, reset_cache=True), shared.reload_inputs, shared.gradio['display'])
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
# notebook/default modes event handlers # notebook/default modes event handlers
else: else:
shared.input_params = [shared.gradio[k] for k in ['textbox', 'interface_state']] shared.input_params = [shared.gradio[k] for k in ['textbox', 'interface_state']]
@ -869,8 +874,8 @@ def create_interface():
shared.gradio['prompt_menu'].change(load_prompt, shared.gradio['prompt_menu'], shared.gradio['textbox'], show_progress=False) shared.gradio['prompt_menu'].change(load_prompt, shared.gradio['prompt_menu'], shared.gradio['textbox'], show_progress=False)
shared.gradio['save_prompt'].click(save_prompt, shared.gradio['textbox'], shared.gradio['status'], show_progress=False) shared.gradio['save_prompt'].click(save_prompt, shared.gradio['textbox'], shared.gradio['status'], show_progress=False)
shared.gradio['count_tokens'].click(count_tokens, shared.gradio['textbox'], shared.gradio['status'], show_progress=False) shared.gradio['count_tokens'].click(count_tokens, shared.gradio['textbox'], shared.gradio['status'], show_progress=False)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{js}}}")
shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False) shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False)
# Extensions block # Extensions block
if shared.args.extensions is not None: if shared.args.extensions is not None: