text-generation-webui/modules/extensions.py
Andy Salerno 654933c634
New universal API with streaming/blocking endpoints (#990)
Previous title: Add api_streaming extension and update api-example-stream to use it

* Merge with latest main

* Add parameter capturing encoder_repetition_penalty

* Change some defaults, minor fixes

* Add --api, --public-api flags

* remove unneeded/broken comment from blocking API startup. The comment is already correctly emitted in try_start_cloudflared by calling the lambda we pass in.

* Update on_start message for blocking_api, it should say 'non-streaming' and not 'streaming'

* Update the API examples

* Change a comment

* Update README

* Remove the gradio API

* Remove unused import

* Minor change

* Remove unused import

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
2023-04-23 15:52:43 -03:00

76 lines
2.4 KiB
Python

import traceback
import gradio as gr
import extensions
import modules.shared as shared
state = {}
available_extensions = []
setup_called = set()
def load_extensions():
global state, setup_called
for i, name in enumerate(shared.args.extensions):
if name in available_extensions:
if name != 'api':
print(f'Loading the extension "{name}"... ', end='')
try:
exec(f"import extensions.{name}.script")
extension = getattr(extensions, name).script
if extension not in setup_called and hasattr(extension, "setup"):
setup_called.add(extension)
extension.setup()
state[name] = [True, i]
if name != 'api':
print('Ok.')
except:
if name != 'api':
print('Fail.')
traceback.print_exc()
# This iterator returns the extensions in the order specified in the command-line
def iterator():
for name in sorted(state, key=lambda x: state[x][1]):
if state[name][0]:
yield getattr(extensions, name).script, name
# Extension functions that map string -> string
def apply_extensions(text, typ):
for extension, _ in iterator():
if typ == "input" and hasattr(extension, "input_modifier"):
text = extension.input_modifier(text)
elif typ == "output" and hasattr(extension, "output_modifier"):
text = extension.output_modifier(text)
elif typ == "bot_prefix" and hasattr(extension, "bot_prefix_modifier"):
text = extension.bot_prefix_modifier(text)
return text
def create_extensions_block():
global setup_called
# Updating the default values
for extension, name in iterator():
if hasattr(extension, 'params'):
for param in extension.params:
_id = f"{name}-{param}"
if _id in shared.settings:
extension.params[param] = shared.settings[_id]
should_display_ui = False
for extension, name in iterator():
if hasattr(extension, "ui"):
should_display_ui = True
# Creating the extension ui elements
if should_display_ui:
with gr.Column(elem_id="extensions"):
for extension, name in iterator():
gr.Markdown(f"\n### {name}")
if hasattr(extension, "ui"):
extension.ui()