mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Add HTTPS support to APIs (openai and default) (#4270)
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
43be1be598
commit
ed66ca3cdf
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import ssl
|
||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
@ -14,6 +15,7 @@ from modules.text_generation import (
|
|||||||
stop_everything_event
|
stop_everything_event
|
||||||
)
|
)
|
||||||
from modules.utils import get_available_models
|
from modules.utils import get_available_models
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
|
|
||||||
def get_model_info():
|
def get_model_info():
|
||||||
@ -199,11 +201,18 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
def _run_server(port: int, share: bool = False, tunnel_id=str):
|
def _run_server(port: int, share: bool = False, tunnel_id=str):
|
||||||
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
||||||
|
|
||||||
server = ThreadingHTTPServer((address, port), Handler)
|
server = ThreadingHTTPServer((address, port), Handler)
|
||||||
|
|
||||||
|
ssl_certfile = shared.args.ssl_certfile
|
||||||
|
ssl_keyfile = shared.args.ssl_keyfile
|
||||||
|
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
|
||||||
|
if ssl_verify:
|
||||||
|
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||||
|
context.load_cert_chain(ssl_certfile, ssl_keyfile)
|
||||||
|
server.socket = context.wrap_socket(server.socket, server_side=True)
|
||||||
|
|
||||||
def on_start(public_url: str):
|
def on_start(public_url: str):
|
||||||
print(f'Starting non-streaming server at public url {public_url}/api')
|
logger.info(f'Starting non-streaming server at public url {public_url}/api')
|
||||||
|
|
||||||
if share:
|
if share:
|
||||||
try:
|
try:
|
||||||
@ -211,8 +220,10 @@ def _run_server(port: int, share: bool = False, tunnel_id=str):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
print(
|
if ssl_verify:
|
||||||
f'Starting API at http://{address}:{port}/api')
|
logger.info(f'Starting API at https://{address}:{port}/api')
|
||||||
|
else:
|
||||||
|
logger.info(f'Starting API at http://{address}:{port}/api')
|
||||||
|
|
||||||
server.serve_forever()
|
server.serve_forever()
|
||||||
|
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import ssl
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
|
from websockets.server import serve
|
||||||
|
|
||||||
from extensions.api.util import (
|
from extensions.api.util import (
|
||||||
build_parameters,
|
build_parameters,
|
||||||
try_start_cloudflared,
|
try_start_cloudflared,
|
||||||
@ -10,7 +13,7 @@ from extensions.api.util import (
|
|||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.chat import generate_chat_reply
|
from modules.chat import generate_chat_reply
|
||||||
from modules.text_generation import generate_reply
|
from modules.text_generation import generate_reply
|
||||||
from websockets.server import serve
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
PATH = '/api/v1/stream'
|
PATH = '/api/v1/stream'
|
||||||
|
|
||||||
@ -98,16 +101,28 @@ async def _handle_connection(websocket, path):
|
|||||||
|
|
||||||
|
|
||||||
async def _run(host: str, port: int):
|
async def _run(host: str, port: int):
|
||||||
async with serve(_handle_connection, host, port, ping_interval=None):
|
ssl_certfile = shared.args.ssl_certfile
|
||||||
await asyncio.Future() # run forever
|
ssl_keyfile = shared.args.ssl_keyfile
|
||||||
|
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
|
||||||
|
if ssl_verify:
|
||||||
|
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||||
|
context.load_cert_chain(ssl_certfile, ssl_keyfile)
|
||||||
|
else:
|
||||||
|
context = None
|
||||||
|
|
||||||
|
async with serve(_handle_connection, host, port, ping_interval=None, ssl=context):
|
||||||
|
await asyncio.Future() # Run the server forever
|
||||||
|
|
||||||
|
|
||||||
def _run_server(port: int, share: bool = False, tunnel_id=str):
|
def _run_server(port: int, share: bool = False, tunnel_id=str):
|
||||||
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
||||||
|
ssl_certfile = shared.args.ssl_certfile
|
||||||
|
ssl_keyfile = shared.args.ssl_keyfile
|
||||||
|
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
|
||||||
|
|
||||||
def on_start(public_url: str):
|
def on_start(public_url: str):
|
||||||
public_url = public_url.replace('https://', 'wss://')
|
public_url = public_url.replace('https://', 'wss://')
|
||||||
print(f'Starting streaming server at public url {public_url}{PATH}')
|
logger.info(f'Starting streaming server at public url {public_url}{PATH}')
|
||||||
|
|
||||||
if share:
|
if share:
|
||||||
try:
|
try:
|
||||||
@ -115,7 +130,10 @@ def _run_server(port: int, share: bool = False, tunnel_id=str):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
else:
|
else:
|
||||||
print(f'Starting streaming server at ws://{address}:{port}{PATH}')
|
if ssl_verify:
|
||||||
|
logger.info(f'Starting streaming server at wss://{address}:{port}{PATH}')
|
||||||
|
else:
|
||||||
|
logger.info(f'Starting streaming server at ws://{address}:{port}{PATH}')
|
||||||
|
|
||||||
asyncio.run(_run(host=address, port=port))
|
asyncio.run(_run(host=address, port=port))
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import ssl
|
||||||
import traceback
|
import traceback
|
||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
@ -322,6 +323,15 @@ def run_server():
|
|||||||
port = int(os.environ.get('OPENEDAI_PORT', params.get('port', 5001)))
|
port = int(os.environ.get('OPENEDAI_PORT', params.get('port', 5001)))
|
||||||
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', port)
|
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', port)
|
||||||
server = ThreadingHTTPServer(server_addr, Handler)
|
server = ThreadingHTTPServer(server_addr, Handler)
|
||||||
|
|
||||||
|
ssl_certfile=os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile)
|
||||||
|
ssl_keyfile=os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile)
|
||||||
|
ssl_verify=True if (ssl_keyfile and ssl_certfile) else False
|
||||||
|
if ssl_verify:
|
||||||
|
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||||
|
context.load_cert_chain(ssl_certfile, ssl_keyfile)
|
||||||
|
server.socket = context.wrap_socket(server.socket, server_side=True)
|
||||||
|
|
||||||
if shared.args.share:
|
if shared.args.share:
|
||||||
try:
|
try:
|
||||||
from flask_cloudflared import _run_cloudflared
|
from flask_cloudflared import _run_cloudflared
|
||||||
@ -329,6 +339,9 @@ def run_server():
|
|||||||
print(f'OpenAI compatible API ready at: OPENAI_API_BASE={public_url}/v1')
|
print(f'OpenAI compatible API ready at: OPENAI_API_BASE={public_url}/v1')
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print('You should install flask_cloudflared manually')
|
print('You should install flask_cloudflared manually')
|
||||||
|
else:
|
||||||
|
if ssl_verify:
|
||||||
|
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=https://{server_addr[0]}:{server_addr[1]}/v1')
|
||||||
else:
|
else:
|
||||||
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')
|
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user