mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-11 21:10:40 +01:00
Add /v1/internal/model/load endpoint (tentative)
This commit is contained in:
parent
43c53a7820
commit
2358706453
@ -1,4 +1,6 @@
|
||||
from modules import shared
|
||||
from modules.models import load_model, unload_model
|
||||
from modules.models_settings import get_model_metadata, update_model_parameters
|
||||
from modules.utils import get_available_models
|
||||
|
||||
|
||||
@ -35,3 +37,27 @@ def get_dummy_models() -> list:
|
||||
'gpt-3.5-turbo',
|
||||
'text-embedding-ada-002',
|
||||
]
|
||||
|
||||
|
||||
def _load_model(data):
|
||||
model_name = data["model_name"]
|
||||
args = data["args"]
|
||||
settings = data["settings"]
|
||||
|
||||
unload_model()
|
||||
model_settings = get_model_metadata(model_name)
|
||||
update_model_parameters(model_settings, initial=True)
|
||||
|
||||
# Update shared.args with custom model loading settings
|
||||
if args:
|
||||
for k in args:
|
||||
if k in shared.args:
|
||||
setattr(shared.args, k, args[k])
|
||||
|
||||
shared.model, shared.tokenizer = load_model(model_name)
|
||||
|
||||
# Update shared.settings with custom generation defaults
|
||||
if settings:
|
||||
for k in settings:
|
||||
if k in shared.settings:
|
||||
shared.settings[k] = settings[k]
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
from threading import Thread
|
||||
|
||||
import extensions.openai.completions as OAIcompletions
|
||||
@ -31,6 +32,7 @@ from .typing import (
|
||||
DecodeResponse,
|
||||
EncodeRequest,
|
||||
EncodeResponse,
|
||||
LoadModelRequest,
|
||||
ModelInfoResponse,
|
||||
TokenCountResponse,
|
||||
to_dict
|
||||
@ -231,12 +233,22 @@ async def handle_stop_generation(request: Request):
|
||||
return JSONResponse(content="OK")
|
||||
|
||||
|
||||
@app.get("/v1/internal/model-info", response_model=ModelInfoResponse)
|
||||
@app.get("/v1/internal/model/info", response_model=ModelInfoResponse)
|
||||
async def handle_model_info():
|
||||
payload = OAImodels.get_current_model_info()
|
||||
return JSONResponse(content=payload)
|
||||
|
||||
|
||||
@app.post("/v1/internal/model/load")
|
||||
async def handle_load_model(request_data: LoadModelRequest):
|
||||
try:
|
||||
OAImodels._load_model(to_dict(request_data))
|
||||
return JSONResponse(content="OK")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
return HTTPException(status_code=400, detail="Failed to load the model.")
|
||||
|
||||
|
||||
def run_server():
|
||||
server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
||||
port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port))
|
||||
|
@ -147,6 +147,12 @@ class ModelInfoResponse(BaseModel):
|
||||
lora_names: List[str]
|
||||
|
||||
|
||||
class LoadModelRequest(BaseModel):
|
||||
model_name: str
|
||||
args: dict | None = None
|
||||
settings: dict | None = None
|
||||
|
||||
|
||||
def to_json(obj):
|
||||
return json.dumps(obj.__dict__, indent=4)
|
||||
|
||||
|
@ -79,7 +79,7 @@ def load_model(model_name, loader=None):
|
||||
loader = metadata['loader']
|
||||
if loader is None:
|
||||
logger.error('The path to the model does not exist. Exiting.')
|
||||
return None, None
|
||||
raise ValueError
|
||||
|
||||
shared.args.loader = loader
|
||||
output = load_func_map[loader](model_name)
|
||||
|
@ -216,8 +216,7 @@ if __name__ == "__main__":
|
||||
model_name = shared.model_name
|
||||
|
||||
model_settings = get_model_metadata(model_name)
|
||||
shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) # hijacking the interface defaults
|
||||
update_model_parameters(model_settings, initial=True) # hijacking the command-line arguments
|
||||
update_model_parameters(model_settings, initial=True) # hijack the command-line arguments
|
||||
|
||||
# Load the model
|
||||
shared.model, shared.tokenizer = load_model(model_name)
|
||||
|
Loading…
x
Reference in New Issue
Block a user