Merge pull request #433 from mayaeary/fix/api-reload

Fix api extension duplicating
This commit is contained in:
oobabooga 2023-03-24 16:56:10 -03:00 committed by GitHub
commit c14e598f14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 4 deletions

View File

@ -1,8 +1,9 @@
import json
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread from threading import Thread
from modules import shared from modules import shared
from modules.text_generation import generate_reply, encode from modules.text_generation import encode, generate_reply
import json
params = { params = {
'port': 5000, 'port': 5000,
@ -87,5 +88,5 @@ def run_server():
print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api') print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
server.serve_forever() server.serve_forever()
def ui(): def setup():
Thread(target=run_server, daemon=True).start() Thread(target=run_server, daemon=True).start()

View File

@ -7,6 +7,7 @@ import modules.shared as shared
state = {} state = {}
available_extensions = [] available_extensions = []
setup_called = False
def load_extensions(): def load_extensions():
global state global state
@ -39,6 +40,8 @@ def apply_extensions(text, typ):
return text return text
def create_extensions_block(): def create_extensions_block():
global setup_called
# Updating the default values # Updating the default values
for extension, name in iterator(): for extension, name in iterator():
if hasattr(extension, 'params'): if hasattr(extension, 'params'):
@ -47,8 +50,19 @@ def create_extensions_block():
if _id in shared.settings: if _id in shared.settings:
extension.params[param] = shared.settings[_id] extension.params[param] = shared.settings[_id]
should_display_ui = False
# Running setup function
if not setup_called:
for extension, name in iterator():
if hasattr(extension, "setup"):
extension.setup()
if hasattr(extension, "ui"):
should_display_ui = True
setup_called = True
# Creating the extension ui elements # Creating the extension ui elements
if len(state) > 0: if should_display_ui:
with gr.Box(elem_id="extensions"): with gr.Box(elem_id="extensions"):
gr.Markdown("Extensions") gr.Markdown("Extensions")
for extension, name in iterator(): for extension, name in iterator():