diff --git a/extensions/openai/models.py b/extensions/openai/models.py index b213c1f8..4e31a700 100644 --- a/extensions/openai/models.py +++ b/extensions/openai/models.py @@ -1,9 +1,4 @@ -from extensions.openai.embeddings import get_embeddings_model_name -from extensions.openai.errors import OpenAIError from modules import shared -from modules.models import load_model as _load_model -from modules.models import unload_model -from modules.models_settings import get_model_metadata, update_model_parameters from modules.utils import get_available_models @@ -14,72 +9,29 @@ def get_current_model_info(): } -def get_current_model_list() -> list: - return [shared.model_name] # The real chat/completions model, maybe "None" +def list_models(): + result = { + "object": "list", + "data": [] + } + + for model in get_dummy_models() + get_available_models()[1:]: + result["data"].append(model_info_dict(model)) + + return result -def get_pseudo_model_list() -> list: +def model_info_dict(model_name: str) -> dict: + return { + "id": model_name, + "object": "model", + "created": 0, + "owned_by": "user" + } + + +def get_dummy_models() -> list: return [ # these are expected by so much, so include some here as a dummy 'gpt-3.5-turbo', 'text-embedding-ada-002', ] - - -def load_model(model_name: str) -> dict: - resp = { - "id": model_name, - "object": "engine", - "owner": "self", - "ready": True, - } - if model_name not in get_pseudo_model_list() + [get_embeddings_model_name()] + get_current_model_list(): # Real model only - # No args. Maybe it works anyways! - # TODO: hack some heuristics into args for better results - - shared.model_name = model_name - unload_model() - - model_settings = get_model_metadata(shared.model_name) - shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) - update_model_parameters(model_settings, initial=True) - - if shared.settings['mode'] != 'instruct': - shared.settings['instruction_template'] = None - - shared.model, shared.tokenizer = _load_model(shared.model_name) - - if not shared.model: # load failed. - shared.model_name = "None" - raise OpenAIError(f"Model load failed for: {shared.model_name}") - - return resp - - -def list_models(is_legacy: bool = False) -> dict: - # TODO: Lora's? - all_model_list = get_current_model_list() + [get_embeddings_model_name()] + get_pseudo_model_list() + get_available_models() - - models = {} - - if is_legacy: - models = [{"id": id, "object": "engine", "owner": "user", "ready": True} for id in all_model_list] - if not shared.model: - models[0]['ready'] = False - else: - models = [{"id": id, "object": "model", "owned_by": "user", "permission": []} for id in all_model_list] - - resp = { - "object": "list", - "data": models, - } - - return resp - - -def model_info(model_name: str) -> dict: - return { - "id": model_name, - "object": "model", - "owned_by": "user", - "permission": [] - } diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 361b97a3..c9b3fb03 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -112,22 +112,18 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion @app.get("/v1/models") -@app.get("/v1/engines") +@app.get("/v1/models/{model}") async def handle_models(request: Request): path = request.url.path - is_legacy = 'engines' in path - is_list = request.url.path.split('?')[0].split('#')[0] in ['/v1/engines', '/v1/models'] + is_list = request.url.path.split('?')[0].split('#')[0] == '/v1/models' - if is_legacy and not is_list: - model_name = path[path.find('/v1/engines/') + len('/v1/engines/'):] - resp = OAImodels.load_model(model_name) - elif is_list: - resp = OAImodels.list_models(is_legacy) + if is_list: + response = OAImodels.list_models() else: model_name = path[len('/v1/models/'):] - resp = OAImodels.model_info(model_name) + response = OAImodels.model_info_dict(model_name) - return JSONResponse(content=resp) + return JSONResponse(response) @app.get('/v1/billing/usage') diff --git a/modules/utils.py b/modules/utils.py index 369d0b70..69953da7 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -71,12 +71,12 @@ def natural_keys(text): def get_available_models(): - model_list = ['None'] + model_list = [] for item in list(Path(f'{shared.args.model_dir}/').glob('*')): if not item.name.endswith(('.txt', '-np', '.pt', '.json', '.yaml', '.py')) and 'llama-tokenizer' not in item.name: model_list.append(re.sub('.pth$', '', item.name)) - return sorted(model_list, key=natural_keys) + return ['None'] + sorted(model_list, key=natural_keys) def get_available_presets():