Add /v1/internal/stop-generation to OpenAI API (#4498)

This commit is contained in:
oobabooga 2023-11-07 00:10:42 -03:00 committed by GitHub
parent 97c21e5667
commit 79b3f5a546
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -18,6 +18,7 @@ from fastapi.requests import Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from modules import shared from modules import shared
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.text_generation import stop_everything_event
from pydub import AudioSegment from pydub import AudioSegment
from sse_starlette import EventSourceResponse from sse_starlette import EventSourceResponse
@ -204,14 +205,7 @@ async def handle_moderations(request: Request):
return JSONResponse(response) return JSONResponse(response)
@app.post("/api/v1/token-count") @app.post("/v1/internal/encode")
async def handle_token_count(request: Request):
body = await request.json()
response = token_count(body['prompt'])
return JSONResponse(response)
@app.post("/api/v1/token/encode")
async def handle_token_encode(request: Request): async def handle_token_encode(request: Request):
body = await request.json() body = await request.json()
encoding_format = body.get("encoding_format", "") encoding_format = body.get("encoding_format", "")
@ -219,7 +213,7 @@ async def handle_token_encode(request: Request):
return JSONResponse(response) return JSONResponse(response)
@app.post("/api/v1/token/decode") @app.post("/v1/internal/decode")
async def handle_token_decode(request: Request): async def handle_token_decode(request: Request):
body = await request.json() body = await request.json()
encoding_format = body.get("encoding_format", "") encoding_format = body.get("encoding_format", "")
@ -227,6 +221,19 @@ async def handle_token_decode(request: Request):
return JSONResponse(response, no_debug=True) return JSONResponse(response, no_debug=True)
@app.post("/v1/internal/token-count")
async def handle_token_count(request: Request):
body = await request.json()
response = token_count(body['prompt'])
return JSONResponse(response)
@app.post("/v1/internal/stop-generation")
async def handle_stop_generation(request: Request):
stop_everything_event()
return JSONResponse(content="OK")
def run_server(): def run_server():
server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1' server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1'
port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port)) port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port))