mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-25 13:58:56 +01:00
Refactor the /v1/models endpoint
This commit is contained in:
parent
1b69694fe9
commit
43c53a7820
@ -1,9 +1,4 @@
|
||||
from extensions.openai.embeddings import get_embeddings_model_name
|
||||
from extensions.openai.errors import OpenAIError
|
||||
from modules import shared
|
||||
from modules.models import load_model as _load_model
|
||||
from modules.models import unload_model
|
||||
from modules.models_settings import get_model_metadata, update_model_parameters
|
||||
from modules.utils import get_available_models
|
||||
|
||||
|
||||
@ -14,72 +9,29 @@ def get_current_model_info():
|
||||
}
|
||||
|
||||
|
||||
def get_current_model_list() -> list:
|
||||
return [shared.model_name] # The real chat/completions model, maybe "None"
|
||||
def list_models():
|
||||
result = {
|
||||
"object": "list",
|
||||
"data": []
|
||||
}
|
||||
|
||||
for model in get_dummy_models() + get_available_models()[1:]:
|
||||
result["data"].append(model_info_dict(model))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_pseudo_model_list() -> list:
|
||||
def model_info_dict(model_name: str) -> dict:
|
||||
return {
|
||||
"id": model_name,
|
||||
"object": "model",
|
||||
"created": 0,
|
||||
"owned_by": "user"
|
||||
}
|
||||
|
||||
|
||||
def get_dummy_models() -> list:
|
||||
return [ # these are expected by so much, so include some here as a dummy
|
||||
'gpt-3.5-turbo',
|
||||
'text-embedding-ada-002',
|
||||
]
|
||||
|
||||
|
||||
def load_model(model_name: str) -> dict:
|
||||
resp = {
|
||||
"id": model_name,
|
||||
"object": "engine",
|
||||
"owner": "self",
|
||||
"ready": True,
|
||||
}
|
||||
if model_name not in get_pseudo_model_list() + [get_embeddings_model_name()] + get_current_model_list(): # Real model only
|
||||
# No args. Maybe it works anyways!
|
||||
# TODO: hack some heuristics into args for better results
|
||||
|
||||
shared.model_name = model_name
|
||||
unload_model()
|
||||
|
||||
model_settings = get_model_metadata(shared.model_name)
|
||||
shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings})
|
||||
update_model_parameters(model_settings, initial=True)
|
||||
|
||||
if shared.settings['mode'] != 'instruct':
|
||||
shared.settings['instruction_template'] = None
|
||||
|
||||
shared.model, shared.tokenizer = _load_model(shared.model_name)
|
||||
|
||||
if not shared.model: # load failed.
|
||||
shared.model_name = "None"
|
||||
raise OpenAIError(f"Model load failed for: {shared.model_name}")
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
def list_models(is_legacy: bool = False) -> dict:
|
||||
# TODO: Lora's?
|
||||
all_model_list = get_current_model_list() + [get_embeddings_model_name()] + get_pseudo_model_list() + get_available_models()
|
||||
|
||||
models = {}
|
||||
|
||||
if is_legacy:
|
||||
models = [{"id": id, "object": "engine", "owner": "user", "ready": True} for id in all_model_list]
|
||||
if not shared.model:
|
||||
models[0]['ready'] = False
|
||||
else:
|
||||
models = [{"id": id, "object": "model", "owned_by": "user", "permission": []} for id in all_model_list]
|
||||
|
||||
resp = {
|
||||
"object": "list",
|
||||
"data": models,
|
||||
}
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
def model_info(model_name: str) -> dict:
|
||||
return {
|
||||
"id": model_name,
|
||||
"object": "model",
|
||||
"owned_by": "user",
|
||||
"permission": []
|
||||
}
|
||||
|
@ -112,22 +112,18 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
@app.get("/v1/engines")
|
||||
@app.get("/v1/models/{model}")
|
||||
async def handle_models(request: Request):
|
||||
path = request.url.path
|
||||
is_legacy = 'engines' in path
|
||||
is_list = request.url.path.split('?')[0].split('#')[0] in ['/v1/engines', '/v1/models']
|
||||
is_list = request.url.path.split('?')[0].split('#')[0] == '/v1/models'
|
||||
|
||||
if is_legacy and not is_list:
|
||||
model_name = path[path.find('/v1/engines/') + len('/v1/engines/'):]
|
||||
resp = OAImodels.load_model(model_name)
|
||||
elif is_list:
|
||||
resp = OAImodels.list_models(is_legacy)
|
||||
if is_list:
|
||||
response = OAImodels.list_models()
|
||||
else:
|
||||
model_name = path[len('/v1/models/'):]
|
||||
resp = OAImodels.model_info(model_name)
|
||||
response = OAImodels.model_info_dict(model_name)
|
||||
|
||||
return JSONResponse(content=resp)
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.get('/v1/billing/usage')
|
||||
|
@ -71,12 +71,12 @@ def natural_keys(text):
|
||||
|
||||
|
||||
def get_available_models():
|
||||
model_list = ['None']
|
||||
model_list = []
|
||||
for item in list(Path(f'{shared.args.model_dir}/').glob('*')):
|
||||
if not item.name.endswith(('.txt', '-np', '.pt', '.json', '.yaml', '.py')) and 'llama-tokenizer' not in item.name:
|
||||
model_list.append(re.sub('.pth$', '', item.name))
|
||||
|
||||
return sorted(model_list, key=natural_keys)
|
||||
return ['None'] + sorted(model_list, key=natural_keys)
|
||||
|
||||
|
||||
def get_available_presets():
|
||||
|
Loading…
Reference in New Issue
Block a user