From 79b3f5a5469a8afa3796841a5eb1d54f2a6aad58 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 7 Nov 2023 00:10:42 -0300 Subject: [PATCH] Add /v1/internal/stop-generation to OpenAI API (#4498) --- extensions/openai/script.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/extensions/openai/script.py b/extensions/openai/script.py index ec145e05..71c1ddf2 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -18,6 +18,7 @@ from fastapi.requests import Request from fastapi.responses import JSONResponse from modules import shared from modules.logging_colors import logger +from modules.text_generation import stop_everything_event from pydub import AudioSegment from sse_starlette import EventSourceResponse @@ -204,14 +205,7 @@ async def handle_moderations(request: Request): return JSONResponse(response) -@app.post("/api/v1/token-count") -async def handle_token_count(request: Request): - body = await request.json() - response = token_count(body['prompt']) - return JSONResponse(response) - - -@app.post("/api/v1/token/encode") +@app.post("/v1/internal/encode") async def handle_token_encode(request: Request): body = await request.json() encoding_format = body.get("encoding_format", "") @@ -219,7 +213,7 @@ async def handle_token_encode(request: Request): return JSONResponse(response) -@app.post("/api/v1/token/decode") +@app.post("/v1/internal/decode") async def handle_token_decode(request: Request): body = await request.json() encoding_format = body.get("encoding_format", "") @@ -227,6 +221,19 @@ async def handle_token_decode(request: Request): return JSONResponse(response, no_debug=True) +@app.post("/v1/internal/token-count") +async def handle_token_count(request: Request): + body = await request.json() + response = token_count(body['prompt']) + return JSONResponse(response) + + +@app.post("/v1/internal/stop-generation") +async def handle_stop_generation(request: Request): + stop_everything_event() + return JSONResponse(content="OK") + + def run_server(): server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1' port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port))