Add /v1/internal/model/load endpoint (tentative)

This commit is contained in:
oobabooga 2023-11-07 20:58:06 -08:00
parent 43c53a7820
commit 2358706453
5 changed files with 47 additions and 4 deletions

View File

@ -1,4 +1,6 @@
from modules import shared
from modules.models import load_model, unload_model
from modules.models_settings import get_model_metadata, update_model_parameters
from modules.utils import get_available_models
@ -35,3 +37,27 @@ def get_dummy_models() -> list:
'gpt-3.5-turbo',
'text-embedding-ada-002',
]
def _load_model(data):
model_name = data["model_name"]
args = data["args"]
settings = data["settings"]
unload_model()
model_settings = get_model_metadata(model_name)
update_model_parameters(model_settings, initial=True)
# Update shared.args with custom model loading settings
if args:
for k in args:
if k in shared.args:
setattr(shared.args, k, args[k])
shared.model, shared.tokenizer = load_model(model_name)
# Update shared.settings with custom generation defaults
if settings:
for k in settings:
if k in shared.settings:
shared.settings[k] = settings[k]

View File

@ -1,5 +1,6 @@
import json
import os
import traceback
from threading import Thread
import extensions.openai.completions as OAIcompletions
@ -31,6 +32,7 @@ from .typing import (
DecodeResponse,
EncodeRequest,
EncodeResponse,
LoadModelRequest,
ModelInfoResponse,
TokenCountResponse,
to_dict
@ -231,12 +233,22 @@ async def handle_stop_generation(request: Request):
return JSONResponse(content="OK")
@app.get("/v1/internal/model-info", response_model=ModelInfoResponse)
@app.get("/v1/internal/model/info", response_model=ModelInfoResponse)
async def handle_model_info():
payload = OAImodels.get_current_model_info()
return JSONResponse(content=payload)
@app.post("/v1/internal/model/load")
async def handle_load_model(request_data: LoadModelRequest):
try:
OAImodels._load_model(to_dict(request_data))
return JSONResponse(content="OK")
except:
traceback.print_exc()
return HTTPException(status_code=400, detail="Failed to load the model.")
def run_server():
server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1'
port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port))

View File

@ -147,6 +147,12 @@ class ModelInfoResponse(BaseModel):
lora_names: List[str]
class LoadModelRequest(BaseModel):
model_name: str
args: dict | None = None
settings: dict | None = None
def to_json(obj):
return json.dumps(obj.__dict__, indent=4)

View File

@ -79,7 +79,7 @@ def load_model(model_name, loader=None):
loader = metadata['loader']
if loader is None:
logger.error('The path to the model does not exist. Exiting.')
return None, None
raise ValueError
shared.args.loader = loader
output = load_func_map[loader](model_name)

View File

@ -216,8 +216,7 @@ if __name__ == "__main__":
model_name = shared.model_name
model_settings = get_model_metadata(model_name)
shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) # hijacking the interface defaults
update_model_parameters(model_settings, initial=True) # hijacking the command-line arguments
update_model_parameters(model_settings, initial=True) # hijack the command-line arguments
# Load the model
shared.model, shared.tokenizer = load_model(model_name)