From ed66ca3cdf11e91401129be968473b5fc3ec6b84 Mon Sep 17 00:00:00 2001 From: Jesus Alvarez Date: Thu, 12 Oct 2023 21:31:13 -0700 Subject: [PATCH] Add HTTPS support to APIs (openai and default) (#4270) --------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com> --- extensions/api/blocking_api.py | 19 +++++++++++++++---- extensions/api/streaming_api.py | 28 +++++++++++++++++++++++----- extensions/openai/script.py | 17 +++++++++++++++-- 3 files changed, 53 insertions(+), 11 deletions(-) diff --git a/extensions/api/blocking_api.py b/extensions/api/blocking_api.py index a91fd515..8d5850cf 100644 --- a/extensions/api/blocking_api.py +++ b/extensions/api/blocking_api.py @@ -1,4 +1,5 @@ import json +import ssl from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from threading import Thread @@ -14,6 +15,7 @@ from modules.text_generation import ( stop_everything_event ) from modules.utils import get_available_models +from modules.logging_colors import logger def get_model_info(): @@ -199,11 +201,18 @@ class Handler(BaseHTTPRequestHandler): def _run_server(port: int, share: bool = False, tunnel_id=str): address = '0.0.0.0' if shared.args.listen else '127.0.0.1' - server = ThreadingHTTPServer((address, port), Handler) + ssl_certfile = shared.args.ssl_certfile + ssl_keyfile = shared.args.ssl_keyfile + ssl_verify = True if (ssl_keyfile and ssl_certfile) else False + if ssl_verify: + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + context.load_cert_chain(ssl_certfile, ssl_keyfile) + server.socket = context.wrap_socket(server.socket, server_side=True) + def on_start(public_url: str): - print(f'Starting non-streaming server at public url {public_url}/api') + logger.info(f'Starting non-streaming server at public url {public_url}/api') if share: try: @@ -211,8 +220,10 @@ def _run_server(port: int, share: bool = False, tunnel_id=str): except Exception: pass else: - print( - f'Starting API at http://{address}:{port}/api') + if ssl_verify: + logger.info(f'Starting API at https://{address}:{port}/api') + else: + logger.info(f'Starting API at http://{address}:{port}/api') server.serve_forever() diff --git a/extensions/api/streaming_api.py b/extensions/api/streaming_api.py index 9175eeb0..71113c2e 100644 --- a/extensions/api/streaming_api.py +++ b/extensions/api/streaming_api.py @@ -1,7 +1,10 @@ import asyncio import json +import ssl from threading import Thread +from websockets.server import serve + from extensions.api.util import ( build_parameters, try_start_cloudflared, @@ -10,7 +13,7 @@ from extensions.api.util import ( from modules import shared from modules.chat import generate_chat_reply from modules.text_generation import generate_reply -from websockets.server import serve +from modules.logging_colors import logger PATH = '/api/v1/stream' @@ -98,16 +101,28 @@ async def _handle_connection(websocket, path): async def _run(host: str, port: int): - async with serve(_handle_connection, host, port, ping_interval=None): - await asyncio.Future() # run forever + ssl_certfile = shared.args.ssl_certfile + ssl_keyfile = shared.args.ssl_keyfile + ssl_verify = True if (ssl_keyfile and ssl_certfile) else False + if ssl_verify: + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + context.load_cert_chain(ssl_certfile, ssl_keyfile) + else: + context = None + + async with serve(_handle_connection, host, port, ping_interval=None, ssl=context): + await asyncio.Future() # Run the server forever def _run_server(port: int, share: bool = False, tunnel_id=str): address = '0.0.0.0' if shared.args.listen else '127.0.0.1' + ssl_certfile = shared.args.ssl_certfile + ssl_keyfile = shared.args.ssl_keyfile + ssl_verify = True if (ssl_keyfile and ssl_certfile) else False def on_start(public_url: str): public_url = public_url.replace('https://', 'wss://') - print(f'Starting streaming server at public url {public_url}{PATH}') + logger.info(f'Starting streaming server at public url {public_url}{PATH}') if share: try: @@ -115,7 +130,10 @@ def _run_server(port: int, share: bool = False, tunnel_id=str): except Exception as e: print(e) else: - print(f'Starting streaming server at ws://{address}:{port}{PATH}') + if ssl_verify: + logger.info(f'Starting streaming server at wss://{address}:{port}{PATH}') + else: + logger.info(f'Starting streaming server at ws://{address}:{port}{PATH}') asyncio.run(_run(host=address, port=port)) diff --git a/extensions/openai/script.py b/extensions/openai/script.py index b44fc535..097689bb 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -1,5 +1,6 @@ import json import os +import ssl import traceback from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from threading import Thread @@ -322,6 +323,15 @@ def run_server(): port = int(os.environ.get('OPENEDAI_PORT', params.get('port', 5001))) server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', port) server = ThreadingHTTPServer(server_addr, Handler) + + ssl_certfile=os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile) + ssl_keyfile=os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile) + ssl_verify=True if (ssl_keyfile and ssl_certfile) else False + if ssl_verify: + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + context.load_cert_chain(ssl_certfile, ssl_keyfile) + server.socket = context.wrap_socket(server.socket, server_side=True) + if shared.args.share: try: from flask_cloudflared import _run_cloudflared @@ -330,8 +340,11 @@ def run_server(): except ImportError: print('You should install flask_cloudflared manually') else: - print(f'OpenAI compatible API ready at: OPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1') - + if ssl_verify: + print(f'OpenAI compatible API ready at: OPENAI_API_BASE=https://{server_addr[0]}:{server_addr[1]}/v1') + else: + print(f'OpenAI compatible API ready at: OPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1') + server.serve_forever()