mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Add /v1/internal/stop-generation to OpenAI API (#4498)
This commit is contained in:
parent
97c21e5667
commit
79b3f5a546
@ -18,6 +18,7 @@ from fastapi.requests import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.text_generation import stop_everything_event
|
||||
from pydub import AudioSegment
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
@ -204,14 +205,7 @@ async def handle_moderations(request: Request):
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post("/api/v1/token-count")
|
||||
async def handle_token_count(request: Request):
|
||||
body = await request.json()
|
||||
response = token_count(body['prompt'])
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post("/api/v1/token/encode")
|
||||
@app.post("/v1/internal/encode")
|
||||
async def handle_token_encode(request: Request):
|
||||
body = await request.json()
|
||||
encoding_format = body.get("encoding_format", "")
|
||||
@ -219,7 +213,7 @@ async def handle_token_encode(request: Request):
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post("/api/v1/token/decode")
|
||||
@app.post("/v1/internal/decode")
|
||||
async def handle_token_decode(request: Request):
|
||||
body = await request.json()
|
||||
encoding_format = body.get("encoding_format", "")
|
||||
@ -227,6 +221,19 @@ async def handle_token_decode(request: Request):
|
||||
return JSONResponse(response, no_debug=True)
|
||||
|
||||
|
||||
@app.post("/v1/internal/token-count")
|
||||
async def handle_token_count(request: Request):
|
||||
body = await request.json()
|
||||
response = token_count(body['prompt'])
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post("/v1/internal/stop-generation")
|
||||
async def handle_stop_generation(request: Request):
|
||||
stop_everything_event()
|
||||
return JSONResponse(content="OK")
|
||||
|
||||
|
||||
def run_server():
|
||||
server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
||||
port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port))
|
||||
|
Loading…
Reference in New Issue
Block a user