mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 09:19:23 +01:00
Add /v1/internal/stop-generation to OpenAI API (#4498)
This commit is contained in:
parent
97c21e5667
commit
79b3f5a546
@ -18,6 +18,7 @@ from fastapi.requests import Request
|
|||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
from modules.text_generation import stop_everything_event
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
from sse_starlette import EventSourceResponse
|
from sse_starlette import EventSourceResponse
|
||||||
|
|
||||||
@ -204,14 +205,7 @@ async def handle_moderations(request: Request):
|
|||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/v1/token-count")
|
@app.post("/v1/internal/encode")
|
||||||
async def handle_token_count(request: Request):
|
|
||||||
body = await request.json()
|
|
||||||
response = token_count(body['prompt'])
|
|
||||||
return JSONResponse(response)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/v1/token/encode")
|
|
||||||
async def handle_token_encode(request: Request):
|
async def handle_token_encode(request: Request):
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
encoding_format = body.get("encoding_format", "")
|
encoding_format = body.get("encoding_format", "")
|
||||||
@ -219,7 +213,7 @@ async def handle_token_encode(request: Request):
|
|||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/v1/token/decode")
|
@app.post("/v1/internal/decode")
|
||||||
async def handle_token_decode(request: Request):
|
async def handle_token_decode(request: Request):
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
encoding_format = body.get("encoding_format", "")
|
encoding_format = body.get("encoding_format", "")
|
||||||
@ -227,6 +221,19 @@ async def handle_token_decode(request: Request):
|
|||||||
return JSONResponse(response, no_debug=True)
|
return JSONResponse(response, no_debug=True)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/internal/token-count")
|
||||||
|
async def handle_token_count(request: Request):
|
||||||
|
body = await request.json()
|
||||||
|
response = token_count(body['prompt'])
|
||||||
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/internal/stop-generation")
|
||||||
|
async def handle_stop_generation(request: Request):
|
||||||
|
stop_everything_event()
|
||||||
|
return JSONResponse(content="OK")
|
||||||
|
|
||||||
|
|
||||||
def run_server():
|
def run_server():
|
||||||
server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
||||||
port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port))
|
port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port))
|
||||||
|
Loading…
Reference in New Issue
Block a user