mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-26 14:20:40 +01:00
commit
454fcf39a9
@ -22,7 +22,7 @@
|
||||
"source": [
|
||||
"# oobabooga/text-generation-webui\n",
|
||||
"\n",
|
||||
"After running both cells, a public gradio URL will appear at the bottom in a few minutes. You can optionally generate API links.\n",
|
||||
"After running both cells, a public gradio URL will appear at the bottom in a few minutes. You can optionally generate an API link.\n",
|
||||
"\n",
|
||||
"* Project page: https://github.com/oobabooga/text-generation-webui\n",
|
||||
"* Gradio server status: https://status.gradio.app/"
|
||||
@ -75,7 +75,7 @@
|
||||
" with open('temp_requirements.txt', 'w') as file:\n",
|
||||
" file.write('\\n'.join(textgen_requirements))\n",
|
||||
"\n",
|
||||
" !pip install -r extensions/api/requirements.txt --upgrade\n",
|
||||
" !pip install -r extensions/openai/requirements.txt --upgrade\n",
|
||||
" !pip install -r temp_requirements.txt --upgrade\n",
|
||||
"\n",
|
||||
" print(\"\\033[1;32;1m\\n --> If you see a warning about \\\"previously imported packages\\\", just ignore it.\\033[0;37;0m\")\n",
|
||||
|
@ -1,6 +1,11 @@
|
||||
## OpenAI compatible API
|
||||
|
||||
The main API for this project is meant to be a drop-in replacement to the OpenAI API, including Chat and Completions endpoints.
|
||||
The main API for this project is meant to be a drop-in replacement to the OpenAI API, including Chat and Completions endpoints.
|
||||
|
||||
* It is 100% offline and private.
|
||||
* It doesn't create any logs.
|
||||
* It doesn't connect to OpenAI.
|
||||
* It doesn't use the openai-python library.
|
||||
|
||||
If you did not use the one-click installers, you may need to install the requirements first:
|
||||
|
||||
@ -10,7 +15,7 @@ pip install -r extensions/openai/requirements.txt
|
||||
|
||||
### Starting the API
|
||||
|
||||
Add `--extensions openai` to your command-line flags.
|
||||
Add `--api` to your command-line flags.
|
||||
|
||||
* To create a public Cloudflare URL, add the `--public-api` flag.
|
||||
* To listen on your local network, add the `--listen` flag.
|
||||
@ -18,31 +23,6 @@ Add `--extensions openai` to your command-line flags.
|
||||
* To use SSL, add `--ssl-keyfile key.pem --ssl-certfile cert.pem`. Note that it doesn't work with `--public-api`.
|
||||
* To use an API key for authentication, add `--api-key yourkey`.
|
||||
|
||||
#### Environment variables
|
||||
|
||||
The following environment variables can be used (they take precendence over everything else):
|
||||
|
||||
| Variable Name | Description | Example Value |
|
||||
|------------------------|------------------------------------|----------------------------|
|
||||
| `OPENEDAI_PORT` | Port number | 5000 |
|
||||
| `OPENEDAI_CERT_PATH` | SSL certificate file path | cert.pem |
|
||||
| `OPENEDAI_KEY_PATH` | SSL key file path | key.pem |
|
||||
| `OPENEDAI_DEBUG` | Enable debugging (set to 1) | 1 |
|
||||
| `SD_WEBUI_URL` | WebUI URL (used by endpoint) | http://127.0.0.1:7861 |
|
||||
| `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | all-mpnet-base-v2 |
|
||||
| `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda |
|
||||
|
||||
#### Persistent settings with `settings.yaml`
|
||||
|
||||
You can also set the following variables in your `settings.yaml` file:
|
||||
|
||||
```
|
||||
openai-embedding_device: cuda
|
||||
openai-embedding_model: all-mpnet-base-v2
|
||||
openai-sd_webui_url: http://127.0.0.1:7861
|
||||
openai-debug: 1
|
||||
```
|
||||
|
||||
### Examples
|
||||
|
||||
For the documentation with all the parameters and their types, consult `http://127.0.0.1:5000/docs` or the [typing.py](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/openai/typing.py) file.
|
||||
@ -220,6 +200,31 @@ for event in client.events():
|
||||
print()
|
||||
```
|
||||
|
||||
### Environment variables
|
||||
|
||||
The following environment variables can be used (they take precendence over everything else):
|
||||
|
||||
| Variable Name | Description | Example Value |
|
||||
|------------------------|------------------------------------|----------------------------|
|
||||
| `OPENEDAI_PORT` | Port number | 5000 |
|
||||
| `OPENEDAI_CERT_PATH` | SSL certificate file path | cert.pem |
|
||||
| `OPENEDAI_KEY_PATH` | SSL key file path | key.pem |
|
||||
| `OPENEDAI_DEBUG` | Enable debugging (set to 1) | 1 |
|
||||
| `SD_WEBUI_URL` | WebUI URL (used by endpoint) | http://127.0.0.1:7861 |
|
||||
| `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | sentence-transformers/all-mpnet-base-v2 |
|
||||
| `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda |
|
||||
|
||||
#### Persistent settings with `settings.yaml`
|
||||
|
||||
You can also set the following variables in your `settings.yaml` file:
|
||||
|
||||
```
|
||||
openai-embedding_device: cuda
|
||||
openai-embedding_model: "sentence-transformers/all-mpnet-base-v2"
|
||||
openai-sd_webui_url: http://127.0.0.1:7861
|
||||
openai-debug: 1
|
||||
```
|
||||
|
||||
### Third-party application setup
|
||||
|
||||
You can usually force an application that uses the OpenAI API to connect to the local API by using the following environment variables:
|
||||
|
@ -1,232 +0,0 @@
|
||||
import json
|
||||
import ssl
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from threading import Thread
|
||||
|
||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||
from modules import shared
|
||||
from modules.chat import generate_chat_reply
|
||||
from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, unload_model
|
||||
from modules.models_settings import get_model_metadata, update_model_parameters
|
||||
from modules.text_generation import (
|
||||
encode,
|
||||
generate_reply,
|
||||
stop_everything_event
|
||||
)
|
||||
from modules.utils import get_available_models
|
||||
from modules.logging_colors import logger
|
||||
|
||||
|
||||
def get_model_info():
|
||||
return {
|
||||
'model_name': shared.model_name,
|
||||
'lora_names': shared.lora_names,
|
||||
# dump
|
||||
'shared.settings': shared.settings,
|
||||
'shared.args': vars(shared.args),
|
||||
}
|
||||
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
if self.path == '/api/v1/model':
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
response = json.dumps({
|
||||
'result': shared.model_name
|
||||
})
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
else:
|
||||
self.send_error(404)
|
||||
|
||||
def do_POST(self):
|
||||
content_length = int(self.headers['Content-Length'])
|
||||
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
|
||||
|
||||
if self.path == '/api/v1/generate':
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
|
||||
prompt = body['prompt']
|
||||
generate_params = build_parameters(body)
|
||||
stopping_strings = generate_params.pop('stopping_strings')
|
||||
generate_params['stream'] = False
|
||||
|
||||
generator = generate_reply(
|
||||
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
answer = ''
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
response = json.dumps({
|
||||
'results': [{
|
||||
'text': answer
|
||||
}]
|
||||
})
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
elif self.path == '/api/v1/chat':
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
|
||||
user_input = body['user_input']
|
||||
regenerate = body.get('regenerate', False)
|
||||
_continue = body.get('_continue', False)
|
||||
|
||||
generate_params = build_parameters(body, chat=True)
|
||||
generate_params['stream'] = False
|
||||
|
||||
generator = generate_chat_reply(
|
||||
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
|
||||
|
||||
answer = generate_params['history']
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
response = json.dumps({
|
||||
'results': [{
|
||||
'history': answer
|
||||
}]
|
||||
})
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
elif self.path == '/api/v1/stop-stream':
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
|
||||
stop_everything_event()
|
||||
|
||||
response = json.dumps({
|
||||
'results': 'success'
|
||||
})
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
elif self.path == '/api/v1/model':
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
|
||||
# by default return the same as the GET interface
|
||||
result = shared.model_name
|
||||
|
||||
# Actions: info, load, list, unload
|
||||
action = body.get('action', '')
|
||||
|
||||
if action == 'load':
|
||||
model_name = body['model_name']
|
||||
args = body.get('args', {})
|
||||
print('args', args)
|
||||
for k in args:
|
||||
setattr(shared.args, k, args[k])
|
||||
|
||||
shared.model_name = model_name
|
||||
unload_model()
|
||||
|
||||
model_settings = get_model_metadata(shared.model_name)
|
||||
shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings})
|
||||
update_model_parameters(model_settings, initial=True)
|
||||
|
||||
if shared.settings['mode'] != 'instruct':
|
||||
shared.settings['instruction_template'] = None
|
||||
|
||||
try:
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
if shared.args.lora:
|
||||
add_lora_to_model(shared.args.lora) # list
|
||||
|
||||
except Exception as e:
|
||||
response = json.dumps({'error': {'message': repr(e)}})
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
raise e
|
||||
|
||||
shared.args.model = shared.model_name
|
||||
|
||||
result = get_model_info()
|
||||
|
||||
elif action == 'unload':
|
||||
unload_model()
|
||||
shared.model_name = None
|
||||
shared.args.model = None
|
||||
result = get_model_info()
|
||||
|
||||
elif action == 'list':
|
||||
result = get_available_models()
|
||||
|
||||
elif action == 'info':
|
||||
result = get_model_info()
|
||||
|
||||
response = json.dumps({
|
||||
'result': result,
|
||||
})
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
elif self.path == '/api/v1/token-count':
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
|
||||
tokens = encode(body['prompt'])[0]
|
||||
response = json.dumps({
|
||||
'results': [{
|
||||
'tokens': len(tokens)
|
||||
}]
|
||||
})
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
else:
|
||||
self.send_error(404)
|
||||
|
||||
def do_OPTIONS(self):
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
|
||||
def end_headers(self):
|
||||
self.send_header('Access-Control-Allow-Origin', '*')
|
||||
self.send_header('Access-Control-Allow-Methods', '*')
|
||||
self.send_header('Access-Control-Allow-Headers', '*')
|
||||
self.send_header('Cache-Control', 'no-store, no-cache, must-revalidate')
|
||||
super().end_headers()
|
||||
|
||||
|
||||
def _run_server(port: int, share: bool = False, tunnel_id=str):
|
||||
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
||||
server = ThreadingHTTPServer((address, port), Handler)
|
||||
|
||||
ssl_certfile = shared.args.ssl_certfile
|
||||
ssl_keyfile = shared.args.ssl_keyfile
|
||||
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
|
||||
if ssl_verify:
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
context.load_cert_chain(ssl_certfile, ssl_keyfile)
|
||||
server.socket = context.wrap_socket(server.socket, server_side=True)
|
||||
|
||||
def on_start(public_url: str):
|
||||
logger.info(f'Blocking API URL: \n\n{public_url}/api\n')
|
||||
|
||||
if share:
|
||||
try:
|
||||
try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
if ssl_verify:
|
||||
logger.info(f'Blocking API URL: \n\nhttps://{address}:{port}/api\n')
|
||||
else:
|
||||
logger.info(f'Blocking API URL: \n\nhttp://{address}:{port}/api\n')
|
||||
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
def start_server(port: int, share: bool = False, tunnel_id=str):
|
||||
Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start()
|
@ -1,2 +0,0 @@
|
||||
flask_cloudflared==0.0.14
|
||||
websockets==11.0.2
|
@ -1,15 +0,0 @@
|
||||
import time
|
||||
|
||||
import extensions.api.blocking_api as blocking_api
|
||||
import extensions.api.streaming_api as streaming_api
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
|
||||
def setup():
|
||||
logger.warning("\nThe current API is deprecated and will be replaced with the OpenAI compatible API on November 13th.\nTo test the new API, use \"--extensions openai\" instead of \"--api\".\nFor documentation on the new API, consult:\nhttps://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API")
|
||||
blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id)
|
||||
if shared.args.public_api:
|
||||
time.sleep(5)
|
||||
|
||||
streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id)
|
@ -1,142 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import ssl
|
||||
from threading import Thread
|
||||
|
||||
from websockets.server import serve
|
||||
|
||||
from extensions.api.util import (
|
||||
build_parameters,
|
||||
try_start_cloudflared,
|
||||
with_api_lock
|
||||
)
|
||||
from modules import shared
|
||||
from modules.chat import generate_chat_reply
|
||||
from modules.text_generation import generate_reply
|
||||
from modules.logging_colors import logger
|
||||
|
||||
PATH = '/api/v1/stream'
|
||||
|
||||
|
||||
@with_api_lock
|
||||
async def _handle_stream_message(websocket, message):
|
||||
message = json.loads(message)
|
||||
|
||||
prompt = message['prompt']
|
||||
generate_params = build_parameters(message)
|
||||
stopping_strings = generate_params.pop('stopping_strings')
|
||||
generate_params['stream'] = True
|
||||
|
||||
generator = generate_reply(
|
||||
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
# As we stream, only send the new bytes.
|
||||
skip_index = 0
|
||||
message_num = 0
|
||||
|
||||
for a in generator:
|
||||
to_send = a[skip_index:]
|
||||
if to_send is None or chr(0xfffd) in to_send: # partial unicode character, don't send it yet.
|
||||
continue
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': to_send
|
||||
}))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
skip_index += len(to_send)
|
||||
message_num += 1
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': message_num
|
||||
}))
|
||||
|
||||
|
||||
@with_api_lock
|
||||
async def _handle_chat_stream_message(websocket, message):
|
||||
body = json.loads(message)
|
||||
|
||||
user_input = body['user_input']
|
||||
generate_params = build_parameters(body, chat=True)
|
||||
generate_params['stream'] = True
|
||||
regenerate = body.get('regenerate', False)
|
||||
_continue = body.get('_continue', False)
|
||||
|
||||
generator = generate_chat_reply(
|
||||
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
|
||||
|
||||
message_num = 0
|
||||
for a in generator:
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'history': a
|
||||
}))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
message_num += 1
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': message_num
|
||||
}))
|
||||
|
||||
|
||||
async def _handle_connection(websocket, path):
|
||||
|
||||
if path == '/api/v1/stream':
|
||||
async for message in websocket:
|
||||
await _handle_stream_message(websocket, message)
|
||||
|
||||
elif path == '/api/v1/chat-stream':
|
||||
async for message in websocket:
|
||||
await _handle_chat_stream_message(websocket, message)
|
||||
|
||||
else:
|
||||
print(f'Streaming api: unknown path: {path}')
|
||||
return
|
||||
|
||||
|
||||
async def _run(host: str, port: int):
|
||||
ssl_certfile = shared.args.ssl_certfile
|
||||
ssl_keyfile = shared.args.ssl_keyfile
|
||||
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
|
||||
if ssl_verify:
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
context.load_cert_chain(ssl_certfile, ssl_keyfile)
|
||||
else:
|
||||
context = None
|
||||
|
||||
async with serve(_handle_connection, host, port, ping_interval=None, ssl=context):
|
||||
await asyncio.Future() # Run the server forever
|
||||
|
||||
|
||||
def _run_server(port: int, share: bool = False, tunnel_id=str):
|
||||
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
||||
ssl_certfile = shared.args.ssl_certfile
|
||||
ssl_keyfile = shared.args.ssl_keyfile
|
||||
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
|
||||
|
||||
def on_start(public_url: str):
|
||||
public_url = public_url.replace('https://', 'wss://')
|
||||
logger.info(f'Streaming API URL: \n\n{public_url}{PATH}\n')
|
||||
|
||||
if share:
|
||||
try:
|
||||
try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
else:
|
||||
if ssl_verify:
|
||||
logger.info(f'Streaming API URL: \n\nwss://{address}:{port}{PATH}\n')
|
||||
else:
|
||||
logger.info(f'Streaming API URL: \n\nws://{address}:{port}{PATH}\n')
|
||||
|
||||
asyncio.run(_run(host=address, port=port))
|
||||
|
||||
|
||||
def start_server(port: int, share: bool = False, tunnel_id=str):
|
||||
Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start()
|
@ -1,156 +0,0 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from threading import Thread
|
||||
from typing import Callable, Optional
|
||||
|
||||
from modules import shared
|
||||
from modules.chat import load_character_memoized
|
||||
from modules.presets import load_preset_memoized
|
||||
|
||||
# We use a thread local to store the asyncio lock, so that each thread
|
||||
# has its own lock. This isn't strictly necessary, but it makes it
|
||||
# such that if we can support multiple worker threads in the future,
|
||||
# thus handling multiple requests in parallel.
|
||||
api_tls = threading.local()
|
||||
|
||||
|
||||
def build_parameters(body, chat=False):
|
||||
|
||||
generate_params = {
|
||||
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
|
||||
'auto_max_new_tokens': bool(body.get('auto_max_new_tokens', False)),
|
||||
'max_tokens_second': int(body.get('max_tokens_second', 0)),
|
||||
'do_sample': bool(body.get('do_sample', True)),
|
||||
'temperature': float(body.get('temperature', 0.5)),
|
||||
'temperature_last': bool(body.get('temperature_last', False)),
|
||||
'top_p': float(body.get('top_p', 1)),
|
||||
'min_p': float(body.get('min_p', 0)),
|
||||
'typical_p': float(body.get('typical_p', body.get('typical', 1))),
|
||||
'epsilon_cutoff': float(body.get('epsilon_cutoff', 0)),
|
||||
'eta_cutoff': float(body.get('eta_cutoff', 0)),
|
||||
'tfs': float(body.get('tfs', 1)),
|
||||
'top_a': float(body.get('top_a', 0)),
|
||||
'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))),
|
||||
'presence_penalty': float(body.get('presence_penalty', body.get('presence_pen', 0))),
|
||||
'frequency_penalty': float(body.get('frequency_penalty', body.get('frequency_pen', 0))),
|
||||
'repetition_penalty_range': int(body.get('repetition_penalty_range', 0)),
|
||||
'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)),
|
||||
'top_k': int(body.get('top_k', 0)),
|
||||
'min_length': int(body.get('min_length', 0)),
|
||||
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
|
||||
'num_beams': int(body.get('num_beams', 1)),
|
||||
'penalty_alpha': float(body.get('penalty_alpha', 0)),
|
||||
'length_penalty': float(body.get('length_penalty', 1)),
|
||||
'early_stopping': bool(body.get('early_stopping', False)),
|
||||
'mirostat_mode': int(body.get('mirostat_mode', 0)),
|
||||
'mirostat_tau': float(body.get('mirostat_tau', 5)),
|
||||
'mirostat_eta': float(body.get('mirostat_eta', 0.1)),
|
||||
'grammar_string': str(body.get('grammar_string', '')),
|
||||
'guidance_scale': float(body.get('guidance_scale', 1)),
|
||||
'negative_prompt': str(body.get('negative_prompt', '')),
|
||||
'seed': int(body.get('seed', -1)),
|
||||
'add_bos_token': bool(body.get('add_bos_token', True)),
|
||||
'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))),
|
||||
'custom_token_bans': str(body.get('custom_token_bans', '')),
|
||||
'ban_eos_token': bool(body.get('ban_eos_token', False)),
|
||||
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
|
||||
'custom_stopping_strings': '', # leave this blank
|
||||
'stopping_strings': body.get('stopping_strings', []),
|
||||
}
|
||||
|
||||
preset_name = body.get('preset', 'None')
|
||||
if preset_name not in ['None', None, '']:
|
||||
preset = load_preset_memoized(preset_name)
|
||||
generate_params.update(preset)
|
||||
|
||||
if chat:
|
||||
character = body.get('character')
|
||||
instruction_template = body.get('instruction_template', shared.settings['instruction_template'])
|
||||
if str(instruction_template) == "None":
|
||||
instruction_template = "Vicuna-v1.1"
|
||||
if str(character) == "None":
|
||||
character = "Assistant"
|
||||
|
||||
name1, name2, _, greeting, context, _, _ = load_character_memoized(character, str(body.get('your_name', shared.settings['name1'])), '', instruct=False)
|
||||
name1_instruct, name2_instruct, _, _, context_instruct, turn_template, _ = load_character_memoized(instruction_template, '', '', instruct=True)
|
||||
generate_params.update({
|
||||
'mode': str(body.get('mode', 'chat')),
|
||||
'name1': str(body.get('name1', name1)),
|
||||
'name2': str(body.get('name2', name2)),
|
||||
'context': str(body.get('context', context)),
|
||||
'greeting': str(body.get('greeting', greeting)),
|
||||
'name1_instruct': str(body.get('name1_instruct', name1_instruct)),
|
||||
'name2_instruct': str(body.get('name2_instruct', name2_instruct)),
|
||||
'context_instruct': str(body.get('context_instruct', context_instruct)),
|
||||
'turn_template': str(body.get('turn_template', turn_template)),
|
||||
'chat-instruct_command': str(body.get('chat_instruct_command', body.get('chat-instruct_command', shared.settings['chat-instruct_command']))),
|
||||
'history': body.get('history', {'internal': [], 'visible': []})
|
||||
})
|
||||
|
||||
return generate_params
|
||||
|
||||
|
||||
def try_start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
|
||||
Thread(target=_start_cloudflared, args=[
|
||||
port, tunnel_id, max_attempts, on_start], daemon=True).start()
|
||||
|
||||
|
||||
def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
|
||||
try:
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
except ImportError:
|
||||
print('You should install flask_cloudflared manually')
|
||||
raise Exception(
|
||||
'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.')
|
||||
|
||||
for _ in range(max_attempts):
|
||||
try:
|
||||
if tunnel_id is not None:
|
||||
public_url = _run_cloudflared(port, port + 1, tunnel_id=tunnel_id)
|
||||
else:
|
||||
public_url = _run_cloudflared(port, port + 1)
|
||||
|
||||
if on_start:
|
||||
on_start(public_url)
|
||||
|
||||
return
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
time.sleep(3)
|
||||
|
||||
raise Exception('Could not start cloudflared.')
|
||||
|
||||
|
||||
def _get_api_lock(tls) -> asyncio.Lock:
|
||||
"""
|
||||
The streaming and blocking API implementations each run on their own
|
||||
thread, and multiplex requests using asyncio. If multiple outstanding
|
||||
requests are received at once, we will try to acquire the shared lock
|
||||
shared.generation_lock multiple times in succession in the same thread,
|
||||
which will cause a deadlock.
|
||||
|
||||
To avoid this, we use this wrapper function to block on an asyncio
|
||||
lock, and then try and grab the shared lock only while holding
|
||||
the asyncio lock.
|
||||
"""
|
||||
if not hasattr(tls, "asyncio_lock"):
|
||||
tls.asyncio_lock = asyncio.Lock()
|
||||
|
||||
return tls.asyncio_lock
|
||||
|
||||
|
||||
def with_api_lock(func):
|
||||
"""
|
||||
This decorator should be added to all streaming API methods which
|
||||
require access to the shared.generation_lock. It ensures that the
|
||||
tls.asyncio_lock is acquired before the method is called, and
|
||||
released afterwards.
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
async def api_wrapper(*args, **kwargs):
|
||||
async with _get_api_lock(api_tls):
|
||||
return await func(*args, **kwargs)
|
||||
return api_wrapper
|
@ -1,11 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
# preload the embedding model, useful for Docker images to prevent re-download on config change
|
||||
# Dockerfile:
|
||||
# ENV OPENEDAI_EMBEDDING_MODEL=all-mpnet-base-v2 # Optional
|
||||
# ENV OPENEDAI_EMBEDDING_MODEL="sentence-transformers/all-mpnet-base-v2" # Optional
|
||||
# RUN python3 cache_embedded_model.py
|
||||
import os
|
||||
|
||||
import sentence_transformers
|
||||
|
||||
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", "all-mpnet-base-v2")
|
||||
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2")
|
||||
model = sentence_transformers.SentenceTransformer(st_model)
|
||||
|
@ -204,8 +204,9 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False) -
|
||||
name1_instruct, name2_instruct, _, _, context_instruct, turn_template, system_message = load_character_memoized(instruction_template, '', '', instruct=True)
|
||||
name1_instruct = body['name1_instruct'] or name1_instruct
|
||||
name2_instruct = body['name2_instruct'] or name2_instruct
|
||||
context_instruct = body['context_instruct'] or context_instruct
|
||||
turn_template = body['turn_template'] or turn_template
|
||||
context_instruct = body['context_instruct'] or context_instruct
|
||||
system_message = body['system_message'] or system_message
|
||||
|
||||
# Chat character
|
||||
character = body['character'] or shared.settings['character']
|
||||
|
@ -3,9 +3,7 @@ import os
|
||||
import numpy as np
|
||||
from extensions.openai.errors import ServiceUnavailableError
|
||||
from extensions.openai.utils import debug_msg, float_list_to_base64
|
||||
from transformers import AutoModel
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
embeddings_params_initialized = False
|
||||
|
||||
@ -17,38 +15,44 @@ def initialize_embedding_params():
|
||||
'''
|
||||
global embeddings_params_initialized
|
||||
if not embeddings_params_initialized:
|
||||
global st_model, embeddings_model, embeddings_device
|
||||
from extensions.openai.script import params
|
||||
|
||||
global st_model, embeddings_model, embeddings_device
|
||||
|
||||
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", params.get('embedding_model', 'all-mpnet-base-v2'))
|
||||
embeddings_model = None
|
||||
# OPENEDAI_EMBEDDING_DEVICE: auto (best or cpu), cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone
|
||||
embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", params.get('embedding_device', 'cpu'))
|
||||
if embeddings_device.lower() == 'auto':
|
||||
embeddings_device = None
|
||||
|
||||
embeddings_params_initialized = True
|
||||
|
||||
|
||||
def load_embedding_model(model: str):
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ModuleNotFoundError:
|
||||
logger.error("The sentence_transformers module has not been found. Please install it manually with pip install -U sentence-transformers.")
|
||||
raise ModuleNotFoundError
|
||||
|
||||
initialize_embedding_params()
|
||||
global embeddings_device, embeddings_model
|
||||
try:
|
||||
print(f"Try embedding model: {model} on {embeddings_device}")
|
||||
trust = shared.args.trust_remote_code
|
||||
if embeddings_device == 'cpu':
|
||||
embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=trust).to("cpu", dtype=float)
|
||||
else: #use the auto mode
|
||||
embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=trust)
|
||||
print(f"\nLoaded embedding model: {model} on {embeddings_model.device}")
|
||||
embeddings_model = SentenceTransformer(model, device=embeddings_device)
|
||||
print(f"Loaded embedding model: {model}")
|
||||
except Exception as e:
|
||||
embeddings_model = None
|
||||
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
|
||||
|
||||
|
||||
def get_embeddings_model() -> AutoModel:
|
||||
def get_embeddings_model():
|
||||
initialize_embedding_params()
|
||||
global embeddings_model, st_model
|
||||
if st_model and not embeddings_model:
|
||||
load_embedding_model(st_model) # lazy load the model
|
||||
|
||||
return embeddings_model
|
||||
|
||||
|
||||
@ -67,9 +71,7 @@ def get_embeddings(input: list) -> np.ndarray:
|
||||
|
||||
|
||||
def embeddings(input: list, encoding_format: str) -> dict:
|
||||
|
||||
embeddings = get_embeddings(input)
|
||||
|
||||
if encoding_format == "base64":
|
||||
data = [{"object": "embedding", "embedding": float_list_to_base64(emb), "index": n} for n, emb in enumerate(embeddings)]
|
||||
else:
|
||||
@ -86,5 +88,4 @@ def embeddings(input: list, encoding_format: str) -> dict:
|
||||
}
|
||||
|
||||
debug_msg(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
||||
|
||||
return response
|
||||
|
@ -55,6 +55,7 @@ def _load_model(data):
|
||||
setattr(shared.args, k, args[k])
|
||||
|
||||
shared.model, shared.tokenizer = load_model(model_name)
|
||||
shared.model_name = model_name
|
||||
|
||||
# Update shared.settings with custom generation defaults
|
||||
if settings:
|
||||
|
@ -1,5 +1,4 @@
|
||||
SpeechRecognition==3.10.0
|
||||
flask_cloudflared==0.0.14
|
||||
sentence-transformers
|
||||
sse-starlette==1.6.5
|
||||
tiktoken
|
||||
|
@ -31,6 +31,8 @@ from .typing import (
|
||||
CompletionResponse,
|
||||
DecodeRequest,
|
||||
DecodeResponse,
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
EncodeRequest,
|
||||
EncodeResponse,
|
||||
LoadModelRequest,
|
||||
@ -41,7 +43,7 @@ from .typing import (
|
||||
|
||||
params = {
|
||||
'embedding_device': 'cpu',
|
||||
'embedding_model': 'all-mpnet-base-v2',
|
||||
'embedding_model': 'sentence-transformers/all-mpnet-base-v2',
|
||||
'sd_webui_url': '',
|
||||
'debug': 0
|
||||
}
|
||||
@ -196,19 +198,16 @@ async def handle_image_generation(request: Request):
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post("/v1/embeddings")
|
||||
async def handle_embeddings(request: Request):
|
||||
body = await request.json()
|
||||
encoding_format = body.get("encoding_format", "")
|
||||
|
||||
input = body.get('input', body.get('text', ''))
|
||||
@app.post("/v1/embeddings", response_model=EmbeddingsResponse)
|
||||
async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
|
||||
input = request_data.input
|
||||
if not input:
|
||||
raise HTTPException(status_code=400, detail="Missing required argument input")
|
||||
|
||||
if type(input) is str:
|
||||
input = [input]
|
||||
|
||||
response = OAIembeddings.embeddings(input, encoding_format)
|
||||
response = OAIembeddings.embeddings(input, request_data.encoding_format)
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@ -295,14 +294,14 @@ def run_server():
|
||||
|
||||
if shared.args.public_api:
|
||||
def on_start(public_url: str):
|
||||
logger.info(f'OpenAI compatible API URL:\n\n{public_url}/v1\n')
|
||||
logger.info(f'OpenAI compatible API URL:\n\n{public_url}\n')
|
||||
|
||||
_start_cloudflared(port, shared.args.public_api_id, max_attempts=3, on_start=on_start)
|
||||
else:
|
||||
if ssl_keyfile and ssl_certfile:
|
||||
logger.info(f'OpenAI compatible API URL:\n\nhttps://{server_addr}:{port}/v1\n')
|
||||
logger.info(f'OpenAI compatible API URL:\n\nhttps://{server_addr}:{port}\n')
|
||||
else:
|
||||
logger.info(f'OpenAI compatible API URL:\n\nhttp://{server_addr}:{port}/v1\n')
|
||||
logger.info(f'OpenAI compatible API URL:\n\nhttp://{server_addr}:{port}\n')
|
||||
|
||||
if shared.args.api_key:
|
||||
logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n')
|
||||
|
@ -42,7 +42,7 @@ class GenerationOptions(BaseModel):
|
||||
|
||||
|
||||
class CompletionRequestParams(BaseModel):
|
||||
model: str | None = None
|
||||
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
|
||||
prompt: str | List[str]
|
||||
best_of: int | None = Field(default=1, description="Unused parameter.")
|
||||
echo: bool | None = False
|
||||
@ -75,7 +75,7 @@ class CompletionResponse(BaseModel):
|
||||
|
||||
class ChatCompletionRequestParams(BaseModel):
|
||||
messages: List[dict]
|
||||
model: str | None = None
|
||||
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
|
||||
frequency_penalty: float | None = 0
|
||||
function_call: str | dict | None = Field(default=None, description="Unused parameter.")
|
||||
functions: List[dict] | None = Field(default=None, description="Unused parameter.")
|
||||
@ -92,10 +92,11 @@ class ChatCompletionRequestParams(BaseModel):
|
||||
mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.")
|
||||
|
||||
instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/instruction-templates. If not set, the correct template will be guessed using the regex expressions in models/config.yaml.")
|
||||
turn_template: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
|
||||
name1_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
|
||||
name2_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
|
||||
context_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
|
||||
turn_template: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
|
||||
system_message: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
|
||||
|
||||
character: str | None = Field(default=None, description="A character defined under text-generation-webui/characters. If not set, the default \"Assistant\" character will be used.")
|
||||
name1: str | None = Field(default=None, description="Overwrites the value set by character.")
|
||||
@ -153,6 +154,19 @@ class LoadModelRequest(BaseModel):
|
||||
settings: dict | None = None
|
||||
|
||||
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
input: str | List[str]
|
||||
model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")
|
||||
encoding_format: str = Field(default="float", description="Can be float or base64.")
|
||||
user: str | None = Field(default=None, description="Unused parameter.")
|
||||
|
||||
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
index: int
|
||||
embedding: List[float]
|
||||
object: str = "embedding"
|
||||
|
||||
|
||||
def to_json(obj):
|
||||
return json.dumps(obj.__dict__, indent=4)
|
||||
|
||||
|
@ -258,9 +258,8 @@ if args.multimodal_pipeline is not None:
|
||||
add_extension('multimodal')
|
||||
|
||||
# Activate the API extension
|
||||
if args.api:
|
||||
# add_extension('openai', last=True)
|
||||
add_extension('api', last=True)
|
||||
if args.api or args.public_api:
|
||||
add_extension('openai', last=True)
|
||||
|
||||
# Load model-specific settings
|
||||
with Path(f'{args.model_dir}/config.yaml') as p:
|
||||
|
@ -190,7 +190,7 @@ def install_webui():
|
||||
use_cuda118 = "Y" if os.environ.get("USE_CUDA118", "").lower() in ("yes", "y", "true", "1", "t", "on") else "N"
|
||||
else:
|
||||
# Ask for CUDA version if using NVIDIA
|
||||
print("\nWould you like to use CUDA 11.8 instead of 12.1? This is only necessary for older GPUs like Kepler.\nIf unsure, say \"N\".\n")
|
||||
print("\nDo you want to use CUDA 11.8 instead of 12.1? Only choose this option if your GPU is very old (Kepler or older).\nFor RTX and GTX series GPUs, say \"N\". If unsure, say \"N\".\n")
|
||||
use_cuda118 = input("Input (Y/N)> ").upper().strip('"\'').strip()
|
||||
while use_cuda118 not in 'YN':
|
||||
print("Invalid choice. Please try again.")
|
||||
|
@ -9,6 +9,7 @@ os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated')
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict')
|
||||
|
||||
with RequestBlocker():
|
||||
import gradio as gr
|
||||
|
Loading…
Reference in New Issue
Block a user