Add HTTPS support to APIs (openai and default) (#4270)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
Jesus Alvarez 2023-10-12 21:31:13 -07:00 committed by GitHub
parent 43be1be598
commit ed66ca3cdf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 53 additions and 11 deletions

View File

@ -1,4 +1,5 @@
import json import json
import ssl
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread from threading import Thread
@ -14,6 +15,7 @@ from modules.text_generation import (
stop_everything_event stop_everything_event
) )
from modules.utils import get_available_models from modules.utils import get_available_models
from modules.logging_colors import logger
def get_model_info(): def get_model_info():
@ -199,11 +201,18 @@ class Handler(BaseHTTPRequestHandler):
def _run_server(port: int, share: bool = False, tunnel_id=str): def _run_server(port: int, share: bool = False, tunnel_id=str):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1' address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
server = ThreadingHTTPServer((address, port), Handler) server = ThreadingHTTPServer((address, port), Handler)
ssl_certfile = shared.args.ssl_certfile
ssl_keyfile = shared.args.ssl_keyfile
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
if ssl_verify:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(ssl_certfile, ssl_keyfile)
server.socket = context.wrap_socket(server.socket, server_side=True)
def on_start(public_url: str): def on_start(public_url: str):
print(f'Starting non-streaming server at public url {public_url}/api') logger.info(f'Starting non-streaming server at public url {public_url}/api')
if share: if share:
try: try:
@ -211,8 +220,10 @@ def _run_server(port: int, share: bool = False, tunnel_id=str):
except Exception: except Exception:
pass pass
else: else:
print( if ssl_verify:
f'Starting API at http://{address}:{port}/api') logger.info(f'Starting API at https://{address}:{port}/api')
else:
logger.info(f'Starting API at http://{address}:{port}/api')
server.serve_forever() server.serve_forever()

View File

@ -1,7 +1,10 @@
import asyncio import asyncio
import json import json
import ssl
from threading import Thread from threading import Thread
from websockets.server import serve
from extensions.api.util import ( from extensions.api.util import (
build_parameters, build_parameters,
try_start_cloudflared, try_start_cloudflared,
@ -10,7 +13,7 @@ from extensions.api.util import (
from modules import shared from modules import shared
from modules.chat import generate_chat_reply from modules.chat import generate_chat_reply
from modules.text_generation import generate_reply from modules.text_generation import generate_reply
from websockets.server import serve from modules.logging_colors import logger
PATH = '/api/v1/stream' PATH = '/api/v1/stream'
@ -98,16 +101,28 @@ async def _handle_connection(websocket, path):
async def _run(host: str, port: int): async def _run(host: str, port: int):
async with serve(_handle_connection, host, port, ping_interval=None): ssl_certfile = shared.args.ssl_certfile
await asyncio.Future() # run forever ssl_keyfile = shared.args.ssl_keyfile
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
if ssl_verify:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(ssl_certfile, ssl_keyfile)
else:
context = None
async with serve(_handle_connection, host, port, ping_interval=None, ssl=context):
await asyncio.Future() # Run the server forever
def _run_server(port: int, share: bool = False, tunnel_id=str): def _run_server(port: int, share: bool = False, tunnel_id=str):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1' address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
ssl_certfile = shared.args.ssl_certfile
ssl_keyfile = shared.args.ssl_keyfile
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
def on_start(public_url: str): def on_start(public_url: str):
public_url = public_url.replace('https://', 'wss://') public_url = public_url.replace('https://', 'wss://')
print(f'Starting streaming server at public url {public_url}{PATH}') logger.info(f'Starting streaming server at public url {public_url}{PATH}')
if share: if share:
try: try:
@ -115,7 +130,10 @@ def _run_server(port: int, share: bool = False, tunnel_id=str):
except Exception as e: except Exception as e:
print(e) print(e)
else: else:
print(f'Starting streaming server at ws://{address}:{port}{PATH}') if ssl_verify:
logger.info(f'Starting streaming server at wss://{address}:{port}{PATH}')
else:
logger.info(f'Starting streaming server at ws://{address}:{port}{PATH}')
asyncio.run(_run(host=address, port=port)) asyncio.run(_run(host=address, port=port))

View File

@ -1,5 +1,6 @@
import json import json
import os import os
import ssl
import traceback import traceback
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread from threading import Thread
@ -322,6 +323,15 @@ def run_server():
port = int(os.environ.get('OPENEDAI_PORT', params.get('port', 5001))) port = int(os.environ.get('OPENEDAI_PORT', params.get('port', 5001)))
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', port) server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', port)
server = ThreadingHTTPServer(server_addr, Handler) server = ThreadingHTTPServer(server_addr, Handler)
ssl_certfile=os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile)
ssl_keyfile=os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile)
ssl_verify=True if (ssl_keyfile and ssl_certfile) else False
if ssl_verify:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(ssl_certfile, ssl_keyfile)
server.socket = context.wrap_socket(server.socket, server_side=True)
if shared.args.share: if shared.args.share:
try: try:
from flask_cloudflared import _run_cloudflared from flask_cloudflared import _run_cloudflared
@ -330,8 +340,11 @@ def run_server():
except ImportError: except ImportError:
print('You should install flask_cloudflared manually') print('You should install flask_cloudflared manually')
else: else:
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1') if ssl_verify:
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=https://{server_addr[0]}:{server_addr[1]}/v1')
else:
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')
server.serve_forever() server.serve_forever()