From 8f4f4daf8bb7f17bff8e2813053f1aca45e85d8a Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 18 Nov 2023 22:33:27 -0300 Subject: [PATCH] Add --admin-key flag for API (#4649) --- README.md | 1 + extensions/openai/script.py | 50 ++++++++++++++++++++++++------------- modules/shared.py | 1 + 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 3ffaaf10..8c2679cf 100644 --- a/README.md +++ b/README.md @@ -413,6 +413,7 @@ Optionally, you can use the following command-line flags: | `--public-api-id PUBLIC_API_ID` | Tunnel ID for named Cloudflare Tunnel. Use together with public-api option. | | `--api-port API_PORT` | The listening port for the API. | | `--api-key API_KEY` | API authentication key. | +| `--admin-key ADMIN_KEY` | API authentication key for admin tasks like loading and unloading models. If not set, will be the same as --api-key. | #### Multimodal diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 2128444e..43d4b261 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -60,7 +60,15 @@ def verify_api_key(authorization: str = Header(None)) -> None: raise HTTPException(status_code=401, detail="Unauthorized") -app = FastAPI(dependencies=[Depends(verify_api_key)]) +def verify_admin_key(authorization: str = Header(None)) -> None: + expected_api_key = shared.args.admin_key + if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"): + raise HTTPException(status_code=401, detail="Unauthorized") + + +app = FastAPI() +check_key = [Depends(verify_api_key)] +check_admin_key = [Depends(verify_admin_key)] # Configure CORS settings to allow all origins, methods, and headers app.add_middleware( @@ -72,12 +80,12 @@ app.add_middleware( ) -@app.options("/") +@app.options("/", dependencies=check_key) async def options_route(): return JSONResponse(content="OK") -@app.post('/v1/completions', response_model=CompletionResponse) +@app.post('/v1/completions', response_model=CompletionResponse, dependencies=check_key) async def openai_completions(request: Request, request_data: CompletionRequest): path = request.url.path is_legacy = "/generate" in path @@ -100,7 +108,7 @@ async def openai_completions(request: Request, request_data: CompletionRequest): return JSONResponse(response) -@app.post('/v1/chat/completions', response_model=ChatCompletionResponse) +@app.post('/v1/chat/completions', response_model=ChatCompletionResponse, dependencies=check_key) async def openai_chat_completions(request: Request, request_data: ChatCompletionRequest): path = request.url.path is_legacy = "/generate" in path @@ -123,8 +131,8 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion return JSONResponse(response) -@app.get("/v1/models") -@app.get("/v1/models/{model}") +@app.get("/v1/models", dependencies=check_key) +@app.get("/v1/models/{model}", dependencies=check_key) async def handle_models(request: Request): path = request.url.path is_list = request.url.path.split('?')[0].split('#')[0] == '/v1/models' @@ -138,7 +146,7 @@ async def handle_models(request: Request): return JSONResponse(response) -@app.get('/v1/billing/usage') +@app.get('/v1/billing/usage', dependencies=check_key) def handle_billing_usage(): ''' Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31 @@ -146,7 +154,7 @@ def handle_billing_usage(): return JSONResponse(content={"total_usage": 0}) -@app.post('/v1/audio/transcriptions') +@app.post('/v1/audio/transcriptions', dependencies=check_key) async def handle_audio_transcription(request: Request): r = sr.Recognizer() @@ -176,7 +184,7 @@ async def handle_audio_transcription(request: Request): return JSONResponse(content=transcription) -@app.post('/v1/images/generations') +@app.post('/v1/images/generations', dependencies=check_key) async def handle_image_generation(request: Request): if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')): @@ -192,7 +200,7 @@ async def handle_image_generation(request: Request): return JSONResponse(response) -@app.post("/v1/embeddings", response_model=EmbeddingsResponse) +@app.post("/v1/embeddings", response_model=EmbeddingsResponse, dependencies=check_key) async def handle_embeddings(request: Request, request_data: EmbeddingsRequest): input = request_data.input if not input: @@ -205,7 +213,7 @@ async def handle_embeddings(request: Request, request_data: EmbeddingsRequest): return JSONResponse(response) -@app.post("/v1/moderations") +@app.post("/v1/moderations", dependencies=check_key) async def handle_moderations(request: Request): body = await request.json() input = body["input"] @@ -216,37 +224,37 @@ async def handle_moderations(request: Request): return JSONResponse(response) -@app.post("/v1/internal/encode", response_model=EncodeResponse) +@app.post("/v1/internal/encode", response_model=EncodeResponse, dependencies=check_key) async def handle_token_encode(request_data: EncodeRequest): response = token_encode(request_data.text) return JSONResponse(response) -@app.post("/v1/internal/decode", response_model=DecodeResponse) +@app.post("/v1/internal/decode", response_model=DecodeResponse, dependencies=check_key) async def handle_token_decode(request_data: DecodeRequest): response = token_decode(request_data.tokens) return JSONResponse(response) -@app.post("/v1/internal/token-count", response_model=TokenCountResponse) +@app.post("/v1/internal/token-count", response_model=TokenCountResponse, dependencies=check_key) async def handle_token_count(request_data: EncodeRequest): response = token_count(request_data.text) return JSONResponse(response) -@app.post("/v1/internal/stop-generation") +@app.post("/v1/internal/stop-generation", dependencies=check_key) async def handle_stop_generation(request: Request): stop_everything_event() return JSONResponse(content="OK") -@app.get("/v1/internal/model/info", response_model=ModelInfoResponse) +@app.get("/v1/internal/model/info", response_model=ModelInfoResponse, dependencies=check_key) async def handle_model_info(): payload = OAImodels.get_current_model_info() return JSONResponse(content=payload) -@app.post("/v1/internal/model/load") +@app.post("/v1/internal/model/load", dependencies=check_admin_key) async def handle_load_model(request_data: LoadModelRequest): ''' This endpoint is experimental and may change in the future. @@ -283,7 +291,7 @@ async def handle_load_model(request_data: LoadModelRequest): return HTTPException(status_code=400, detail="Failed to load the model.") -@app.post("/v1/internal/model/unload") +@app.post("/v1/internal/model/unload", dependencies=check_admin_key) async def handle_unload_model(): unload_model() return JSONResponse(content="OK") @@ -308,8 +316,14 @@ def run_server(): logger.info(f'OpenAI-compatible API URL:\n\nhttp://{server_addr}:{port}\n') if shared.args.api_key: + if not shared.args.admin_key: + shared.args.admin_key = shared.args.api_key + logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n') + if shared.args.admin_key: + logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n') + uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile) diff --git a/modules/shared.py b/modules/shared.py index 54e72a6c..b139a2cf 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -170,6 +170,7 @@ parser.add_argument('--public-api', action='store_true', help='Create a public U parser.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None) parser.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.') parser.add_argument('--api-key', type=str, default='', help='API authentication key.') +parser.add_argument('--admin-key', type=str, default='', help='API authentication key for admin tasks like loading and unloading models. If not set, will be the same as --api-key.') # Multimodal parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.')