diff --git a/extensions/openai/models.py b/extensions/openai/models.py index a7e67df6..1325a03b 100644 --- a/extensions/openai/models.py +++ b/extensions/openai/models.py @@ -15,7 +15,30 @@ def get_current_model_info(): def list_models(): - return {'model_names': get_available_models()[1:]} + mode = shared.args.model_selection_mode + + result = { + "object": "list", + "data": [] + } + + # Inclure les dummy models si le bit 0 est activé + if mode & 1: + dummy_models = ['gpt-3.5-turbo', 'text-embedding-ada-002'] + for model in dummy_models: + result["data"].append(model_info_dict(model)) + + # Inclure les modèles locaux si le bit 1 est activé + if mode & 2: + if mode & 4: + # Ne renvoyer que le modèle actuellement chargé + result["data"].append(model_info_dict(shared.model_name)) + else: + # Renvoyer tous les modèles disponibles + for model in get_available_models(): + result["data"].append(model_info_dict(model)) + + return result def list_dummy_models(): diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 03d99e8d..e85738c3 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -147,7 +147,7 @@ async def handle_models(request: Request): is_list = request.url.path.split('?')[0].split('#')[0] == '/v1/models' if is_list: - response = OAImodels.list_dummy_models() + response = OAImodels.list_models() else: model_name = path[len('/v1/models/'):] response = OAImodels.model_info_dict(model_name) diff --git a/modules/shared.py b/modules/shared.py index 43533a14..17e1a3ee 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -200,6 +200,7 @@ group.add_argument('--api-port', type=int, default=5000, help='The listening por group.add_argument('--api-key', type=str, default='', help='API authentication key.') group.add_argument('--admin-key', type=str, default='', help='API authentication key for admin tasks like loading and unloading models. If not set, will be the same as --api-key.') group.add_argument('--nowebui', action='store_true', help='Do not launch the Gradio UI. Useful for launching the API in standalone mode.') +group.add_argument('--model-selection-mode', type=int, default=0, help='Model selection mode: bitwise flag. 1=Include dummy models, 2=Include local models, 4=Return only the currently loaded model if local models are included.') # Multimodal group = parser.add_argument_group('Multimodal')