text-generation-webui/extensions/openai/script.py

402 lines
14 KiB
Python
Raw Normal View History

import asyncio
2023-05-03 04:05:38 +02:00
import json
import logging
2023-05-03 04:05:38 +02:00
import os
import traceback
from collections import deque
2023-05-03 03:49:53 +02:00
from threading import Thread
2023-05-10 03:49:39 +02:00
2023-11-17 03:03:06 +01:00
import speech_recognition as sr
import uvicorn
from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.requests import Request
from fastapi.responses import JSONResponse
from pydub import AudioSegment
from sse_starlette import EventSourceResponse
2023-09-16 05:11:16 +02:00
import extensions.openai.completions as OAIcompletions
import extensions.openai.embeddings as OAIembeddings
import extensions.openai.images as OAIimages
import extensions.openai.logits as OAIlogits
2023-09-16 05:11:16 +02:00
import extensions.openai.models as OAImodels
import extensions.openai.moderations as OAImoderations
from extensions.openai.errors import ServiceUnavailableError
2023-09-16 05:11:16 +02:00
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import _start_cloudflared
2023-09-16 05:11:16 +02:00
from modules import shared
from modules.logging_colors import logger
2023-11-16 00:48:33 +01:00
from modules.models import unload_model
from modules.text_generation import stop_everything_event
from .typing import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatPromptResponse,
CompletionRequest,
CompletionResponse,
DecodeRequest,
DecodeResponse,
EmbeddingsRequest,
EmbeddingsResponse,
EncodeRequest,
EncodeResponse,
LoadLorasRequest,
LoadModelRequest,
LogitsRequest,
LogitsResponse,
LoraListResponse,
2023-11-08 03:59:02 +01:00
ModelInfoResponse,
ModelListResponse,
TokenCountResponse,
to_dict
)
2023-05-03 03:49:53 +02:00
params = {
'embedding_device': 'cpu',
'embedding_model': 'sentence-transformers/all-mpnet-base-v2',
'sd_webui_url': '',
'debug': 0
2023-05-03 03:49:53 +02:00
}
streaming_semaphore = asyncio.Semaphore(1)
def verify_api_key(authorization: str = Header(None)) -> None:
expected_api_key = shared.args.api_key
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
raise HTTPException(status_code=401, detail="Unauthorized")
2023-11-19 02:33:27 +01:00
def verify_admin_key(authorization: str = Header(None)) -> None:
expected_api_key = shared.args.admin_key
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
raise HTTPException(status_code=401, detail="Unauthorized")
app = FastAPI()
check_key = [Depends(verify_api_key)]
check_admin_key = [Depends(verify_admin_key)]
# Configure CORS settings to allow all origins, methods, and headers
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
2023-11-17 05:11:55 +01:00
allow_methods=["*"],
allow_headers=["*"]
)
2023-11-19 02:33:27 +01:00
@app.options("/", dependencies=check_key)
async def options_route():
return JSONResponse(content="OK")
2023-11-19 02:33:27 +01:00
@app.post('/v1/completions', response_model=CompletionResponse, dependencies=check_key)
async def openai_completions(request: Request, request_data: CompletionRequest):
path = request.url.path
is_legacy = "/generate" in path
2023-05-03 03:49:53 +02:00
if request_data.stream:
async def generator():
async with streaming_semaphore:
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy)
for resp in response:
disconnected = await request.is_disconnected()
if disconnected:
break
yield {"data": json.dumps(resp)}
return EventSourceResponse(generator()) # SSE streaming
else:
response = OAIcompletions.completions(to_dict(request_data), is_legacy=is_legacy)
return JSONResponse(response)
2023-11-19 02:33:27 +01:00
@app.post('/v1/chat/completions', response_model=ChatCompletionResponse, dependencies=check_key)
async def openai_chat_completions(request: Request, request_data: ChatCompletionRequest):
requested_model = request_data.model
payload = OAImodels.get_current_model_info()
current_model = payload["model_name"]
if not current_model == requested_model:
requested_model_dict = {"model_name": requested_model}
try:
OAImodels._load_model(requested_model_dict)
except:
traceback.print_exc()
return HTTPException(status_code=400, detail="Failed to load the model.")
path = request.url.path
is_legacy = "/generate" in path
2023-05-03 03:49:53 +02:00
if request_data.stream:
async def generator():
async with streaming_semaphore:
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy)
for resp in response:
disconnected = await request.is_disconnected()
if disconnected:
break
yield {"data": json.dumps(resp)}
2023-05-03 03:49:53 +02:00
return EventSourceResponse(generator()) # SSE streaming
2023-07-12 20:33:25 +02:00
else:
response = OAIcompletions.chat_completions(to_dict(request_data), is_legacy=is_legacy)
return JSONResponse(response)
2023-11-19 02:33:27 +01:00
@app.get("/v1/models", dependencies=check_key)
@app.get("/v1/models/{model}", dependencies=check_key)
async def handle_models(request: Request):
path = request.url.path
2023-11-08 04:59:27 +01:00
is_list = request.url.path.split('?')[0].split('#')[0] == '/v1/models'
if is_list:
response = OAImodels.list_dummy_models()
else:
model_name = path[len('/v1/models/'):]
2023-11-08 04:59:27 +01:00
response = OAImodels.model_info_dict(model_name)
2023-05-03 04:05:38 +02:00
2023-11-08 04:59:27 +01:00
return JSONResponse(response)
2023-05-03 03:49:53 +02:00
2023-11-19 02:33:27 +01:00
@app.get('/v1/billing/usage', dependencies=check_key)
def handle_billing_usage():
'''
Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
'''
return JSONResponse(content={"total_usage": 0})
2023-11-19 02:33:27 +01:00
@app.post('/v1/audio/transcriptions', dependencies=check_key)
async def handle_audio_transcription(request: Request):
r = sr.Recognizer()
form = await request.form()
audio_file = await form["file"].read()
audio_data = AudioSegment.from_file(audio_file)
# Convert AudioSegment to raw data
raw_data = audio_data.raw_data
# Create AudioData object
audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
2024-01-06 07:05:03 +01:00
whisper_language = form.getvalue('language', None)
whisper_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny
transcription = {"text": ""}
try:
2024-01-06 07:05:03 +01:00
transcription["text"] = r.recognize_whisper(audio_data, language=whisper_language, model=whisper_model)
except sr.UnknownValueError:
print("Whisper could not understand audio")
transcription["text"] = "Whisper could not understand audio UnknownValueError"
except sr.RequestError as e:
print("Could not request results from Whisper", e)
transcription["text"] = "Whisper could not understand audio RequestError"
return JSONResponse(content=transcription)
2023-11-19 02:33:27 +01:00
@app.post('/v1/images/generations', dependencies=check_key)
async def handle_image_generation(request: Request):
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
2023-05-03 03:49:53 +02:00
body = await request.json()
prompt = body['prompt']
size = body.get('size', '1024x1024')
response_format = body.get('response_format', 'url') # or b64_json
n = body.get('n', 1) # ignore the batch limits of max 10
2023-07-12 20:33:25 +02:00
response = await OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
return JSONResponse(response)
2023-05-03 03:49:53 +02:00
2023-11-19 02:33:27 +01:00
@app.post("/v1/embeddings", response_model=EmbeddingsResponse, dependencies=check_key)
async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
input = request_data.input
if not input:
raise HTTPException(status_code=400, detail="Missing required argument input")
if type(input) is str:
input = [input]
response = OAIembeddings.embeddings(input, request_data.encoding_format)
return JSONResponse(response)
2023-05-03 03:49:53 +02:00
2023-07-12 20:33:25 +02:00
2023-11-19 02:33:27 +01:00
@app.post("/v1/moderations", dependencies=check_key)
async def handle_moderations(request: Request):
body = await request.json()
input = body["input"]
if not input:
raise HTTPException(status_code=400, detail="Missing required argument input")
response = OAImoderations.moderations(input)
return JSONResponse(response)
2023-07-12 20:33:25 +02:00
2023-11-19 02:33:27 +01:00
@app.post("/v1/internal/encode", response_model=EncodeResponse, dependencies=check_key)
async def handle_token_encode(request_data: EncodeRequest):
response = token_encode(request_data.text)
return JSONResponse(response)
2023-07-12 20:33:25 +02:00
2023-11-19 02:33:27 +01:00
@app.post("/v1/internal/decode", response_model=DecodeResponse, dependencies=check_key)
async def handle_token_decode(request_data: DecodeRequest):
response = token_decode(request_data.tokens)
return JSONResponse(response)
2023-05-03 03:49:53 +02:00
2023-11-19 02:33:27 +01:00
@app.post("/v1/internal/token-count", response_model=TokenCountResponse, dependencies=check_key)
async def handle_token_count(request_data: EncodeRequest):
response = token_count(request_data.text)
return JSONResponse(response)
@app.post("/v1/internal/logits", response_model=LogitsResponse, dependencies=check_key)
async def handle_logits(request_data: LogitsRequest):
'''
Given a prompt, returns the top 50 most likely logits as a dict.
The keys are the tokens, and the values are the probabilities.
'''
response = OAIlogits._get_next_logits(to_dict(request_data))
return JSONResponse(response)
@app.post('/v1/internal/chat-prompt', response_model=ChatPromptResponse, dependencies=check_key)
async def handle_chat_prompt(request: Request, request_data: ChatCompletionRequest):
path = request.url.path
is_legacy = "/generate" in path
generator = OAIcompletions.chat_completions_common(to_dict(request_data), is_legacy=is_legacy, prompt_only=True)
response = deque(generator, maxlen=1).pop()
return JSONResponse(response)
2023-11-19 02:33:27 +01:00
@app.post("/v1/internal/stop-generation", dependencies=check_key)
async def handle_stop_generation(request: Request):
stop_everything_event()
return JSONResponse(content="OK")
2023-11-19 02:33:27 +01:00
@app.get("/v1/internal/model/info", response_model=ModelInfoResponse, dependencies=check_key)
2023-11-08 03:59:02 +01:00
async def handle_model_info():
payload = OAImodels.get_current_model_info()
return JSONResponse(content=payload)
@app.get("/v1/internal/model/list", response_model=ModelListResponse, dependencies=check_admin_key)
async def handle_list_models():
payload = OAImodels.list_models()
return JSONResponse(content=payload)
2023-11-19 02:33:27 +01:00
@app.post("/v1/internal/model/load", dependencies=check_admin_key)
async def handle_load_model(request_data: LoadModelRequest):
'''
This endpoint is experimental and may change in the future.
The "args" parameter can be used to modify flags like "--load-in-4bit"
or "--n-gpu-layers" before loading a model. Example:
2023-11-16 03:39:08 +01:00
```
"args": {
"load_in_4bit": true,
"n_gpu_layers": 12
}
2023-11-16 03:39:08 +01:00
```
Note that those settings will remain after loading the model. So you
may need to change them back to load a second model.
The "settings" parameter is also a dict but with keys for the
shared.settings object. It can be used to modify the default instruction
template like this:
2023-11-16 03:39:08 +01:00
```
"settings": {
"instruction_template": "Alpaca"
}
2023-11-16 03:39:08 +01:00
```
'''
try:
OAImodels._load_model(to_dict(request_data))
return JSONResponse(content="OK")
except:
traceback.print_exc()
return HTTPException(status_code=400, detail="Failed to load the model.")
2023-11-19 02:33:27 +01:00
@app.post("/v1/internal/model/unload", dependencies=check_admin_key)
async def handle_unload_model():
2023-11-16 00:48:33 +01:00
unload_model()
@app.get("/v1/internal/lora/list", response_model=LoraListResponse, dependencies=check_admin_key)
async def handle_list_loras():
response = OAImodels.list_loras()
return JSONResponse(content=response)
@app.post("/v1/internal/lora/load", dependencies=check_admin_key)
async def handle_load_loras(request_data: LoadLorasRequest):
try:
OAImodels.load_loras(request_data.lora_names)
return JSONResponse(content="OK")
except:
traceback.print_exc()
return HTTPException(status_code=400, detail="Failed to apply the LoRA(s).")
@app.post("/v1/internal/lora/unload", dependencies=check_admin_key)
async def handle_unload_loras():
OAImodels.unload_all_loras()
2023-11-16 00:48:33 +01:00
return JSONResponse(content="OK")
2023-05-03 03:49:53 +02:00
def run_server():
server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1'
port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port))
ssl_certfile = os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile)
ssl_keyfile = os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile)
if shared.args.public_api:
def on_start(public_url: str):
2023-11-17 03:36:28 +01:00
logger.info(f'OpenAI-compatible API URL:\n\n{public_url}\n')
_start_cloudflared(port, shared.args.public_api_id, max_attempts=3, on_start=on_start)
2023-05-03 03:49:53 +02:00
else:
if ssl_keyfile and ssl_certfile:
2023-11-17 03:36:28 +01:00
logger.info(f'OpenAI-compatible API URL:\n\nhttps://{server_addr}:{port}\n')
else:
2023-11-17 03:36:28 +01:00
logger.info(f'OpenAI-compatible API URL:\n\nhttp://{server_addr}:{port}\n')
if shared.args.api_key:
2023-11-19 02:33:27 +01:00
if not shared.args.admin_key:
shared.args.admin_key = shared.args.api_key
logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n')
2023-11-22 02:56:28 +01:00
if shared.args.admin_key and shared.args.admin_key != shared.args.api_key:
2023-11-19 02:33:27 +01:00
logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n')
logging.getLogger("uvicorn.error").propagate = False
uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile)
2023-05-03 03:49:53 +02:00
def setup():
if shared.args.nowebui:
run_server()
else:
Thread(target=run_server, daemon=True).start()