From 2358706453e7969a5aa80c865527db1c0c7f8a70 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 7 Nov 2023 20:58:06 -0800 Subject: [PATCH] Add /v1/internal/model/load endpoint (tentative) --- extensions/openai/models.py | 26 ++++++++++++++++++++++++++ extensions/openai/script.py | 14 +++++++++++++- extensions/openai/typing.py | 6 ++++++ modules/models.py | 2 +- server.py | 3 +-- 5 files changed, 47 insertions(+), 4 deletions(-) diff --git a/extensions/openai/models.py b/extensions/openai/models.py index 4e31a700..053c7ca1 100644 --- a/extensions/openai/models.py +++ b/extensions/openai/models.py @@ -1,4 +1,6 @@ from modules import shared +from modules.models import load_model, unload_model +from modules.models_settings import get_model_metadata, update_model_parameters from modules.utils import get_available_models @@ -35,3 +37,27 @@ def get_dummy_models() -> list: 'gpt-3.5-turbo', 'text-embedding-ada-002', ] + + +def _load_model(data): + model_name = data["model_name"] + args = data["args"] + settings = data["settings"] + + unload_model() + model_settings = get_model_metadata(model_name) + update_model_parameters(model_settings, initial=True) + + # Update shared.args with custom model loading settings + if args: + for k in args: + if k in shared.args: + setattr(shared.args, k, args[k]) + + shared.model, shared.tokenizer = load_model(model_name) + + # Update shared.settings with custom generation defaults + if settings: + for k in settings: + if k in shared.settings: + shared.settings[k] = settings[k] diff --git a/extensions/openai/script.py b/extensions/openai/script.py index c9b3fb03..4f8bb0d2 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -1,5 +1,6 @@ import json import os +import traceback from threading import Thread import extensions.openai.completions as OAIcompletions @@ -31,6 +32,7 @@ from .typing import ( DecodeResponse, EncodeRequest, EncodeResponse, + LoadModelRequest, ModelInfoResponse, TokenCountResponse, to_dict @@ -231,12 +233,22 @@ async def handle_stop_generation(request: Request): return JSONResponse(content="OK") -@app.get("/v1/internal/model-info", response_model=ModelInfoResponse) +@app.get("/v1/internal/model/info", response_model=ModelInfoResponse) async def handle_model_info(): payload = OAImodels.get_current_model_info() return JSONResponse(content=payload) +@app.post("/v1/internal/model/load") +async def handle_load_model(request_data: LoadModelRequest): + try: + OAImodels._load_model(to_dict(request_data)) + return JSONResponse(content="OK") + except: + traceback.print_exc() + return HTTPException(status_code=400, detail="Failed to load the model.") + + def run_server(): server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1' port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port)) diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index da19e2be..11fd5f65 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -147,6 +147,12 @@ class ModelInfoResponse(BaseModel): lora_names: List[str] +class LoadModelRequest(BaseModel): + model_name: str + args: dict | None = None + settings: dict | None = None + + def to_json(obj): return json.dumps(obj.__dict__, indent=4) diff --git a/modules/models.py b/modules/models.py index d0392485..cc9b405c 100644 --- a/modules/models.py +++ b/modules/models.py @@ -79,7 +79,7 @@ def load_model(model_name, loader=None): loader = metadata['loader'] if loader is None: logger.error('The path to the model does not exist. Exiting.') - return None, None + raise ValueError shared.args.loader = loader output = load_func_map[loader](model_name) diff --git a/server.py b/server.py index 4218967f..1a87ef45 100644 --- a/server.py +++ b/server.py @@ -216,8 +216,7 @@ if __name__ == "__main__": model_name = shared.model_name model_settings = get_model_metadata(model_name) - shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) # hijacking the interface defaults - update_model_parameters(model_settings, initial=True) # hijacking the command-line arguments + update_model_parameters(model_settings, initial=True) # hijack the command-line arguments # Load the model shared.model, shared.tokenizer = load_model(model_name)