mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
3e7feb699c
* many openai updates * total reorg & cleanup. * fixups * missing import os for images * +moderations, custom_stopping_strings, more fixes * fix bugs in completion streaming * moderation fix (flagged) * updated moderation categories --------- Co-authored-by: Matthew Ashton <mashton-gitlab@zhero.org>
269 lines
9.7 KiB
Python
269 lines
9.7 KiB
Python
import json
|
|
import os
|
|
import traceback
|
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|
from threading import Thread
|
|
|
|
from modules import shared
|
|
|
|
from extensions.openai.tokens import token_count, token_encode, token_decode
|
|
import extensions.openai.models as OAImodels
|
|
import extensions.openai.edits as OAIedits
|
|
import extensions.openai.embeddings as OAIembeddings
|
|
import extensions.openai.images as OAIimages
|
|
import extensions.openai.moderations as OAImoderations
|
|
import extensions.openai.completions as OAIcompletions
|
|
from extensions.openai.errors import *
|
|
from extensions.openai.utils import debug_msg
|
|
from extensions.openai.defaults import (get_default_req_params, default, clamp)
|
|
|
|
|
|
params = {
|
|
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
|
|
}
|
|
|
|
class Handler(BaseHTTPRequestHandler):
|
|
def send_access_control_headers(self):
|
|
self.send_header("Access-Control-Allow-Origin", "*")
|
|
self.send_header("Access-Control-Allow-Credentials", "true")
|
|
self.send_header(
|
|
"Access-Control-Allow-Methods",
|
|
"GET,HEAD,OPTIONS,POST,PUT"
|
|
)
|
|
self.send_header(
|
|
"Access-Control-Allow-Headers",
|
|
"Origin, Accept, X-Requested-With, Content-Type, "
|
|
"Access-Control-Request-Method, Access-Control-Request-Headers, "
|
|
"Authorization"
|
|
)
|
|
|
|
def do_OPTIONS(self):
|
|
self.send_response(200)
|
|
self.send_access_control_headers()
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
self.wfile.write("OK".encode('utf-8'))
|
|
|
|
def start_sse(self):
|
|
self.send_response(200)
|
|
self.send_access_control_headers()
|
|
self.send_header('Content-Type', 'text/event-stream')
|
|
self.send_header('Cache-Control', 'no-cache')
|
|
# self.send_header('Connection', 'keep-alive')
|
|
self.end_headers()
|
|
|
|
def send_sse(self, chunk: dict):
|
|
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
|
debug_msg(response)
|
|
self.wfile.write(response.encode('utf-8'))
|
|
|
|
def end_sse(self):
|
|
self.wfile.write('data: [DONE]\r\n\r\n'.encode('utf-8'))
|
|
|
|
def return_json(self, ret: dict, code: int = 200, no_debug=False):
|
|
self.send_response(code)
|
|
self.send_access_control_headers()
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
|
|
response = json.dumps(ret)
|
|
r_utf8 = response.encode('utf-8')
|
|
self.wfile.write(r_utf8)
|
|
if not no_debug:
|
|
debug_msg(r_utf8)
|
|
|
|
def openai_error(self, message, code = 500, error_type = 'APIError', param = '', internal_message = ''):
|
|
|
|
error_resp = {
|
|
'error': {
|
|
'message': message,
|
|
'code': code,
|
|
'type': error_type,
|
|
'param': param,
|
|
}
|
|
}
|
|
if internal_message:
|
|
print(internal_message)
|
|
#error_resp['internal_message'] = internal_message
|
|
|
|
self.return_json(error_resp, code)
|
|
|
|
def openai_error_handler(func):
|
|
def wrapper(self):
|
|
try:
|
|
func(self)
|
|
except ServiceUnavailableError as e:
|
|
self.openai_error(e.message, e.code, e.error_type, internal_message=e.internal_message)
|
|
except InvalidRequestError as e:
|
|
self.openai_error(e.message, e.code, e.error_type, e.param, internal_message=e.internal_message)
|
|
except OpenAIError as e:
|
|
self.openai_error(e.message, e.code, e.error_type, internal_message=e.internal_message)
|
|
except Exception as e:
|
|
self.openai_error(repr(e), 500, 'OpenAIError', internal_message=traceback.format_exc())
|
|
|
|
return wrapper
|
|
|
|
@openai_error_handler
|
|
def do_GET(self):
|
|
debug_msg(self.requestline)
|
|
debug_msg(self.headers)
|
|
|
|
if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'):
|
|
is_legacy = 'engines' in self.path
|
|
is_list = self.path in ['/v1/engines', '/v1/models']
|
|
if is_legacy and not is_list:
|
|
model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):]
|
|
resp = OAImodels.load_model(model_name)
|
|
elif is_list:
|
|
resp = OAImodels.list_models(is_legacy)
|
|
else:
|
|
model_name = self.path[len('/v1/models/'):]
|
|
resp = OAImodels.model_info()
|
|
|
|
self.return_json(resp)
|
|
|
|
elif '/billing/usage' in self.path:
|
|
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
|
self.return_json({"total_usage": 0}, no_debug=True)
|
|
|
|
else:
|
|
self.send_error(404)
|
|
|
|
@openai_error_handler
|
|
def do_POST(self):
|
|
debug_msg(self.requestline)
|
|
debug_msg(self.headers)
|
|
|
|
content_length = int(self.headers['Content-Length'])
|
|
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
|
|
|
|
debug_msg(body)
|
|
|
|
if '/completions' in self.path or '/generate' in self.path:
|
|
|
|
if not shared.model:
|
|
self.openai_error("No model loaded.")
|
|
return
|
|
|
|
is_legacy = '/generate' in self.path
|
|
is_streaming = body.get('stream', False)
|
|
|
|
if is_streaming:
|
|
self.start_sse()
|
|
|
|
response = []
|
|
if 'chat' in self.path:
|
|
response = OAIcompletions.stream_chat_completions(body, is_legacy=is_legacy)
|
|
else:
|
|
response = OAIcompletions.stream_completions(body, is_legacy=is_legacy)
|
|
|
|
for resp in response:
|
|
self.send_sse(resp)
|
|
|
|
self.end_sse()
|
|
|
|
else:
|
|
response = ''
|
|
if 'chat' in self.path:
|
|
response = OAIcompletions.chat_completions(body, is_legacy=is_legacy)
|
|
else:
|
|
response = OAIcompletions.completions(body, is_legacy=is_legacy)
|
|
|
|
self.return_json(response)
|
|
|
|
elif '/edits' in self.path:
|
|
# deprecated
|
|
|
|
if not shared.model:
|
|
self.openai_error("No model loaded.")
|
|
return
|
|
|
|
req_params = get_default_req_params()
|
|
|
|
instruction = body['instruction']
|
|
input = body.get('input', '')
|
|
temperature = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
|
top_p = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
|
|
|
|
response = OAIedits.edits(instruction, input, temperature, top_p)
|
|
|
|
self.return_json(response)
|
|
|
|
elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ:
|
|
prompt = body['prompt']
|
|
size = default(body, 'size', '1024x1024')
|
|
response_format = default(body, 'response_format', 'url') # or b64_json
|
|
n = default(body, 'n', 1) # ignore the batch limits of max 10
|
|
|
|
response = OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
|
|
|
|
self.return_json(response, no_debug=True)
|
|
|
|
elif '/embeddings' in self.path:
|
|
encoding_format = body.get('encoding_format', '')
|
|
|
|
input = body.get('input', body.get('text', ''))
|
|
if not input:
|
|
raise InvalidRequestError("Missing required argument input", params='input')
|
|
|
|
if type(input) is str:
|
|
input = [input]
|
|
|
|
response = OAIembeddings.embeddings(input, encoding_format)
|
|
|
|
self.return_json(response, no_debug=True)
|
|
|
|
elif '/moderations' in self.path:
|
|
input = body['input']
|
|
if not input:
|
|
raise InvalidRequestError("Missing required argument input", params='input')
|
|
|
|
response = OAImoderations.moderations(input)
|
|
|
|
self.return_json(response, no_debug=True)
|
|
|
|
elif self.path == '/api/v1/token-count':
|
|
# NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side.
|
|
response = token_count(body['prompt'])
|
|
|
|
self.return_json(response, no_debug=True)
|
|
|
|
elif self.path == '/api/v1/token/encode':
|
|
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
|
|
encoding_format = body.get('encoding_format', '')
|
|
|
|
response = token_encode(body['input'], encoding_format)
|
|
|
|
self.return_json(response, no_debug=True)
|
|
|
|
elif self.path == '/api/v1/token/decode':
|
|
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
|
|
encoding_format = body.get('encoding_format', '')
|
|
|
|
response = token_decode(body['input'], encoding_format)
|
|
|
|
self.return_json(response, no_debug=True)
|
|
|
|
else:
|
|
self.send_error(404)
|
|
|
|
|
|
def run_server():
|
|
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
|
|
server = ThreadingHTTPServer(server_addr, Handler)
|
|
if shared.args.share:
|
|
try:
|
|
from flask_cloudflared import _run_cloudflared
|
|
public_url = _run_cloudflared(params['port'], params['port'] + 1)
|
|
print(f'Starting OpenAI compatible api at\nOPENAI_API_BASE={public_url}/v1')
|
|
except ImportError:
|
|
print('You should install flask_cloudflared manually')
|
|
else:
|
|
print(f'Starting OpenAI compatible api:\nOPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')
|
|
|
|
server.serve_forever()
|
|
|
|
|
|
def setup():
|
|
Thread(target=run_server, daemon=True).start()
|