diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 03d99e8d..f23caf9b 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -353,23 +353,38 @@ async def handle_unload_loras(): def run_server(): - server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1' + # Parse configuration port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port)) - ssl_certfile = os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile) ssl_keyfile = os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile) + # In the server configuration: + server_addrs = [] + if os.environ.get('OPENEDAI_ENABLE_IPV6', shared.args.api_enable_ipv6): + server_addrs.append('[::]' if shared.args.listen else '[::1]') + if not os.environ.get('OPENEDAI_DISABLE_IPV4', shared.args.api_disable_ipv4): + server_addrs.append('0.0.0.0' if shared.args.listen else '127.0.0.1') + + if not server_addrs: + raise Exception('you MUST enable IPv6 or IPv4 for the API to work') + + # Log server information if shared.args.public_api: - def on_start(public_url: str): - logger.info(f'OpenAI-compatible API URL:\n\n{public_url}\n') - - _start_cloudflared(port, shared.args.public_api_id, max_attempts=3, on_start=on_start) + _start_cloudflared( + port, + shared.args.public_api_id, + max_attempts=3, + on_start=lambda url: logger.info(f'OpenAI-compatible API URL:\n\n{url}\n') + ) else: - if ssl_keyfile and ssl_certfile: - logger.info(f'OpenAI-compatible API URL:\n\nhttps://{server_addr}:{port}\n') + url_proto = 'https://' if (ssl_certfile and ssl_keyfile) else 'http://' + urls = [f'{url_proto}{addr}:{port}' for addr in server_addrs] + if len(urls) > 1: + logger.info('OpenAI-compatible API URLs:\n\n' + '\n'.join(urls) + '\n') else: - logger.info(f'OpenAI-compatible API URL:\n\nhttp://{server_addr}:{port}\n') + logger.info('OpenAI-compatible API URL:\n\n' + '\n'.join(urls) + '\n') + # Log API keys if shared.args.api_key: if not shared.args.admin_key: shared.args.admin_key = shared.args.api_key @@ -379,8 +394,9 @@ def run_server(): if shared.args.admin_key and shared.args.admin_key != shared.args.api_key: logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n') + # Start server logging.getLogger("uvicorn.error").propagate = False - uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile) + uvicorn.run(app, host=server_addrs, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile) def setup(): diff --git a/modules/shared.py b/modules/shared.py index 7829e462..a0070b1f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -193,6 +193,8 @@ group.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudf group.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.') group.add_argument('--api-key', type=str, default='', help='API authentication key.') group.add_argument('--admin-key', type=str, default='', help='API authentication key for admin tasks like loading and unloading models. If not set, will be the same as --api-key.') +group.add_argument('--api-enable-ipv6', action='store_true', help='Enable IPv6 for the API') +group.add_argument('--api-disable-ipv4', action='store_true', help='Disable IPv4 for the API') group.add_argument('--nowebui', action='store_true', help='Do not launch the Gradio UI. Useful for launching the API in standalone mode.') # Multimodal