mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 09:40:20 +01:00
Change the timing for setup() calls
This commit is contained in:
parent
e563b015d8
commit
42ea6a3fc0
@ -11,29 +11,31 @@ setup_called = set()
|
|||||||
|
|
||||||
|
|
||||||
def load_extensions():
|
def load_extensions():
|
||||||
global state
|
global state, setup_called
|
||||||
for i, name in enumerate(shared.args.extensions):
|
for i, name in enumerate(shared.args.extensions):
|
||||||
if name in available_extensions:
|
if name in available_extensions:
|
||||||
print(f'Loading the extension "{name}"... ', end='')
|
print(f'Loading the extension "{name}"... ', end='')
|
||||||
try:
|
try:
|
||||||
exec(f"import extensions.{name}.script")
|
exec(f"import extensions.{name}.script")
|
||||||
|
extension = eval(f"extensions.{name}.script")
|
||||||
|
if extension not in setup_called and hasattr(extension, "setup"):
|
||||||
|
setup_called.add(extension)
|
||||||
|
extension.setup()
|
||||||
state[name] = [True, i]
|
state[name] = [True, i]
|
||||||
print('Ok.')
|
print('Ok.')
|
||||||
except:
|
except:
|
||||||
print('Fail.')
|
print('Fail.')
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
# This iterator returns the extensions in the order specified in the command-line
|
# This iterator returns the extensions in the order specified in the command-line
|
||||||
|
|
||||||
|
|
||||||
def iterator():
|
def iterator():
|
||||||
for name in sorted(state, key=lambda x: state[x][1]):
|
for name in sorted(state, key=lambda x: state[x][1]):
|
||||||
if state[name][0]:
|
if state[name][0]:
|
||||||
yield eval(f"extensions.{name}.script"), name
|
yield eval(f"extensions.{name}.script"), name
|
||||||
|
|
||||||
|
|
||||||
# Extension functions that map string -> string
|
# Extension functions that map string -> string
|
||||||
|
|
||||||
|
|
||||||
def apply_extensions(text, typ):
|
def apply_extensions(text, typ):
|
||||||
for extension, _ in iterator():
|
for extension, _ in iterator():
|
||||||
if typ == "input" and hasattr(extension, "input_modifier"):
|
if typ == "input" and hasattr(extension, "input_modifier"):
|
||||||
@ -57,14 +59,9 @@ def create_extensions_block():
|
|||||||
extension.params[param] = shared.settings[_id]
|
extension.params[param] = shared.settings[_id]
|
||||||
|
|
||||||
should_display_ui = False
|
should_display_ui = False
|
||||||
|
|
||||||
# Running setup function
|
|
||||||
for extension, name in iterator():
|
for extension, name in iterator():
|
||||||
if hasattr(extension, "ui"):
|
if hasattr(extension, "ui"):
|
||||||
should_display_ui = True
|
should_display_ui = True
|
||||||
if extension not in setup_called and hasattr(extension, "setup"):
|
|
||||||
setup_called.add(extension)
|
|
||||||
extension.setup()
|
|
||||||
|
|
||||||
# Creating the extension ui elements
|
# Creating the extension ui elements
|
||||||
if should_display_ui:
|
if should_display_ui:
|
||||||
|
Loading…
Reference in New Issue
Block a user