From 4aabff3728e00ee2a4afa5e6ee8079bab0523dbd Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 10 Nov 2023 06:39:08 -0800 Subject: [PATCH] Remove old API, launch OpenAI API with --api --- Colab-TextGen-GPU.ipynb | 2 +- docs/12 - OpenAI API.md | 52 +++---- extensions/api/blocking_api.py | 232 -------------------------------- extensions/api/requirements.txt | 2 - extensions/api/script.py | 15 --- extensions/api/streaming_api.py | 142 ------------------- extensions/api/util.py | 156 --------------------- extensions/openai/script.py | 6 +- modules/shared.py | 5 +- server.py | 1 + 10 files changed, 33 insertions(+), 580 deletions(-) delete mode 100644 extensions/api/blocking_api.py delete mode 100644 extensions/api/requirements.txt delete mode 100644 extensions/api/script.py delete mode 100644 extensions/api/streaming_api.py delete mode 100644 extensions/api/util.py diff --git a/Colab-TextGen-GPU.ipynb b/Colab-TextGen-GPU.ipynb index 66953bda..73580a91 100644 --- a/Colab-TextGen-GPU.ipynb +++ b/Colab-TextGen-GPU.ipynb @@ -75,7 +75,7 @@ " with open('temp_requirements.txt', 'w') as file:\n", " file.write('\\n'.join(textgen_requirements))\n", "\n", - " !pip install -r extensions/api/requirements.txt --upgrade\n", + " !pip install -r extensions/openai/requirements.txt --upgrade\n", " !pip install -r temp_requirements.txt --upgrade\n", "\n", " print(\"\\033[1;32;1m\\n --> If you see a warning about \\\"previously imported packages\\\", just ignore it.\\033[0;37;0m\")\n", diff --git a/docs/12 - OpenAI API.md b/docs/12 - OpenAI API.md index c0261785..e779f492 100644 --- a/docs/12 - OpenAI API.md +++ b/docs/12 - OpenAI API.md @@ -10,7 +10,7 @@ pip install -r extensions/openai/requirements.txt ### Starting the API -Add `--extensions openai` to your command-line flags. +Add `--api` to your command-line flags. * To create a public Cloudflare URL, add the `--public-api` flag. * To listen on your local network, add the `--listen` flag. @@ -18,31 +18,6 @@ Add `--extensions openai` to your command-line flags. * To use SSL, add `--ssl-keyfile key.pem --ssl-certfile cert.pem`. Note that it doesn't work with `--public-api`. * To use an API key for authentication, add `--api-key yourkey`. -#### Environment variables - -The following environment variables can be used (they take precendence over everything else): - -| Variable Name | Description | Example Value | -|------------------------|------------------------------------|----------------------------| -| `OPENEDAI_PORT` | Port number | 5000 | -| `OPENEDAI_CERT_PATH` | SSL certificate file path | cert.pem | -| `OPENEDAI_KEY_PATH` | SSL key file path | key.pem | -| `OPENEDAI_DEBUG` | Enable debugging (set to 1) | 1 | -| `SD_WEBUI_URL` | WebUI URL (used by endpoint) | http://127.0.0.1:7861 | -| `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | all-mpnet-base-v2 | -| `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda | - -#### Persistent settings with `settings.yaml` - -You can also set the following variables in your `settings.yaml` file: - -``` -openai-embedding_device: cuda -openai-embedding_model: all-mpnet-base-v2 -openai-sd_webui_url: http://127.0.0.1:7861 -openai-debug: 1 -``` - ### Examples For the documentation with all the parameters and their types, consult `http://127.0.0.1:5000/docs` or the [typing.py](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/openai/typing.py) file. @@ -220,6 +195,31 @@ for event in client.events(): print() ``` +### Environment variables + +The following environment variables can be used (they take precendence over everything else): + +| Variable Name | Description | Example Value | +|------------------------|------------------------------------|----------------------------| +| `OPENEDAI_PORT` | Port number | 5000 | +| `OPENEDAI_CERT_PATH` | SSL certificate file path | cert.pem | +| `OPENEDAI_KEY_PATH` | SSL key file path | key.pem | +| `OPENEDAI_DEBUG` | Enable debugging (set to 1) | 1 | +| `SD_WEBUI_URL` | WebUI URL (used by endpoint) | http://127.0.0.1:7861 | +| `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | all-mpnet-base-v2 | +| `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda | + +#### Persistent settings with `settings.yaml` + +You can also set the following variables in your `settings.yaml` file: + +``` +openai-embedding_device: cuda +openai-embedding_model: all-mpnet-base-v2 +openai-sd_webui_url: http://127.0.0.1:7861 +openai-debug: 1 +``` + ### Third-party application setup You can usually force an application that uses the OpenAI API to connect to the local API by using the following environment variables: diff --git a/extensions/api/blocking_api.py b/extensions/api/blocking_api.py deleted file mode 100644 index ecc327aa..00000000 --- a/extensions/api/blocking_api.py +++ /dev/null @@ -1,232 +0,0 @@ -import json -import ssl -from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer -from threading import Thread - -from extensions.api.util import build_parameters, try_start_cloudflared -from modules import shared -from modules.chat import generate_chat_reply -from modules.LoRA import add_lora_to_model -from modules.models import load_model, unload_model -from modules.models_settings import get_model_metadata, update_model_parameters -from modules.text_generation import ( - encode, - generate_reply, - stop_everything_event -) -from modules.utils import get_available_models -from modules.logging_colors import logger - - -def get_model_info(): - return { - 'model_name': shared.model_name, - 'lora_names': shared.lora_names, - # dump - 'shared.settings': shared.settings, - 'shared.args': vars(shared.args), - } - - -class Handler(BaseHTTPRequestHandler): - def do_GET(self): - if self.path == '/api/v1/model': - self.send_response(200) - self.end_headers() - response = json.dumps({ - 'result': shared.model_name - }) - - self.wfile.write(response.encode('utf-8')) - else: - self.send_error(404) - - def do_POST(self): - content_length = int(self.headers['Content-Length']) - body = json.loads(self.rfile.read(content_length).decode('utf-8')) - - if self.path == '/api/v1/generate': - self.send_response(200) - self.send_header('Content-Type', 'application/json') - self.end_headers() - - prompt = body['prompt'] - generate_params = build_parameters(body) - stopping_strings = generate_params.pop('stopping_strings') - generate_params['stream'] = False - - generator = generate_reply( - prompt, generate_params, stopping_strings=stopping_strings, is_chat=False) - - answer = '' - for a in generator: - answer = a - - response = json.dumps({ - 'results': [{ - 'text': answer - }] - }) - - self.wfile.write(response.encode('utf-8')) - - elif self.path == '/api/v1/chat': - self.send_response(200) - self.send_header('Content-Type', 'application/json') - self.end_headers() - - user_input = body['user_input'] - regenerate = body.get('regenerate', False) - _continue = body.get('_continue', False) - - generate_params = build_parameters(body, chat=True) - generate_params['stream'] = False - - generator = generate_chat_reply( - user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False) - - answer = generate_params['history'] - for a in generator: - answer = a - - response = json.dumps({ - 'results': [{ - 'history': answer - }] - }) - - self.wfile.write(response.encode('utf-8')) - - elif self.path == '/api/v1/stop-stream': - self.send_response(200) - self.send_header('Content-Type', 'application/json') - self.end_headers() - - stop_everything_event() - - response = json.dumps({ - 'results': 'success' - }) - - self.wfile.write(response.encode('utf-8')) - - elif self.path == '/api/v1/model': - self.send_response(200) - self.send_header('Content-Type', 'application/json') - self.end_headers() - - # by default return the same as the GET interface - result = shared.model_name - - # Actions: info, load, list, unload - action = body.get('action', '') - - if action == 'load': - model_name = body['model_name'] - args = body.get('args', {}) - print('args', args) - for k in args: - setattr(shared.args, k, args[k]) - - shared.model_name = model_name - unload_model() - - model_settings = get_model_metadata(shared.model_name) - shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) - update_model_parameters(model_settings, initial=True) - - if shared.settings['mode'] != 'instruct': - shared.settings['instruction_template'] = None - - try: - shared.model, shared.tokenizer = load_model(shared.model_name) - if shared.args.lora: - add_lora_to_model(shared.args.lora) # list - - except Exception as e: - response = json.dumps({'error': {'message': repr(e)}}) - - self.wfile.write(response.encode('utf-8')) - raise e - - shared.args.model = shared.model_name - - result = get_model_info() - - elif action == 'unload': - unload_model() - shared.model_name = None - shared.args.model = None - result = get_model_info() - - elif action == 'list': - result = get_available_models() - - elif action == 'info': - result = get_model_info() - - response = json.dumps({ - 'result': result, - }) - - self.wfile.write(response.encode('utf-8')) - - elif self.path == '/api/v1/token-count': - self.send_response(200) - self.send_header('Content-Type', 'application/json') - self.end_headers() - - tokens = encode(body['prompt'])[0] - response = json.dumps({ - 'results': [{ - 'tokens': len(tokens) - }] - }) - - self.wfile.write(response.encode('utf-8')) - else: - self.send_error(404) - - def do_OPTIONS(self): - self.send_response(200) - self.end_headers() - - def end_headers(self): - self.send_header('Access-Control-Allow-Origin', '*') - self.send_header('Access-Control-Allow-Methods', '*') - self.send_header('Access-Control-Allow-Headers', '*') - self.send_header('Cache-Control', 'no-store, no-cache, must-revalidate') - super().end_headers() - - -def _run_server(port: int, share: bool = False, tunnel_id=str): - address = '0.0.0.0' if shared.args.listen else '127.0.0.1' - server = ThreadingHTTPServer((address, port), Handler) - - ssl_certfile = shared.args.ssl_certfile - ssl_keyfile = shared.args.ssl_keyfile - ssl_verify = True if (ssl_keyfile and ssl_certfile) else False - if ssl_verify: - context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - context.load_cert_chain(ssl_certfile, ssl_keyfile) - server.socket = context.wrap_socket(server.socket, server_side=True) - - def on_start(public_url: str): - logger.info(f'Blocking API URL: \n\n{public_url}/api\n') - - if share: - try: - try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start) - except Exception: - pass - else: - if ssl_verify: - logger.info(f'Blocking API URL: \n\nhttps://{address}:{port}/api\n') - else: - logger.info(f'Blocking API URL: \n\nhttp://{address}:{port}/api\n') - - server.serve_forever() - - -def start_server(port: int, share: bool = False, tunnel_id=str): - Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start() diff --git a/extensions/api/requirements.txt b/extensions/api/requirements.txt deleted file mode 100644 index e4f26c3a..00000000 --- a/extensions/api/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -flask_cloudflared==0.0.14 -websockets==11.0.2 \ No newline at end of file diff --git a/extensions/api/script.py b/extensions/api/script.py deleted file mode 100644 index 2dc57a7f..00000000 --- a/extensions/api/script.py +++ /dev/null @@ -1,15 +0,0 @@ -import time - -import extensions.api.blocking_api as blocking_api -import extensions.api.streaming_api as streaming_api -from modules import shared -from modules.logging_colors import logger - - -def setup(): - logger.warning("\nThe current API is deprecated and will be replaced with the OpenAI compatible API on November 13th.\nTo test the new API, use \"--extensions openai\" instead of \"--api\".\nFor documentation on the new API, consult:\nhttps://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API") - blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id) - if shared.args.public_api: - time.sleep(5) - - streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id) diff --git a/extensions/api/streaming_api.py b/extensions/api/streaming_api.py deleted file mode 100644 index 2968ed8d..00000000 --- a/extensions/api/streaming_api.py +++ /dev/null @@ -1,142 +0,0 @@ -import asyncio -import json -import ssl -from threading import Thread - -from websockets.server import serve - -from extensions.api.util import ( - build_parameters, - try_start_cloudflared, - with_api_lock -) -from modules import shared -from modules.chat import generate_chat_reply -from modules.text_generation import generate_reply -from modules.logging_colors import logger - -PATH = '/api/v1/stream' - - -@with_api_lock -async def _handle_stream_message(websocket, message): - message = json.loads(message) - - prompt = message['prompt'] - generate_params = build_parameters(message) - stopping_strings = generate_params.pop('stopping_strings') - generate_params['stream'] = True - - generator = generate_reply( - prompt, generate_params, stopping_strings=stopping_strings, is_chat=False) - - # As we stream, only send the new bytes. - skip_index = 0 - message_num = 0 - - for a in generator: - to_send = a[skip_index:] - if to_send is None or chr(0xfffd) in to_send: # partial unicode character, don't send it yet. - continue - - await websocket.send(json.dumps({ - 'event': 'text_stream', - 'message_num': message_num, - 'text': to_send - })) - - await asyncio.sleep(0) - skip_index += len(to_send) - message_num += 1 - - await websocket.send(json.dumps({ - 'event': 'stream_end', - 'message_num': message_num - })) - - -@with_api_lock -async def _handle_chat_stream_message(websocket, message): - body = json.loads(message) - - user_input = body['user_input'] - generate_params = build_parameters(body, chat=True) - generate_params['stream'] = True - regenerate = body.get('regenerate', False) - _continue = body.get('_continue', False) - - generator = generate_chat_reply( - user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False) - - message_num = 0 - for a in generator: - await websocket.send(json.dumps({ - 'event': 'text_stream', - 'message_num': message_num, - 'history': a - })) - - await asyncio.sleep(0) - message_num += 1 - - await websocket.send(json.dumps({ - 'event': 'stream_end', - 'message_num': message_num - })) - - -async def _handle_connection(websocket, path): - - if path == '/api/v1/stream': - async for message in websocket: - await _handle_stream_message(websocket, message) - - elif path == '/api/v1/chat-stream': - async for message in websocket: - await _handle_chat_stream_message(websocket, message) - - else: - print(f'Streaming api: unknown path: {path}') - return - - -async def _run(host: str, port: int): - ssl_certfile = shared.args.ssl_certfile - ssl_keyfile = shared.args.ssl_keyfile - ssl_verify = True if (ssl_keyfile and ssl_certfile) else False - if ssl_verify: - context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - context.load_cert_chain(ssl_certfile, ssl_keyfile) - else: - context = None - - async with serve(_handle_connection, host, port, ping_interval=None, ssl=context): - await asyncio.Future() # Run the server forever - - -def _run_server(port: int, share: bool = False, tunnel_id=str): - address = '0.0.0.0' if shared.args.listen else '127.0.0.1' - ssl_certfile = shared.args.ssl_certfile - ssl_keyfile = shared.args.ssl_keyfile - ssl_verify = True if (ssl_keyfile and ssl_certfile) else False - - def on_start(public_url: str): - public_url = public_url.replace('https://', 'wss://') - logger.info(f'Streaming API URL: \n\n{public_url}{PATH}\n') - - if share: - try: - try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start) - except Exception as e: - print(e) - else: - if ssl_verify: - logger.info(f'Streaming API URL: \n\nwss://{address}:{port}{PATH}\n') - else: - logger.info(f'Streaming API URL: \n\nws://{address}:{port}{PATH}\n') - - asyncio.run(_run(host=address, port=port)) - - -def start_server(port: int, share: bool = False, tunnel_id=str): - Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start() diff --git a/extensions/api/util.py b/extensions/api/util.py deleted file mode 100644 index 2c7e73fc..00000000 --- a/extensions/api/util.py +++ /dev/null @@ -1,156 +0,0 @@ -import asyncio -import functools -import threading -import time -import traceback -from threading import Thread -from typing import Callable, Optional - -from modules import shared -from modules.chat import load_character_memoized -from modules.presets import load_preset_memoized - -# We use a thread local to store the asyncio lock, so that each thread -# has its own lock. This isn't strictly necessary, but it makes it -# such that if we can support multiple worker threads in the future, -# thus handling multiple requests in parallel. -api_tls = threading.local() - - -def build_parameters(body, chat=False): - - generate_params = { - 'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))), - 'auto_max_new_tokens': bool(body.get('auto_max_new_tokens', False)), - 'max_tokens_second': int(body.get('max_tokens_second', 0)), - 'do_sample': bool(body.get('do_sample', True)), - 'temperature': float(body.get('temperature', 0.5)), - 'temperature_last': bool(body.get('temperature_last', False)), - 'top_p': float(body.get('top_p', 1)), - 'min_p': float(body.get('min_p', 0)), - 'typical_p': float(body.get('typical_p', body.get('typical', 1))), - 'epsilon_cutoff': float(body.get('epsilon_cutoff', 0)), - 'eta_cutoff': float(body.get('eta_cutoff', 0)), - 'tfs': float(body.get('tfs', 1)), - 'top_a': float(body.get('top_a', 0)), - 'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))), - 'presence_penalty': float(body.get('presence_penalty', body.get('presence_pen', 0))), - 'frequency_penalty': float(body.get('frequency_penalty', body.get('frequency_pen', 0))), - 'repetition_penalty_range': int(body.get('repetition_penalty_range', 0)), - 'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)), - 'top_k': int(body.get('top_k', 0)), - 'min_length': int(body.get('min_length', 0)), - 'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)), - 'num_beams': int(body.get('num_beams', 1)), - 'penalty_alpha': float(body.get('penalty_alpha', 0)), - 'length_penalty': float(body.get('length_penalty', 1)), - 'early_stopping': bool(body.get('early_stopping', False)), - 'mirostat_mode': int(body.get('mirostat_mode', 0)), - 'mirostat_tau': float(body.get('mirostat_tau', 5)), - 'mirostat_eta': float(body.get('mirostat_eta', 0.1)), - 'grammar_string': str(body.get('grammar_string', '')), - 'guidance_scale': float(body.get('guidance_scale', 1)), - 'negative_prompt': str(body.get('negative_prompt', '')), - 'seed': int(body.get('seed', -1)), - 'add_bos_token': bool(body.get('add_bos_token', True)), - 'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))), - 'custom_token_bans': str(body.get('custom_token_bans', '')), - 'ban_eos_token': bool(body.get('ban_eos_token', False)), - 'skip_special_tokens': bool(body.get('skip_special_tokens', True)), - 'custom_stopping_strings': '', # leave this blank - 'stopping_strings': body.get('stopping_strings', []), - } - - preset_name = body.get('preset', 'None') - if preset_name not in ['None', None, '']: - preset = load_preset_memoized(preset_name) - generate_params.update(preset) - - if chat: - character = body.get('character') - instruction_template = body.get('instruction_template', shared.settings['instruction_template']) - if str(instruction_template) == "None": - instruction_template = "Vicuna-v1.1" - if str(character) == "None": - character = "Assistant" - - name1, name2, _, greeting, context, _, _ = load_character_memoized(character, str(body.get('your_name', shared.settings['name1'])), '', instruct=False) - name1_instruct, name2_instruct, _, _, context_instruct, turn_template, _ = load_character_memoized(instruction_template, '', '', instruct=True) - generate_params.update({ - 'mode': str(body.get('mode', 'chat')), - 'name1': str(body.get('name1', name1)), - 'name2': str(body.get('name2', name2)), - 'context': str(body.get('context', context)), - 'greeting': str(body.get('greeting', greeting)), - 'name1_instruct': str(body.get('name1_instruct', name1_instruct)), - 'name2_instruct': str(body.get('name2_instruct', name2_instruct)), - 'context_instruct': str(body.get('context_instruct', context_instruct)), - 'turn_template': str(body.get('turn_template', turn_template)), - 'chat-instruct_command': str(body.get('chat_instruct_command', body.get('chat-instruct_command', shared.settings['chat-instruct_command']))), - 'history': body.get('history', {'internal': [], 'visible': []}) - }) - - return generate_params - - -def try_start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): - Thread(target=_start_cloudflared, args=[ - port, tunnel_id, max_attempts, on_start], daemon=True).start() - - -def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): - try: - from flask_cloudflared import _run_cloudflared - except ImportError: - print('You should install flask_cloudflared manually') - raise Exception( - 'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.') - - for _ in range(max_attempts): - try: - if tunnel_id is not None: - public_url = _run_cloudflared(port, port + 1, tunnel_id=tunnel_id) - else: - public_url = _run_cloudflared(port, port + 1) - - if on_start: - on_start(public_url) - - return - except Exception: - traceback.print_exc() - time.sleep(3) - - raise Exception('Could not start cloudflared.') - - -def _get_api_lock(tls) -> asyncio.Lock: - """ - The streaming and blocking API implementations each run on their own - thread, and multiplex requests using asyncio. If multiple outstanding - requests are received at once, we will try to acquire the shared lock - shared.generation_lock multiple times in succession in the same thread, - which will cause a deadlock. - - To avoid this, we use this wrapper function to block on an asyncio - lock, and then try and grab the shared lock only while holding - the asyncio lock. - """ - if not hasattr(tls, "asyncio_lock"): - tls.asyncio_lock = asyncio.Lock() - - return tls.asyncio_lock - - -def with_api_lock(func): - """ - This decorator should be added to all streaming API methods which - require access to the shared.generation_lock. It ensures that the - tls.asyncio_lock is acquired before the method is called, and - released afterwards. - """ - @functools.wraps(func) - async def api_wrapper(*args, **kwargs): - async with _get_api_lock(api_tls): - return await func(*args, **kwargs) - return api_wrapper diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 40574a72..8bf78ca2 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -295,14 +295,14 @@ def run_server(): if shared.args.public_api: def on_start(public_url: str): - logger.info(f'OpenAI compatible API URL:\n\n{public_url}/v1\n') + logger.info(f'OpenAI compatible API URL:\n\n{public_url}\n') _start_cloudflared(port, shared.args.public_api_id, max_attempts=3, on_start=on_start) else: if ssl_keyfile and ssl_certfile: - logger.info(f'OpenAI compatible API URL:\n\nhttps://{server_addr}:{port}/v1\n') + logger.info(f'OpenAI compatible API URL:\n\nhttps://{server_addr}:{port}\n') else: - logger.info(f'OpenAI compatible API URL:\n\nhttp://{server_addr}:{port}/v1\n') + logger.info(f'OpenAI compatible API URL:\n\nhttp://{server_addr}:{port}\n') if shared.args.api_key: logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n') diff --git a/modules/shared.py b/modules/shared.py index d7bf3f57..8672e45f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -258,9 +258,8 @@ if args.multimodal_pipeline is not None: add_extension('multimodal') # Activate the API extension -if args.api: - # add_extension('openai', last=True) - add_extension('api', last=True) +if args.api or args.public_api: + add_extension('openai', last=True) # Load model-specific settings with Path(f'{args.model_dir}/config.yaml') as p: diff --git a/server.py b/server.py index 1a87ef45..e9605e3b 100644 --- a/server.py +++ b/server.py @@ -9,6 +9,7 @@ os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' os.environ['BITSANDBYTES_NOWELCOME'] = '1' warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated') +warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict') with RequestBlocker(): import gradio as gr