mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 09:19:23 +01:00
Add /v1/internal/model/load endpoint (tentative)
This commit is contained in:
parent
43c53a7820
commit
2358706453
@ -1,4 +1,6 @@
|
|||||||
from modules import shared
|
from modules import shared
|
||||||
|
from modules.models import load_model, unload_model
|
||||||
|
from modules.models_settings import get_model_metadata, update_model_parameters
|
||||||
from modules.utils import get_available_models
|
from modules.utils import get_available_models
|
||||||
|
|
||||||
|
|
||||||
@ -35,3 +37,27 @@ def get_dummy_models() -> list:
|
|||||||
'gpt-3.5-turbo',
|
'gpt-3.5-turbo',
|
||||||
'text-embedding-ada-002',
|
'text-embedding-ada-002',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_model(data):
|
||||||
|
model_name = data["model_name"]
|
||||||
|
args = data["args"]
|
||||||
|
settings = data["settings"]
|
||||||
|
|
||||||
|
unload_model()
|
||||||
|
model_settings = get_model_metadata(model_name)
|
||||||
|
update_model_parameters(model_settings, initial=True)
|
||||||
|
|
||||||
|
# Update shared.args with custom model loading settings
|
||||||
|
if args:
|
||||||
|
for k in args:
|
||||||
|
if k in shared.args:
|
||||||
|
setattr(shared.args, k, args[k])
|
||||||
|
|
||||||
|
shared.model, shared.tokenizer = load_model(model_name)
|
||||||
|
|
||||||
|
# Update shared.settings with custom generation defaults
|
||||||
|
if settings:
|
||||||
|
for k in settings:
|
||||||
|
if k in shared.settings:
|
||||||
|
shared.settings[k] = settings[k]
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import traceback
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
import extensions.openai.completions as OAIcompletions
|
import extensions.openai.completions as OAIcompletions
|
||||||
@ -31,6 +32,7 @@ from .typing import (
|
|||||||
DecodeResponse,
|
DecodeResponse,
|
||||||
EncodeRequest,
|
EncodeRequest,
|
||||||
EncodeResponse,
|
EncodeResponse,
|
||||||
|
LoadModelRequest,
|
||||||
ModelInfoResponse,
|
ModelInfoResponse,
|
||||||
TokenCountResponse,
|
TokenCountResponse,
|
||||||
to_dict
|
to_dict
|
||||||
@ -231,12 +233,22 @@ async def handle_stop_generation(request: Request):
|
|||||||
return JSONResponse(content="OK")
|
return JSONResponse(content="OK")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/internal/model-info", response_model=ModelInfoResponse)
|
@app.get("/v1/internal/model/info", response_model=ModelInfoResponse)
|
||||||
async def handle_model_info():
|
async def handle_model_info():
|
||||||
payload = OAImodels.get_current_model_info()
|
payload = OAImodels.get_current_model_info()
|
||||||
return JSONResponse(content=payload)
|
return JSONResponse(content=payload)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/internal/model/load")
|
||||||
|
async def handle_load_model(request_data: LoadModelRequest):
|
||||||
|
try:
|
||||||
|
OAImodels._load_model(to_dict(request_data))
|
||||||
|
return JSONResponse(content="OK")
|
||||||
|
except:
|
||||||
|
traceback.print_exc()
|
||||||
|
return HTTPException(status_code=400, detail="Failed to load the model.")
|
||||||
|
|
||||||
|
|
||||||
def run_server():
|
def run_server():
|
||||||
server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
||||||
port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port))
|
port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port))
|
||||||
|
@ -147,6 +147,12 @@ class ModelInfoResponse(BaseModel):
|
|||||||
lora_names: List[str]
|
lora_names: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class LoadModelRequest(BaseModel):
|
||||||
|
model_name: str
|
||||||
|
args: dict | None = None
|
||||||
|
settings: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
def to_json(obj):
|
def to_json(obj):
|
||||||
return json.dumps(obj.__dict__, indent=4)
|
return json.dumps(obj.__dict__, indent=4)
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ def load_model(model_name, loader=None):
|
|||||||
loader = metadata['loader']
|
loader = metadata['loader']
|
||||||
if loader is None:
|
if loader is None:
|
||||||
logger.error('The path to the model does not exist. Exiting.')
|
logger.error('The path to the model does not exist. Exiting.')
|
||||||
return None, None
|
raise ValueError
|
||||||
|
|
||||||
shared.args.loader = loader
|
shared.args.loader = loader
|
||||||
output = load_func_map[loader](model_name)
|
output = load_func_map[loader](model_name)
|
||||||
|
@ -216,8 +216,7 @@ if __name__ == "__main__":
|
|||||||
model_name = shared.model_name
|
model_name = shared.model_name
|
||||||
|
|
||||||
model_settings = get_model_metadata(model_name)
|
model_settings = get_model_metadata(model_name)
|
||||||
shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) # hijacking the interface defaults
|
update_model_parameters(model_settings, initial=True) # hijack the command-line arguments
|
||||||
update_model_parameters(model_settings, initial=True) # hijacking the command-line arguments
|
|
||||||
|
|
||||||
# Load the model
|
# Load the model
|
||||||
shared.model, shared.tokenizer = load_model(model_name)
|
shared.model, shared.tokenizer = load_model(model_name)
|
||||||
|
Loading…
Reference in New Issue
Block a user