mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-21 23:57:58 +01:00
Remove old API, launch OpenAI API with --api
This commit is contained in:
parent
6a7cd01ebf
commit
4aabff3728
@ -75,7 +75,7 @@
|
|||||||
" with open('temp_requirements.txt', 'w') as file:\n",
|
" with open('temp_requirements.txt', 'w') as file:\n",
|
||||||
" file.write('\\n'.join(textgen_requirements))\n",
|
" file.write('\\n'.join(textgen_requirements))\n",
|
||||||
"\n",
|
"\n",
|
||||||
" !pip install -r extensions/api/requirements.txt --upgrade\n",
|
" !pip install -r extensions/openai/requirements.txt --upgrade\n",
|
||||||
" !pip install -r temp_requirements.txt --upgrade\n",
|
" !pip install -r temp_requirements.txt --upgrade\n",
|
||||||
"\n",
|
"\n",
|
||||||
" print(\"\\033[1;32;1m\\n --> If you see a warning about \\\"previously imported packages\\\", just ignore it.\\033[0;37;0m\")\n",
|
" print(\"\\033[1;32;1m\\n --> If you see a warning about \\\"previously imported packages\\\", just ignore it.\\033[0;37;0m\")\n",
|
||||||
|
@ -10,7 +10,7 @@ pip install -r extensions/openai/requirements.txt
|
|||||||
|
|
||||||
### Starting the API
|
### Starting the API
|
||||||
|
|
||||||
Add `--extensions openai` to your command-line flags.
|
Add `--api` to your command-line flags.
|
||||||
|
|
||||||
* To create a public Cloudflare URL, add the `--public-api` flag.
|
* To create a public Cloudflare URL, add the `--public-api` flag.
|
||||||
* To listen on your local network, add the `--listen` flag.
|
* To listen on your local network, add the `--listen` flag.
|
||||||
@ -18,31 +18,6 @@ Add `--extensions openai` to your command-line flags.
|
|||||||
* To use SSL, add `--ssl-keyfile key.pem --ssl-certfile cert.pem`. Note that it doesn't work with `--public-api`.
|
* To use SSL, add `--ssl-keyfile key.pem --ssl-certfile cert.pem`. Note that it doesn't work with `--public-api`.
|
||||||
* To use an API key for authentication, add `--api-key yourkey`.
|
* To use an API key for authentication, add `--api-key yourkey`.
|
||||||
|
|
||||||
#### Environment variables
|
|
||||||
|
|
||||||
The following environment variables can be used (they take precendence over everything else):
|
|
||||||
|
|
||||||
| Variable Name | Description | Example Value |
|
|
||||||
|------------------------|------------------------------------|----------------------------|
|
|
||||||
| `OPENEDAI_PORT` | Port number | 5000 |
|
|
||||||
| `OPENEDAI_CERT_PATH` | SSL certificate file path | cert.pem |
|
|
||||||
| `OPENEDAI_KEY_PATH` | SSL key file path | key.pem |
|
|
||||||
| `OPENEDAI_DEBUG` | Enable debugging (set to 1) | 1 |
|
|
||||||
| `SD_WEBUI_URL` | WebUI URL (used by endpoint) | http://127.0.0.1:7861 |
|
|
||||||
| `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | all-mpnet-base-v2 |
|
|
||||||
| `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda |
|
|
||||||
|
|
||||||
#### Persistent settings with `settings.yaml`
|
|
||||||
|
|
||||||
You can also set the following variables in your `settings.yaml` file:
|
|
||||||
|
|
||||||
```
|
|
||||||
openai-embedding_device: cuda
|
|
||||||
openai-embedding_model: all-mpnet-base-v2
|
|
||||||
openai-sd_webui_url: http://127.0.0.1:7861
|
|
||||||
openai-debug: 1
|
|
||||||
```
|
|
||||||
|
|
||||||
### Examples
|
### Examples
|
||||||
|
|
||||||
For the documentation with all the parameters and their types, consult `http://127.0.0.1:5000/docs` or the [typing.py](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/openai/typing.py) file.
|
For the documentation with all the parameters and their types, consult `http://127.0.0.1:5000/docs` or the [typing.py](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/openai/typing.py) file.
|
||||||
@ -220,6 +195,31 @@ for event in client.events():
|
|||||||
print()
|
print()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Environment variables
|
||||||
|
|
||||||
|
The following environment variables can be used (they take precendence over everything else):
|
||||||
|
|
||||||
|
| Variable Name | Description | Example Value |
|
||||||
|
|------------------------|------------------------------------|----------------------------|
|
||||||
|
| `OPENEDAI_PORT` | Port number | 5000 |
|
||||||
|
| `OPENEDAI_CERT_PATH` | SSL certificate file path | cert.pem |
|
||||||
|
| `OPENEDAI_KEY_PATH` | SSL key file path | key.pem |
|
||||||
|
| `OPENEDAI_DEBUG` | Enable debugging (set to 1) | 1 |
|
||||||
|
| `SD_WEBUI_URL` | WebUI URL (used by endpoint) | http://127.0.0.1:7861 |
|
||||||
|
| `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | all-mpnet-base-v2 |
|
||||||
|
| `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda |
|
||||||
|
|
||||||
|
#### Persistent settings with `settings.yaml`
|
||||||
|
|
||||||
|
You can also set the following variables in your `settings.yaml` file:
|
||||||
|
|
||||||
|
```
|
||||||
|
openai-embedding_device: cuda
|
||||||
|
openai-embedding_model: all-mpnet-base-v2
|
||||||
|
openai-sd_webui_url: http://127.0.0.1:7861
|
||||||
|
openai-debug: 1
|
||||||
|
```
|
||||||
|
|
||||||
### Third-party application setup
|
### Third-party application setup
|
||||||
|
|
||||||
You can usually force an application that uses the OpenAI API to connect to the local API by using the following environment variables:
|
You can usually force an application that uses the OpenAI API to connect to the local API by using the following environment variables:
|
||||||
|
@ -1,232 +0,0 @@
|
|||||||
import json
|
|
||||||
import ssl
|
|
||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|
||||||
from threading import Thread
|
|
||||||
|
|
||||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
|
||||||
from modules import shared
|
|
||||||
from modules.chat import generate_chat_reply
|
|
||||||
from modules.LoRA import add_lora_to_model
|
|
||||||
from modules.models import load_model, unload_model
|
|
||||||
from modules.models_settings import get_model_metadata, update_model_parameters
|
|
||||||
from modules.text_generation import (
|
|
||||||
encode,
|
|
||||||
generate_reply,
|
|
||||||
stop_everything_event
|
|
||||||
)
|
|
||||||
from modules.utils import get_available_models
|
|
||||||
from modules.logging_colors import logger
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_info():
|
|
||||||
return {
|
|
||||||
'model_name': shared.model_name,
|
|
||||||
'lora_names': shared.lora_names,
|
|
||||||
# dump
|
|
||||||
'shared.settings': shared.settings,
|
|
||||||
'shared.args': vars(shared.args),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class Handler(BaseHTTPRequestHandler):
|
|
||||||
def do_GET(self):
|
|
||||||
if self.path == '/api/v1/model':
|
|
||||||
self.send_response(200)
|
|
||||||
self.end_headers()
|
|
||||||
response = json.dumps({
|
|
||||||
'result': shared.model_name
|
|
||||||
})
|
|
||||||
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
else:
|
|
||||||
self.send_error(404)
|
|
||||||
|
|
||||||
def do_POST(self):
|
|
||||||
content_length = int(self.headers['Content-Length'])
|
|
||||||
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
|
|
||||||
|
|
||||||
if self.path == '/api/v1/generate':
|
|
||||||
self.send_response(200)
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
prompt = body['prompt']
|
|
||||||
generate_params = build_parameters(body)
|
|
||||||
stopping_strings = generate_params.pop('stopping_strings')
|
|
||||||
generate_params['stream'] = False
|
|
||||||
|
|
||||||
generator = generate_reply(
|
|
||||||
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
|
|
||||||
|
|
||||||
answer = ''
|
|
||||||
for a in generator:
|
|
||||||
answer = a
|
|
||||||
|
|
||||||
response = json.dumps({
|
|
||||||
'results': [{
|
|
||||||
'text': answer
|
|
||||||
}]
|
|
||||||
})
|
|
||||||
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
elif self.path == '/api/v1/chat':
|
|
||||||
self.send_response(200)
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
user_input = body['user_input']
|
|
||||||
regenerate = body.get('regenerate', False)
|
|
||||||
_continue = body.get('_continue', False)
|
|
||||||
|
|
||||||
generate_params = build_parameters(body, chat=True)
|
|
||||||
generate_params['stream'] = False
|
|
||||||
|
|
||||||
generator = generate_chat_reply(
|
|
||||||
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
|
|
||||||
|
|
||||||
answer = generate_params['history']
|
|
||||||
for a in generator:
|
|
||||||
answer = a
|
|
||||||
|
|
||||||
response = json.dumps({
|
|
||||||
'results': [{
|
|
||||||
'history': answer
|
|
||||||
}]
|
|
||||||
})
|
|
||||||
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
elif self.path == '/api/v1/stop-stream':
|
|
||||||
self.send_response(200)
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
stop_everything_event()
|
|
||||||
|
|
||||||
response = json.dumps({
|
|
||||||
'results': 'success'
|
|
||||||
})
|
|
||||||
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
elif self.path == '/api/v1/model':
|
|
||||||
self.send_response(200)
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
# by default return the same as the GET interface
|
|
||||||
result = shared.model_name
|
|
||||||
|
|
||||||
# Actions: info, load, list, unload
|
|
||||||
action = body.get('action', '')
|
|
||||||
|
|
||||||
if action == 'load':
|
|
||||||
model_name = body['model_name']
|
|
||||||
args = body.get('args', {})
|
|
||||||
print('args', args)
|
|
||||||
for k in args:
|
|
||||||
setattr(shared.args, k, args[k])
|
|
||||||
|
|
||||||
shared.model_name = model_name
|
|
||||||
unload_model()
|
|
||||||
|
|
||||||
model_settings = get_model_metadata(shared.model_name)
|
|
||||||
shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings})
|
|
||||||
update_model_parameters(model_settings, initial=True)
|
|
||||||
|
|
||||||
if shared.settings['mode'] != 'instruct':
|
|
||||||
shared.settings['instruction_template'] = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
|
||||||
if shared.args.lora:
|
|
||||||
add_lora_to_model(shared.args.lora) # list
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
response = json.dumps({'error': {'message': repr(e)}})
|
|
||||||
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
raise e
|
|
||||||
|
|
||||||
shared.args.model = shared.model_name
|
|
||||||
|
|
||||||
result = get_model_info()
|
|
||||||
|
|
||||||
elif action == 'unload':
|
|
||||||
unload_model()
|
|
||||||
shared.model_name = None
|
|
||||||
shared.args.model = None
|
|
||||||
result = get_model_info()
|
|
||||||
|
|
||||||
elif action == 'list':
|
|
||||||
result = get_available_models()
|
|
||||||
|
|
||||||
elif action == 'info':
|
|
||||||
result = get_model_info()
|
|
||||||
|
|
||||||
response = json.dumps({
|
|
||||||
'result': result,
|
|
||||||
})
|
|
||||||
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
elif self.path == '/api/v1/token-count':
|
|
||||||
self.send_response(200)
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
tokens = encode(body['prompt'])[0]
|
|
||||||
response = json.dumps({
|
|
||||||
'results': [{
|
|
||||||
'tokens': len(tokens)
|
|
||||||
}]
|
|
||||||
})
|
|
||||||
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
else:
|
|
||||||
self.send_error(404)
|
|
||||||
|
|
||||||
def do_OPTIONS(self):
|
|
||||||
self.send_response(200)
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
def end_headers(self):
|
|
||||||
self.send_header('Access-Control-Allow-Origin', '*')
|
|
||||||
self.send_header('Access-Control-Allow-Methods', '*')
|
|
||||||
self.send_header('Access-Control-Allow-Headers', '*')
|
|
||||||
self.send_header('Cache-Control', 'no-store, no-cache, must-revalidate')
|
|
||||||
super().end_headers()
|
|
||||||
|
|
||||||
|
|
||||||
def _run_server(port: int, share: bool = False, tunnel_id=str):
|
|
||||||
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
|
||||||
server = ThreadingHTTPServer((address, port), Handler)
|
|
||||||
|
|
||||||
ssl_certfile = shared.args.ssl_certfile
|
|
||||||
ssl_keyfile = shared.args.ssl_keyfile
|
|
||||||
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
|
|
||||||
if ssl_verify:
|
|
||||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
||||||
context.load_cert_chain(ssl_certfile, ssl_keyfile)
|
|
||||||
server.socket = context.wrap_socket(server.socket, server_side=True)
|
|
||||||
|
|
||||||
def on_start(public_url: str):
|
|
||||||
logger.info(f'Blocking API URL: \n\n{public_url}/api\n')
|
|
||||||
|
|
||||||
if share:
|
|
||||||
try:
|
|
||||||
try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
if ssl_verify:
|
|
||||||
logger.info(f'Blocking API URL: \n\nhttps://{address}:{port}/api\n')
|
|
||||||
else:
|
|
||||||
logger.info(f'Blocking API URL: \n\nhttp://{address}:{port}/api\n')
|
|
||||||
|
|
||||||
server.serve_forever()
|
|
||||||
|
|
||||||
|
|
||||||
def start_server(port: int, share: bool = False, tunnel_id=str):
|
|
||||||
Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start()
|
|
@ -1,2 +0,0 @@
|
|||||||
flask_cloudflared==0.0.14
|
|
||||||
websockets==11.0.2
|
|
@ -1,15 +0,0 @@
|
|||||||
import time
|
|
||||||
|
|
||||||
import extensions.api.blocking_api as blocking_api
|
|
||||||
import extensions.api.streaming_api as streaming_api
|
|
||||||
from modules import shared
|
|
||||||
from modules.logging_colors import logger
|
|
||||||
|
|
||||||
|
|
||||||
def setup():
|
|
||||||
logger.warning("\nThe current API is deprecated and will be replaced with the OpenAI compatible API on November 13th.\nTo test the new API, use \"--extensions openai\" instead of \"--api\".\nFor documentation on the new API, consult:\nhttps://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API")
|
|
||||||
blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id)
|
|
||||||
if shared.args.public_api:
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id)
|
|
@ -1,142 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import ssl
|
|
||||||
from threading import Thread
|
|
||||||
|
|
||||||
from websockets.server import serve
|
|
||||||
|
|
||||||
from extensions.api.util import (
|
|
||||||
build_parameters,
|
|
||||||
try_start_cloudflared,
|
|
||||||
with_api_lock
|
|
||||||
)
|
|
||||||
from modules import shared
|
|
||||||
from modules.chat import generate_chat_reply
|
|
||||||
from modules.text_generation import generate_reply
|
|
||||||
from modules.logging_colors import logger
|
|
||||||
|
|
||||||
PATH = '/api/v1/stream'
|
|
||||||
|
|
||||||
|
|
||||||
@with_api_lock
|
|
||||||
async def _handle_stream_message(websocket, message):
|
|
||||||
message = json.loads(message)
|
|
||||||
|
|
||||||
prompt = message['prompt']
|
|
||||||
generate_params = build_parameters(message)
|
|
||||||
stopping_strings = generate_params.pop('stopping_strings')
|
|
||||||
generate_params['stream'] = True
|
|
||||||
|
|
||||||
generator = generate_reply(
|
|
||||||
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
|
|
||||||
|
|
||||||
# As we stream, only send the new bytes.
|
|
||||||
skip_index = 0
|
|
||||||
message_num = 0
|
|
||||||
|
|
||||||
for a in generator:
|
|
||||||
to_send = a[skip_index:]
|
|
||||||
if to_send is None or chr(0xfffd) in to_send: # partial unicode character, don't send it yet.
|
|
||||||
continue
|
|
||||||
|
|
||||||
await websocket.send(json.dumps({
|
|
||||||
'event': 'text_stream',
|
|
||||||
'message_num': message_num,
|
|
||||||
'text': to_send
|
|
||||||
}))
|
|
||||||
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
skip_index += len(to_send)
|
|
||||||
message_num += 1
|
|
||||||
|
|
||||||
await websocket.send(json.dumps({
|
|
||||||
'event': 'stream_end',
|
|
||||||
'message_num': message_num
|
|
||||||
}))
|
|
||||||
|
|
||||||
|
|
||||||
@with_api_lock
|
|
||||||
async def _handle_chat_stream_message(websocket, message):
|
|
||||||
body = json.loads(message)
|
|
||||||
|
|
||||||
user_input = body['user_input']
|
|
||||||
generate_params = build_parameters(body, chat=True)
|
|
||||||
generate_params['stream'] = True
|
|
||||||
regenerate = body.get('regenerate', False)
|
|
||||||
_continue = body.get('_continue', False)
|
|
||||||
|
|
||||||
generator = generate_chat_reply(
|
|
||||||
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
|
|
||||||
|
|
||||||
message_num = 0
|
|
||||||
for a in generator:
|
|
||||||
await websocket.send(json.dumps({
|
|
||||||
'event': 'text_stream',
|
|
||||||
'message_num': message_num,
|
|
||||||
'history': a
|
|
||||||
}))
|
|
||||||
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
message_num += 1
|
|
||||||
|
|
||||||
await websocket.send(json.dumps({
|
|
||||||
'event': 'stream_end',
|
|
||||||
'message_num': message_num
|
|
||||||
}))
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_connection(websocket, path):
|
|
||||||
|
|
||||||
if path == '/api/v1/stream':
|
|
||||||
async for message in websocket:
|
|
||||||
await _handle_stream_message(websocket, message)
|
|
||||||
|
|
||||||
elif path == '/api/v1/chat-stream':
|
|
||||||
async for message in websocket:
|
|
||||||
await _handle_chat_stream_message(websocket, message)
|
|
||||||
|
|
||||||
else:
|
|
||||||
print(f'Streaming api: unknown path: {path}')
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
async def _run(host: str, port: int):
|
|
||||||
ssl_certfile = shared.args.ssl_certfile
|
|
||||||
ssl_keyfile = shared.args.ssl_keyfile
|
|
||||||
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
|
|
||||||
if ssl_verify:
|
|
||||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
||||||
context.load_cert_chain(ssl_certfile, ssl_keyfile)
|
|
||||||
else:
|
|
||||||
context = None
|
|
||||||
|
|
||||||
async with serve(_handle_connection, host, port, ping_interval=None, ssl=context):
|
|
||||||
await asyncio.Future() # Run the server forever
|
|
||||||
|
|
||||||
|
|
||||||
def _run_server(port: int, share: bool = False, tunnel_id=str):
|
|
||||||
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
|
||||||
ssl_certfile = shared.args.ssl_certfile
|
|
||||||
ssl_keyfile = shared.args.ssl_keyfile
|
|
||||||
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
|
|
||||||
|
|
||||||
def on_start(public_url: str):
|
|
||||||
public_url = public_url.replace('https://', 'wss://')
|
|
||||||
logger.info(f'Streaming API URL: \n\n{public_url}{PATH}\n')
|
|
||||||
|
|
||||||
if share:
|
|
||||||
try:
|
|
||||||
try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
else:
|
|
||||||
if ssl_verify:
|
|
||||||
logger.info(f'Streaming API URL: \n\nwss://{address}:{port}{PATH}\n')
|
|
||||||
else:
|
|
||||||
logger.info(f'Streaming API URL: \n\nws://{address}:{port}{PATH}\n')
|
|
||||||
|
|
||||||
asyncio.run(_run(host=address, port=port))
|
|
||||||
|
|
||||||
|
|
||||||
def start_server(port: int, share: bool = False, tunnel_id=str):
|
|
||||||
Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start()
|
|
@ -1,156 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import functools
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
from threading import Thread
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
from modules import shared
|
|
||||||
from modules.chat import load_character_memoized
|
|
||||||
from modules.presets import load_preset_memoized
|
|
||||||
|
|
||||||
# We use a thread local to store the asyncio lock, so that each thread
|
|
||||||
# has its own lock. This isn't strictly necessary, but it makes it
|
|
||||||
# such that if we can support multiple worker threads in the future,
|
|
||||||
# thus handling multiple requests in parallel.
|
|
||||||
api_tls = threading.local()
|
|
||||||
|
|
||||||
|
|
||||||
def build_parameters(body, chat=False):
|
|
||||||
|
|
||||||
generate_params = {
|
|
||||||
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
|
|
||||||
'auto_max_new_tokens': bool(body.get('auto_max_new_tokens', False)),
|
|
||||||
'max_tokens_second': int(body.get('max_tokens_second', 0)),
|
|
||||||
'do_sample': bool(body.get('do_sample', True)),
|
|
||||||
'temperature': float(body.get('temperature', 0.5)),
|
|
||||||
'temperature_last': bool(body.get('temperature_last', False)),
|
|
||||||
'top_p': float(body.get('top_p', 1)),
|
|
||||||
'min_p': float(body.get('min_p', 0)),
|
|
||||||
'typical_p': float(body.get('typical_p', body.get('typical', 1))),
|
|
||||||
'epsilon_cutoff': float(body.get('epsilon_cutoff', 0)),
|
|
||||||
'eta_cutoff': float(body.get('eta_cutoff', 0)),
|
|
||||||
'tfs': float(body.get('tfs', 1)),
|
|
||||||
'top_a': float(body.get('top_a', 0)),
|
|
||||||
'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))),
|
|
||||||
'presence_penalty': float(body.get('presence_penalty', body.get('presence_pen', 0))),
|
|
||||||
'frequency_penalty': float(body.get('frequency_penalty', body.get('frequency_pen', 0))),
|
|
||||||
'repetition_penalty_range': int(body.get('repetition_penalty_range', 0)),
|
|
||||||
'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)),
|
|
||||||
'top_k': int(body.get('top_k', 0)),
|
|
||||||
'min_length': int(body.get('min_length', 0)),
|
|
||||||
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
|
|
||||||
'num_beams': int(body.get('num_beams', 1)),
|
|
||||||
'penalty_alpha': float(body.get('penalty_alpha', 0)),
|
|
||||||
'length_penalty': float(body.get('length_penalty', 1)),
|
|
||||||
'early_stopping': bool(body.get('early_stopping', False)),
|
|
||||||
'mirostat_mode': int(body.get('mirostat_mode', 0)),
|
|
||||||
'mirostat_tau': float(body.get('mirostat_tau', 5)),
|
|
||||||
'mirostat_eta': float(body.get('mirostat_eta', 0.1)),
|
|
||||||
'grammar_string': str(body.get('grammar_string', '')),
|
|
||||||
'guidance_scale': float(body.get('guidance_scale', 1)),
|
|
||||||
'negative_prompt': str(body.get('negative_prompt', '')),
|
|
||||||
'seed': int(body.get('seed', -1)),
|
|
||||||
'add_bos_token': bool(body.get('add_bos_token', True)),
|
|
||||||
'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))),
|
|
||||||
'custom_token_bans': str(body.get('custom_token_bans', '')),
|
|
||||||
'ban_eos_token': bool(body.get('ban_eos_token', False)),
|
|
||||||
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
|
|
||||||
'custom_stopping_strings': '', # leave this blank
|
|
||||||
'stopping_strings': body.get('stopping_strings', []),
|
|
||||||
}
|
|
||||||
|
|
||||||
preset_name = body.get('preset', 'None')
|
|
||||||
if preset_name not in ['None', None, '']:
|
|
||||||
preset = load_preset_memoized(preset_name)
|
|
||||||
generate_params.update(preset)
|
|
||||||
|
|
||||||
if chat:
|
|
||||||
character = body.get('character')
|
|
||||||
instruction_template = body.get('instruction_template', shared.settings['instruction_template'])
|
|
||||||
if str(instruction_template) == "None":
|
|
||||||
instruction_template = "Vicuna-v1.1"
|
|
||||||
if str(character) == "None":
|
|
||||||
character = "Assistant"
|
|
||||||
|
|
||||||
name1, name2, _, greeting, context, _, _ = load_character_memoized(character, str(body.get('your_name', shared.settings['name1'])), '', instruct=False)
|
|
||||||
name1_instruct, name2_instruct, _, _, context_instruct, turn_template, _ = load_character_memoized(instruction_template, '', '', instruct=True)
|
|
||||||
generate_params.update({
|
|
||||||
'mode': str(body.get('mode', 'chat')),
|
|
||||||
'name1': str(body.get('name1', name1)),
|
|
||||||
'name2': str(body.get('name2', name2)),
|
|
||||||
'context': str(body.get('context', context)),
|
|
||||||
'greeting': str(body.get('greeting', greeting)),
|
|
||||||
'name1_instruct': str(body.get('name1_instruct', name1_instruct)),
|
|
||||||
'name2_instruct': str(body.get('name2_instruct', name2_instruct)),
|
|
||||||
'context_instruct': str(body.get('context_instruct', context_instruct)),
|
|
||||||
'turn_template': str(body.get('turn_template', turn_template)),
|
|
||||||
'chat-instruct_command': str(body.get('chat_instruct_command', body.get('chat-instruct_command', shared.settings['chat-instruct_command']))),
|
|
||||||
'history': body.get('history', {'internal': [], 'visible': []})
|
|
||||||
})
|
|
||||||
|
|
||||||
return generate_params
|
|
||||||
|
|
||||||
|
|
||||||
def try_start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
|
|
||||||
Thread(target=_start_cloudflared, args=[
|
|
||||||
port, tunnel_id, max_attempts, on_start], daemon=True).start()
|
|
||||||
|
|
||||||
|
|
||||||
def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
|
|
||||||
try:
|
|
||||||
from flask_cloudflared import _run_cloudflared
|
|
||||||
except ImportError:
|
|
||||||
print('You should install flask_cloudflared manually')
|
|
||||||
raise Exception(
|
|
||||||
'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.')
|
|
||||||
|
|
||||||
for _ in range(max_attempts):
|
|
||||||
try:
|
|
||||||
if tunnel_id is not None:
|
|
||||||
public_url = _run_cloudflared(port, port + 1, tunnel_id=tunnel_id)
|
|
||||||
else:
|
|
||||||
public_url = _run_cloudflared(port, port + 1)
|
|
||||||
|
|
||||||
if on_start:
|
|
||||||
on_start(public_url)
|
|
||||||
|
|
||||||
return
|
|
||||||
except Exception:
|
|
||||||
traceback.print_exc()
|
|
||||||
time.sleep(3)
|
|
||||||
|
|
||||||
raise Exception('Could not start cloudflared.')
|
|
||||||
|
|
||||||
|
|
||||||
def _get_api_lock(tls) -> asyncio.Lock:
|
|
||||||
"""
|
|
||||||
The streaming and blocking API implementations each run on their own
|
|
||||||
thread, and multiplex requests using asyncio. If multiple outstanding
|
|
||||||
requests are received at once, we will try to acquire the shared lock
|
|
||||||
shared.generation_lock multiple times in succession in the same thread,
|
|
||||||
which will cause a deadlock.
|
|
||||||
|
|
||||||
To avoid this, we use this wrapper function to block on an asyncio
|
|
||||||
lock, and then try and grab the shared lock only while holding
|
|
||||||
the asyncio lock.
|
|
||||||
"""
|
|
||||||
if not hasattr(tls, "asyncio_lock"):
|
|
||||||
tls.asyncio_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
return tls.asyncio_lock
|
|
||||||
|
|
||||||
|
|
||||||
def with_api_lock(func):
|
|
||||||
"""
|
|
||||||
This decorator should be added to all streaming API methods which
|
|
||||||
require access to the shared.generation_lock. It ensures that the
|
|
||||||
tls.asyncio_lock is acquired before the method is called, and
|
|
||||||
released afterwards.
|
|
||||||
"""
|
|
||||||
@functools.wraps(func)
|
|
||||||
async def api_wrapper(*args, **kwargs):
|
|
||||||
async with _get_api_lock(api_tls):
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
return api_wrapper
|
|
@ -295,14 +295,14 @@ def run_server():
|
|||||||
|
|
||||||
if shared.args.public_api:
|
if shared.args.public_api:
|
||||||
def on_start(public_url: str):
|
def on_start(public_url: str):
|
||||||
logger.info(f'OpenAI compatible API URL:\n\n{public_url}/v1\n')
|
logger.info(f'OpenAI compatible API URL:\n\n{public_url}\n')
|
||||||
|
|
||||||
_start_cloudflared(port, shared.args.public_api_id, max_attempts=3, on_start=on_start)
|
_start_cloudflared(port, shared.args.public_api_id, max_attempts=3, on_start=on_start)
|
||||||
else:
|
else:
|
||||||
if ssl_keyfile and ssl_certfile:
|
if ssl_keyfile and ssl_certfile:
|
||||||
logger.info(f'OpenAI compatible API URL:\n\nhttps://{server_addr}:{port}/v1\n')
|
logger.info(f'OpenAI compatible API URL:\n\nhttps://{server_addr}:{port}\n')
|
||||||
else:
|
else:
|
||||||
logger.info(f'OpenAI compatible API URL:\n\nhttp://{server_addr}:{port}/v1\n')
|
logger.info(f'OpenAI compatible API URL:\n\nhttp://{server_addr}:{port}\n')
|
||||||
|
|
||||||
if shared.args.api_key:
|
if shared.args.api_key:
|
||||||
logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n')
|
logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n')
|
||||||
|
@ -258,9 +258,8 @@ if args.multimodal_pipeline is not None:
|
|||||||
add_extension('multimodal')
|
add_extension('multimodal')
|
||||||
|
|
||||||
# Activate the API extension
|
# Activate the API extension
|
||||||
if args.api:
|
if args.api or args.public_api:
|
||||||
# add_extension('openai', last=True)
|
add_extension('openai', last=True)
|
||||||
add_extension('api', last=True)
|
|
||||||
|
|
||||||
# Load model-specific settings
|
# Load model-specific settings
|
||||||
with Path(f'{args.model_dir}/config.yaml') as p:
|
with Path(f'{args.model_dir}/config.yaml') as p:
|
||||||
|
@ -9,6 +9,7 @@ os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
|||||||
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
|
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
|
||||||
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
||||||
warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated')
|
warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated')
|
||||||
|
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict')
|
||||||
|
|
||||||
with RequestBlocker():
|
with RequestBlocker():
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
Loading…
Reference in New Issue
Block a user