diff --git a/extensions/openai/models.py b/extensions/openai/models.py index 83e550f8..b213c1f8 100644 --- a/extensions/openai/models.py +++ b/extensions/openai/models.py @@ -7,6 +7,13 @@ from modules.models_settings import get_model_metadata, update_model_parameters from modules.utils import get_available_models +def get_current_model_info(): + return { + 'model_name': shared.model_name, + 'lora_names': shared.lora_names + } + + def get_current_model_list() -> list: return [shared.model_name] # The real chat/completions model, maybe "None" diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 71c1ddf2..72c2776b 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -27,6 +27,7 @@ from .typing import ( ChatCompletionResponse, CompletionRequest, CompletionResponse, + ModelInfoResponse, to_dict ) @@ -234,6 +235,12 @@ async def handle_stop_generation(request: Request): return JSONResponse(content="OK") +@app.get("/v1/internal/model-info", response_model=ModelInfoResponse) +async def handle_model_info(): + payload = OAImodels.get_current_model_info() + return JSONResponse(content=payload) + + def run_server(): server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1' port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port)) diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index 31fb03db..4e0211b2 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -121,6 +121,11 @@ class ChatCompletionResponse(BaseModel): usage: dict +class ModelInfoResponse(BaseModel): + model_name: str + lora_names: List[str] + + def to_json(obj): return json.dumps(obj.__dict__, indent=4)