Remove old API, launch OpenAI API with --api

This commit is contained in:
oobabooga 2023-11-10 06:39:08 -08:00
parent 6a7cd01ebf
commit 4aabff3728
10 changed files with 33 additions and 580 deletions

View File

@ -75,7 +75,7 @@
" with open('temp_requirements.txt', 'w') as file:\n", " with open('temp_requirements.txt', 'w') as file:\n",
" file.write('\\n'.join(textgen_requirements))\n", " file.write('\\n'.join(textgen_requirements))\n",
"\n", "\n",
" !pip install -r extensions/api/requirements.txt --upgrade\n", " !pip install -r extensions/openai/requirements.txt --upgrade\n",
" !pip install -r temp_requirements.txt --upgrade\n", " !pip install -r temp_requirements.txt --upgrade\n",
"\n", "\n",
" print(\"\\033[1;32;1m\\n --> If you see a warning about \\\"previously imported packages\\\", just ignore it.\\033[0;37;0m\")\n", " print(\"\\033[1;32;1m\\n --> If you see a warning about \\\"previously imported packages\\\", just ignore it.\\033[0;37;0m\")\n",

View File

@ -10,7 +10,7 @@ pip install -r extensions/openai/requirements.txt
### Starting the API ### Starting the API
Add `--extensions openai` to your command-line flags. Add `--api` to your command-line flags.
* To create a public Cloudflare URL, add the `--public-api` flag. * To create a public Cloudflare URL, add the `--public-api` flag.
* To listen on your local network, add the `--listen` flag. * To listen on your local network, add the `--listen` flag.
@ -18,31 +18,6 @@ Add `--extensions openai` to your command-line flags.
* To use SSL, add `--ssl-keyfile key.pem --ssl-certfile cert.pem`. Note that it doesn't work with `--public-api`. * To use SSL, add `--ssl-keyfile key.pem --ssl-certfile cert.pem`. Note that it doesn't work with `--public-api`.
* To use an API key for authentication, add `--api-key yourkey`. * To use an API key for authentication, add `--api-key yourkey`.
#### Environment variables
The following environment variables can be used (they take precendence over everything else):
| Variable Name | Description | Example Value |
|------------------------|------------------------------------|----------------------------|
| `OPENEDAI_PORT` | Port number | 5000 |
| `OPENEDAI_CERT_PATH` | SSL certificate file path | cert.pem |
| `OPENEDAI_KEY_PATH` | SSL key file path | key.pem |
| `OPENEDAI_DEBUG` | Enable debugging (set to 1) | 1 |
| `SD_WEBUI_URL` | WebUI URL (used by endpoint) | http://127.0.0.1:7861 |
| `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | all-mpnet-base-v2 |
| `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda |
#### Persistent settings with `settings.yaml`
You can also set the following variables in your `settings.yaml` file:
```
openai-embedding_device: cuda
openai-embedding_model: all-mpnet-base-v2
openai-sd_webui_url: http://127.0.0.1:7861
openai-debug: 1
```
### Examples ### Examples
For the documentation with all the parameters and their types, consult `http://127.0.0.1:5000/docs` or the [typing.py](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/openai/typing.py) file. For the documentation with all the parameters and their types, consult `http://127.0.0.1:5000/docs` or the [typing.py](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/openai/typing.py) file.
@ -220,6 +195,31 @@ for event in client.events():
print() print()
``` ```
### Environment variables
The following environment variables can be used (they take precendence over everything else):
| Variable Name | Description | Example Value |
|------------------------|------------------------------------|----------------------------|
| `OPENEDAI_PORT` | Port number | 5000 |
| `OPENEDAI_CERT_PATH` | SSL certificate file path | cert.pem |
| `OPENEDAI_KEY_PATH` | SSL key file path | key.pem |
| `OPENEDAI_DEBUG` | Enable debugging (set to 1) | 1 |
| `SD_WEBUI_URL` | WebUI URL (used by endpoint) | http://127.0.0.1:7861 |
| `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | all-mpnet-base-v2 |
| `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda |
#### Persistent settings with `settings.yaml`
You can also set the following variables in your `settings.yaml` file:
```
openai-embedding_device: cuda
openai-embedding_model: all-mpnet-base-v2
openai-sd_webui_url: http://127.0.0.1:7861
openai-debug: 1
```
### Third-party application setup ### Third-party application setup
You can usually force an application that uses the OpenAI API to connect to the local API by using the following environment variables: You can usually force an application that uses the OpenAI API to connect to the local API by using the following environment variables:

View File

@ -1,232 +0,0 @@
import json
import ssl
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared
from modules.chat import generate_chat_reply
from modules.LoRA import add_lora_to_model
from modules.models import load_model, unload_model
from modules.models_settings import get_model_metadata, update_model_parameters
from modules.text_generation import (
encode,
generate_reply,
stop_everything_event
)
from modules.utils import get_available_models
from modules.logging_colors import logger
def get_model_info():
return {
'model_name': shared.model_name,
'lora_names': shared.lora_names,
# dump
'shared.settings': shared.settings,
'shared.args': vars(shared.args),
}
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == '/api/v1/model':
self.send_response(200)
self.end_headers()
response = json.dumps({
'result': shared.model_name
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def do_POST(self):
content_length = int(self.headers['Content-Length'])
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
if self.path == '/api/v1/generate':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
prompt = body['prompt']
generate_params = build_parameters(body)
stopping_strings = generate_params.pop('stopping_strings')
generate_params['stream'] = False
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
for a in generator:
answer = a
response = json.dumps({
'results': [{
'text': answer
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/chat':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
user_input = body['user_input']
regenerate = body.get('regenerate', False)
_continue = body.get('_continue', False)
generate_params = build_parameters(body, chat=True)
generate_params['stream'] = False
generator = generate_chat_reply(
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
answer = generate_params['history']
for a in generator:
answer = a
response = json.dumps({
'results': [{
'history': answer
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/stop-stream':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
stop_everything_event()
response = json.dumps({
'results': 'success'
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/model':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
# by default return the same as the GET interface
result = shared.model_name
# Actions: info, load, list, unload
action = body.get('action', '')
if action == 'load':
model_name = body['model_name']
args = body.get('args', {})
print('args', args)
for k in args:
setattr(shared.args, k, args[k])
shared.model_name = model_name
unload_model()
model_settings = get_model_metadata(shared.model_name)
shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings})
update_model_parameters(model_settings, initial=True)
if shared.settings['mode'] != 'instruct':
shared.settings['instruction_template'] = None
try:
shared.model, shared.tokenizer = load_model(shared.model_name)
if shared.args.lora:
add_lora_to_model(shared.args.lora) # list
except Exception as e:
response = json.dumps({'error': {'message': repr(e)}})
self.wfile.write(response.encode('utf-8'))
raise e
shared.args.model = shared.model_name
result = get_model_info()
elif action == 'unload':
unload_model()
shared.model_name = None
shared.args.model = None
result = get_model_info()
elif action == 'list':
result = get_available_models()
elif action == 'info':
result = get_model_info()
response = json.dumps({
'result': result,
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/token-count':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
tokens = encode(body['prompt'])[0]
response = json.dumps({
'results': [{
'tokens': len(tokens)
}]
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def do_OPTIONS(self):
self.send_response(200)
self.end_headers()
def end_headers(self):
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', '*')
self.send_header('Access-Control-Allow-Headers', '*')
self.send_header('Cache-Control', 'no-store, no-cache, must-revalidate')
super().end_headers()
def _run_server(port: int, share: bool = False, tunnel_id=str):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
server = ThreadingHTTPServer((address, port), Handler)
ssl_certfile = shared.args.ssl_certfile
ssl_keyfile = shared.args.ssl_keyfile
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
if ssl_verify:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(ssl_certfile, ssl_keyfile)
server.socket = context.wrap_socket(server.socket, server_side=True)
def on_start(public_url: str):
logger.info(f'Blocking API URL: \n\n{public_url}/api\n')
if share:
try:
try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start)
except Exception:
pass
else:
if ssl_verify:
logger.info(f'Blocking API URL: \n\nhttps://{address}:{port}/api\n')
else:
logger.info(f'Blocking API URL: \n\nhttp://{address}:{port}/api\n')
server.serve_forever()
def start_server(port: int, share: bool = False, tunnel_id=str):
Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start()

View File

@ -1,2 +0,0 @@
flask_cloudflared==0.0.14
websockets==11.0.2

View File

@ -1,15 +0,0 @@
import time
import extensions.api.blocking_api as blocking_api
import extensions.api.streaming_api as streaming_api
from modules import shared
from modules.logging_colors import logger
def setup():
logger.warning("\nThe current API is deprecated and will be replaced with the OpenAI compatible API on November 13th.\nTo test the new API, use \"--extensions openai\" instead of \"--api\".\nFor documentation on the new API, consult:\nhttps://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API")
blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id)
if shared.args.public_api:
time.sleep(5)
streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id)

View File

@ -1,142 +0,0 @@
import asyncio
import json
import ssl
from threading import Thread
from websockets.server import serve
from extensions.api.util import (
build_parameters,
try_start_cloudflared,
with_api_lock
)
from modules import shared
from modules.chat import generate_chat_reply
from modules.text_generation import generate_reply
from modules.logging_colors import logger
PATH = '/api/v1/stream'
@with_api_lock
async def _handle_stream_message(websocket, message):
message = json.loads(message)
prompt = message['prompt']
generate_params = build_parameters(message)
stopping_strings = generate_params.pop('stopping_strings')
generate_params['stream'] = True
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
# As we stream, only send the new bytes.
skip_index = 0
message_num = 0
for a in generator:
to_send = a[skip_index:]
if to_send is None or chr(0xfffd) in to_send: # partial unicode character, don't send it yet.
continue
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': to_send
}))
await asyncio.sleep(0)
skip_index += len(to_send)
message_num += 1
await websocket.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
@with_api_lock
async def _handle_chat_stream_message(websocket, message):
body = json.loads(message)
user_input = body['user_input']
generate_params = build_parameters(body, chat=True)
generate_params['stream'] = True
regenerate = body.get('regenerate', False)
_continue = body.get('_continue', False)
generator = generate_chat_reply(
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
message_num = 0
for a in generator:
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'history': a
}))
await asyncio.sleep(0)
message_num += 1
await websocket.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
async def _handle_connection(websocket, path):
if path == '/api/v1/stream':
async for message in websocket:
await _handle_stream_message(websocket, message)
elif path == '/api/v1/chat-stream':
async for message in websocket:
await _handle_chat_stream_message(websocket, message)
else:
print(f'Streaming api: unknown path: {path}')
return
async def _run(host: str, port: int):
ssl_certfile = shared.args.ssl_certfile
ssl_keyfile = shared.args.ssl_keyfile
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
if ssl_verify:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(ssl_certfile, ssl_keyfile)
else:
context = None
async with serve(_handle_connection, host, port, ping_interval=None, ssl=context):
await asyncio.Future() # Run the server forever
def _run_server(port: int, share: bool = False, tunnel_id=str):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
ssl_certfile = shared.args.ssl_certfile
ssl_keyfile = shared.args.ssl_keyfile
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
def on_start(public_url: str):
public_url = public_url.replace('https://', 'wss://')
logger.info(f'Streaming API URL: \n\n{public_url}{PATH}\n')
if share:
try:
try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start)
except Exception as e:
print(e)
else:
if ssl_verify:
logger.info(f'Streaming API URL: \n\nwss://{address}:{port}{PATH}\n')
else:
logger.info(f'Streaming API URL: \n\nws://{address}:{port}{PATH}\n')
asyncio.run(_run(host=address, port=port))
def start_server(port: int, share: bool = False, tunnel_id=str):
Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start()

View File

@ -1,156 +0,0 @@
import asyncio
import functools
import threading
import time
import traceback
from threading import Thread
from typing import Callable, Optional
from modules import shared
from modules.chat import load_character_memoized
from modules.presets import load_preset_memoized
# We use a thread local to store the asyncio lock, so that each thread
# has its own lock. This isn't strictly necessary, but it makes it
# such that if we can support multiple worker threads in the future,
# thus handling multiple requests in parallel.
api_tls = threading.local()
def build_parameters(body, chat=False):
generate_params = {
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
'auto_max_new_tokens': bool(body.get('auto_max_new_tokens', False)),
'max_tokens_second': int(body.get('max_tokens_second', 0)),
'do_sample': bool(body.get('do_sample', True)),
'temperature': float(body.get('temperature', 0.5)),
'temperature_last': bool(body.get('temperature_last', False)),
'top_p': float(body.get('top_p', 1)),
'min_p': float(body.get('min_p', 0)),
'typical_p': float(body.get('typical_p', body.get('typical', 1))),
'epsilon_cutoff': float(body.get('epsilon_cutoff', 0)),
'eta_cutoff': float(body.get('eta_cutoff', 0)),
'tfs': float(body.get('tfs', 1)),
'top_a': float(body.get('top_a', 0)),
'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))),
'presence_penalty': float(body.get('presence_penalty', body.get('presence_pen', 0))),
'frequency_penalty': float(body.get('frequency_penalty', body.get('frequency_pen', 0))),
'repetition_penalty_range': int(body.get('repetition_penalty_range', 0)),
'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)),
'top_k': int(body.get('top_k', 0)),
'min_length': int(body.get('min_length', 0)),
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
'num_beams': int(body.get('num_beams', 1)),
'penalty_alpha': float(body.get('penalty_alpha', 0)),
'length_penalty': float(body.get('length_penalty', 1)),
'early_stopping': bool(body.get('early_stopping', False)),
'mirostat_mode': int(body.get('mirostat_mode', 0)),
'mirostat_tau': float(body.get('mirostat_tau', 5)),
'mirostat_eta': float(body.get('mirostat_eta', 0.1)),
'grammar_string': str(body.get('grammar_string', '')),
'guidance_scale': float(body.get('guidance_scale', 1)),
'negative_prompt': str(body.get('negative_prompt', '')),
'seed': int(body.get('seed', -1)),
'add_bos_token': bool(body.get('add_bos_token', True)),
'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))),
'custom_token_bans': str(body.get('custom_token_bans', '')),
'ban_eos_token': bool(body.get('ban_eos_token', False)),
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
'custom_stopping_strings': '', # leave this blank
'stopping_strings': body.get('stopping_strings', []),
}
preset_name = body.get('preset', 'None')
if preset_name not in ['None', None, '']:
preset = load_preset_memoized(preset_name)
generate_params.update(preset)
if chat:
character = body.get('character')
instruction_template = body.get('instruction_template', shared.settings['instruction_template'])
if str(instruction_template) == "None":
instruction_template = "Vicuna-v1.1"
if str(character) == "None":
character = "Assistant"
name1, name2, _, greeting, context, _, _ = load_character_memoized(character, str(body.get('your_name', shared.settings['name1'])), '', instruct=False)
name1_instruct, name2_instruct, _, _, context_instruct, turn_template, _ = load_character_memoized(instruction_template, '', '', instruct=True)
generate_params.update({
'mode': str(body.get('mode', 'chat')),
'name1': str(body.get('name1', name1)),
'name2': str(body.get('name2', name2)),
'context': str(body.get('context', context)),
'greeting': str(body.get('greeting', greeting)),
'name1_instruct': str(body.get('name1_instruct', name1_instruct)),
'name2_instruct': str(body.get('name2_instruct', name2_instruct)),
'context_instruct': str(body.get('context_instruct', context_instruct)),
'turn_template': str(body.get('turn_template', turn_template)),
'chat-instruct_command': str(body.get('chat_instruct_command', body.get('chat-instruct_command', shared.settings['chat-instruct_command']))),
'history': body.get('history', {'internal': [], 'visible': []})
})
return generate_params
def try_start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
Thread(target=_start_cloudflared, args=[
port, tunnel_id, max_attempts, on_start], daemon=True).start()
def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
try:
from flask_cloudflared import _run_cloudflared
except ImportError:
print('You should install flask_cloudflared manually')
raise Exception(
'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.')
for _ in range(max_attempts):
try:
if tunnel_id is not None:
public_url = _run_cloudflared(port, port + 1, tunnel_id=tunnel_id)
else:
public_url = _run_cloudflared(port, port + 1)
if on_start:
on_start(public_url)
return
except Exception:
traceback.print_exc()
time.sleep(3)
raise Exception('Could not start cloudflared.')
def _get_api_lock(tls) -> asyncio.Lock:
"""
The streaming and blocking API implementations each run on their own
thread, and multiplex requests using asyncio. If multiple outstanding
requests are received at once, we will try to acquire the shared lock
shared.generation_lock multiple times in succession in the same thread,
which will cause a deadlock.
To avoid this, we use this wrapper function to block on an asyncio
lock, and then try and grab the shared lock only while holding
the asyncio lock.
"""
if not hasattr(tls, "asyncio_lock"):
tls.asyncio_lock = asyncio.Lock()
return tls.asyncio_lock
def with_api_lock(func):
"""
This decorator should be added to all streaming API methods which
require access to the shared.generation_lock. It ensures that the
tls.asyncio_lock is acquired before the method is called, and
released afterwards.
"""
@functools.wraps(func)
async def api_wrapper(*args, **kwargs):
async with _get_api_lock(api_tls):
return await func(*args, **kwargs)
return api_wrapper

View File

@ -295,14 +295,14 @@ def run_server():
if shared.args.public_api: if shared.args.public_api:
def on_start(public_url: str): def on_start(public_url: str):
logger.info(f'OpenAI compatible API URL:\n\n{public_url}/v1\n') logger.info(f'OpenAI compatible API URL:\n\n{public_url}\n')
_start_cloudflared(port, shared.args.public_api_id, max_attempts=3, on_start=on_start) _start_cloudflared(port, shared.args.public_api_id, max_attempts=3, on_start=on_start)
else: else:
if ssl_keyfile and ssl_certfile: if ssl_keyfile and ssl_certfile:
logger.info(f'OpenAI compatible API URL:\n\nhttps://{server_addr}:{port}/v1\n') logger.info(f'OpenAI compatible API URL:\n\nhttps://{server_addr}:{port}\n')
else: else:
logger.info(f'OpenAI compatible API URL:\n\nhttp://{server_addr}:{port}/v1\n') logger.info(f'OpenAI compatible API URL:\n\nhttp://{server_addr}:{port}\n')
if shared.args.api_key: if shared.args.api_key:
logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n') logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n')

View File

@ -258,9 +258,8 @@ if args.multimodal_pipeline is not None:
add_extension('multimodal') add_extension('multimodal')
# Activate the API extension # Activate the API extension
if args.api: if args.api or args.public_api:
# add_extension('openai', last=True) add_extension('openai', last=True)
add_extension('api', last=True)
# Load model-specific settings # Load model-specific settings
with Path(f'{args.model_dir}/config.yaml') as p: with Path(f'{args.model_dir}/config.yaml') as p:

View File

@ -9,6 +9,7 @@ os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
os.environ['BITSANDBYTES_NOWELCOME'] = '1' os.environ['BITSANDBYTES_NOWELCOME'] = '1'
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated')
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict')
with RequestBlocker(): with RequestBlocker():
import gradio as gr import gradio as gr