Allow extensions to define a new tab

This commit is contained in:
oobabooga 2023-05-17 01:25:01 -03:00
parent ad0b71af11
commit ce21804ec7
4 changed files with 51 additions and 22 deletions

View File

@ -59,6 +59,10 @@ ol li p, ul li p {
margin-bottom: 35px; margin-bottom: 35px;
} }
.extension-tab {
border: 0 !important;
}
span.math.inline { span.math.inline {
font-size: 27px; font-size: 27px;
vertical-align: baseline !important; vertical-align: baseline !important;

View File

@ -49,18 +49,35 @@ Additionally, the script may define two special global variables:
#### `params` dictionary #### `params` dictionary
`script.py` may contain a special dictionary called `params`:
```python ```python
params = { params = {
"language string": "ja", "display_name": "Google Translate",
"is_tab": True,
} }
``` ```
This dicionary can be used to make the extension parameters customizable by adding entries to a `settings.json` file like this: In this dictionary, `display_name` is used to define the displayed name of the extension inside the UI, and `is_tab` is used to define whether the extension's `ui()` function should be called in a new `gr.Tab()` that will appear in the header bar. By default, the extension appears at the bottom of the "Text generation" tab.
Additionally, `params` may contain variables that you want to be customizable through a `settings.json` file. For instance, assuming the extension is in `extensions/google_translate`, the variable `language string` below
```python
params = {
"display_name": "Google Translate",
"is_tab": True,
"language string": "jp"
}
```
can be customized by adding a key called `google_translate-language string` to `settings.json`:
```python ```python
"google_translate-language string": "fr", "google_translate-language string": "fr",
``` ```
That is, the syntax is `extension_name-variable_name`.
#### `input_hijack` dictionary #### `input_hijack` dictionary
```python ```python

View File

@ -137,6 +137,30 @@ def _apply_custom_js():
return all_js return all_js
def create_extensions_block():
to_display = []
for extension, name in iterator():
if hasattr(extension, "ui") and not (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
to_display.append((extension, name))
# Creating the extension ui elements
if len(to_display) > 0:
with gr.Column(elem_id="extensions"):
for row in to_display:
extension, name = row
display_name = getattr(extension, 'params', {}).get('display_name', name)
gr.Markdown(f"\n### {display_name}")
extension.ui()
def create_extensions_tabs():
for extension, name in iterator():
if hasattr(extension, "ui") and (hasattr(extension, 'params') and extension.params.get('is_tab', False)):
display_name = getattr(extension, 'params', {}).get('display_name', name)
with gr.Tab(display_name, elem_classes="extension-tab"):
extension.ui()
EXTENSION_MAP = { EXTENSION_MAP = {
"input": partial(_apply_string_extensions, "input_modifier"), "input": partial(_apply_string_extensions, "input_modifier"),
"output": partial(_apply_string_extensions, "output_modifier"), "output": partial(_apply_string_extensions, "output_modifier"),
@ -157,21 +181,3 @@ def apply_extensions(typ, *args, **kwargs):
raise ValueError(f"Invalid extension type {typ}") raise ValueError(f"Invalid extension type {typ}")
return EXTENSION_MAP[typ](*args, **kwargs) return EXTENSION_MAP[typ](*args, **kwargs)
def create_extensions_block():
global setup_called
should_display_ui = False
for extension, name in iterator():
if hasattr(extension, "ui"):
should_display_ui = True
break
# Creating the extension ui elements
if should_display_ui:
with gr.Column(elem_id="extensions"):
for extension, name in iterator():
if hasattr(extension, "ui"):
gr.Markdown(f"\n### {name}")
extension.ui()

View File

@ -877,9 +877,11 @@ def create_interface():
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{js}}}")
shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False) shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False)
# Extensions tabs
extensions_module.create_extensions_tabs()
# Extensions block # Extensions block
if shared.args.extensions is not None: extensions_module.create_extensions_block()
extensions_module.create_extensions_block()
# Launch the interface # Launch the interface
shared.gradio['interface'].queue() shared.gradio['interface'].queue()