Merge pull request #4579 from oobabooga/dev

Merge dev branch
This commit is contained in:
oobabooga 2023-11-13 11:39:08 -03:00 committed by GitHub
commit 454fcf39a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 85 additions and 612 deletions

View File

@ -22,7 +22,7 @@
"source": [
"# oobabooga/text-generation-webui\n",
"\n",
"After running both cells, a public gradio URL will appear at the bottom in a few minutes. You can optionally generate API links.\n",
"After running both cells, a public gradio URL will appear at the bottom in a few minutes. You can optionally generate an API link.\n",
"\n",
"* Project page: https://github.com/oobabooga/text-generation-webui\n",
"* Gradio server status: https://status.gradio.app/"
@ -75,7 +75,7 @@
" with open('temp_requirements.txt', 'w') as file:\n",
" file.write('\\n'.join(textgen_requirements))\n",
"\n",
" !pip install -r extensions/api/requirements.txt --upgrade\n",
" !pip install -r extensions/openai/requirements.txt --upgrade\n",
" !pip install -r temp_requirements.txt --upgrade\n",
"\n",
" print(\"\\033[1;32;1m\\n --> If you see a warning about \\\"previously imported packages\\\", just ignore it.\\033[0;37;0m\")\n",

View File

@ -1,6 +1,11 @@
## OpenAI compatible API
The main API for this project is meant to be a drop-in replacement to the OpenAI API, including Chat and Completions endpoints.
The main API for this project is meant to be a drop-in replacement to the OpenAI API, including Chat and Completions endpoints.
* It is 100% offline and private.
* It doesn't create any logs.
* It doesn't connect to OpenAI.
* It doesn't use the openai-python library.
If you did not use the one-click installers, you may need to install the requirements first:
@ -10,7 +15,7 @@ pip install -r extensions/openai/requirements.txt
### Starting the API
Add `--extensions openai` to your command-line flags.
Add `--api` to your command-line flags.
* To create a public Cloudflare URL, add the `--public-api` flag.
* To listen on your local network, add the `--listen` flag.
@ -18,31 +23,6 @@ Add `--extensions openai` to your command-line flags.
* To use SSL, add `--ssl-keyfile key.pem --ssl-certfile cert.pem`. Note that it doesn't work with `--public-api`.
* To use an API key for authentication, add `--api-key yourkey`.
#### Environment variables
The following environment variables can be used (they take precendence over everything else):
| Variable Name | Description | Example Value |
|------------------------|------------------------------------|----------------------------|
| `OPENEDAI_PORT` | Port number | 5000 |
| `OPENEDAI_CERT_PATH` | SSL certificate file path | cert.pem |
| `OPENEDAI_KEY_PATH` | SSL key file path | key.pem |
| `OPENEDAI_DEBUG` | Enable debugging (set to 1) | 1 |
| `SD_WEBUI_URL` | WebUI URL (used by endpoint) | http://127.0.0.1:7861 |
| `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | all-mpnet-base-v2 |
| `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda |
#### Persistent settings with `settings.yaml`
You can also set the following variables in your `settings.yaml` file:
```
openai-embedding_device: cuda
openai-embedding_model: all-mpnet-base-v2
openai-sd_webui_url: http://127.0.0.1:7861
openai-debug: 1
```
### Examples
For the documentation with all the parameters and their types, consult `http://127.0.0.1:5000/docs` or the [typing.py](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/openai/typing.py) file.
@ -220,6 +200,31 @@ for event in client.events():
print()
```
### Environment variables
The following environment variables can be used (they take precendence over everything else):
| Variable Name | Description | Example Value |
|------------------------|------------------------------------|----------------------------|
| `OPENEDAI_PORT` | Port number | 5000 |
| `OPENEDAI_CERT_PATH` | SSL certificate file path | cert.pem |
| `OPENEDAI_KEY_PATH` | SSL key file path | key.pem |
| `OPENEDAI_DEBUG` | Enable debugging (set to 1) | 1 |
| `SD_WEBUI_URL` | WebUI URL (used by endpoint) | http://127.0.0.1:7861 |
| `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | sentence-transformers/all-mpnet-base-v2 |
| `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda |
#### Persistent settings with `settings.yaml`
You can also set the following variables in your `settings.yaml` file:
```
openai-embedding_device: cuda
openai-embedding_model: "sentence-transformers/all-mpnet-base-v2"
openai-sd_webui_url: http://127.0.0.1:7861
openai-debug: 1
```
### Third-party application setup
You can usually force an application that uses the OpenAI API to connect to the local API by using the following environment variables:

View File

@ -1,232 +0,0 @@
import json
import ssl
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared
from modules.chat import generate_chat_reply
from modules.LoRA import add_lora_to_model
from modules.models import load_model, unload_model
from modules.models_settings import get_model_metadata, update_model_parameters
from modules.text_generation import (
encode,
generate_reply,
stop_everything_event
)
from modules.utils import get_available_models
from modules.logging_colors import logger
def get_model_info():
return {
'model_name': shared.model_name,
'lora_names': shared.lora_names,
# dump
'shared.settings': shared.settings,
'shared.args': vars(shared.args),
}
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == '/api/v1/model':
self.send_response(200)
self.end_headers()
response = json.dumps({
'result': shared.model_name
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def do_POST(self):
content_length = int(self.headers['Content-Length'])
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
if self.path == '/api/v1/generate':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
prompt = body['prompt']
generate_params = build_parameters(body)
stopping_strings = generate_params.pop('stopping_strings')
generate_params['stream'] = False
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
for a in generator:
answer = a
response = json.dumps({
'results': [{
'text': answer
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/chat':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
user_input = body['user_input']
regenerate = body.get('regenerate', False)
_continue = body.get('_continue', False)
generate_params = build_parameters(body, chat=True)
generate_params['stream'] = False
generator = generate_chat_reply(
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
answer = generate_params['history']
for a in generator:
answer = a
response = json.dumps({
'results': [{
'history': answer
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/stop-stream':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
stop_everything_event()
response = json.dumps({
'results': 'success'
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/model':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
# by default return the same as the GET interface
result = shared.model_name
# Actions: info, load, list, unload
action = body.get('action', '')
if action == 'load':
model_name = body['model_name']
args = body.get('args', {})
print('args', args)
for k in args:
setattr(shared.args, k, args[k])
shared.model_name = model_name
unload_model()
model_settings = get_model_metadata(shared.model_name)
shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings})
update_model_parameters(model_settings, initial=True)
if shared.settings['mode'] != 'instruct':
shared.settings['instruction_template'] = None
try:
shared.model, shared.tokenizer = load_model(shared.model_name)
if shared.args.lora:
add_lora_to_model(shared.args.lora) # list
except Exception as e:
response = json.dumps({'error': {'message': repr(e)}})
self.wfile.write(response.encode('utf-8'))
raise e
shared.args.model = shared.model_name
result = get_model_info()
elif action == 'unload':
unload_model()
shared.model_name = None
shared.args.model = None
result = get_model_info()
elif action == 'list':
result = get_available_models()
elif action == 'info':
result = get_model_info()
response = json.dumps({
'result': result,
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/token-count':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
tokens = encode(body['prompt'])[0]
response = json.dumps({
'results': [{
'tokens': len(tokens)
}]
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def do_OPTIONS(self):
self.send_response(200)
self.end_headers()
def end_headers(self):
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', '*')
self.send_header('Access-Control-Allow-Headers', '*')
self.send_header('Cache-Control', 'no-store, no-cache, must-revalidate')
super().end_headers()
def _run_server(port: int, share: bool = False, tunnel_id=str):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
server = ThreadingHTTPServer((address, port), Handler)
ssl_certfile = shared.args.ssl_certfile
ssl_keyfile = shared.args.ssl_keyfile
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
if ssl_verify:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(ssl_certfile, ssl_keyfile)
server.socket = context.wrap_socket(server.socket, server_side=True)
def on_start(public_url: str):
logger.info(f'Blocking API URL: \n\n{public_url}/api\n')
if share:
try:
try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start)
except Exception:
pass
else:
if ssl_verify:
logger.info(f'Blocking API URL: \n\nhttps://{address}:{port}/api\n')
else:
logger.info(f'Blocking API URL: \n\nhttp://{address}:{port}/api\n')
server.serve_forever()
def start_server(port: int, share: bool = False, tunnel_id=str):
Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start()

View File

@ -1,2 +0,0 @@
flask_cloudflared==0.0.14
websockets==11.0.2

View File

@ -1,15 +0,0 @@
import time
import extensions.api.blocking_api as blocking_api
import extensions.api.streaming_api as streaming_api
from modules import shared
from modules.logging_colors import logger
def setup():
logger.warning("\nThe current API is deprecated and will be replaced with the OpenAI compatible API on November 13th.\nTo test the new API, use \"--extensions openai\" instead of \"--api\".\nFor documentation on the new API, consult:\nhttps://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API")
blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id)
if shared.args.public_api:
time.sleep(5)
streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id)

View File

@ -1,142 +0,0 @@
import asyncio
import json
import ssl
from threading import Thread
from websockets.server import serve
from extensions.api.util import (
build_parameters,
try_start_cloudflared,
with_api_lock
)
from modules import shared
from modules.chat import generate_chat_reply
from modules.text_generation import generate_reply
from modules.logging_colors import logger
PATH = '/api/v1/stream'
@with_api_lock
async def _handle_stream_message(websocket, message):
message = json.loads(message)
prompt = message['prompt']
generate_params = build_parameters(message)
stopping_strings = generate_params.pop('stopping_strings')
generate_params['stream'] = True
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
# As we stream, only send the new bytes.
skip_index = 0
message_num = 0
for a in generator:
to_send = a[skip_index:]
if to_send is None or chr(0xfffd) in to_send: # partial unicode character, don't send it yet.
continue
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': to_send
}))
await asyncio.sleep(0)
skip_index += len(to_send)
message_num += 1
await websocket.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
@with_api_lock
async def _handle_chat_stream_message(websocket, message):
body = json.loads(message)
user_input = body['user_input']
generate_params = build_parameters(body, chat=True)
generate_params['stream'] = True
regenerate = body.get('regenerate', False)
_continue = body.get('_continue', False)
generator = generate_chat_reply(
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
message_num = 0
for a in generator:
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'history': a
}))
await asyncio.sleep(0)
message_num += 1
await websocket.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
async def _handle_connection(websocket, path):
if path == '/api/v1/stream':
async for message in websocket:
await _handle_stream_message(websocket, message)
elif path == '/api/v1/chat-stream':
async for message in websocket:
await _handle_chat_stream_message(websocket, message)
else:
print(f'Streaming api: unknown path: {path}')
return
async def _run(host: str, port: int):
ssl_certfile = shared.args.ssl_certfile
ssl_keyfile = shared.args.ssl_keyfile
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
if ssl_verify:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(ssl_certfile, ssl_keyfile)
else:
context = None
async with serve(_handle_connection, host, port, ping_interval=None, ssl=context):
await asyncio.Future() # Run the server forever
def _run_server(port: int, share: bool = False, tunnel_id=str):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
ssl_certfile = shared.args.ssl_certfile
ssl_keyfile = shared.args.ssl_keyfile
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
def on_start(public_url: str):
public_url = public_url.replace('https://', 'wss://')
logger.info(f'Streaming API URL: \n\n{public_url}{PATH}\n')
if share:
try:
try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start)
except Exception as e:
print(e)
else:
if ssl_verify:
logger.info(f'Streaming API URL: \n\nwss://{address}:{port}{PATH}\n')
else:
logger.info(f'Streaming API URL: \n\nws://{address}:{port}{PATH}\n')
asyncio.run(_run(host=address, port=port))
def start_server(port: int, share: bool = False, tunnel_id=str):
Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start()

View File

@ -1,156 +0,0 @@
import asyncio
import functools
import threading
import time
import traceback
from threading import Thread
from typing import Callable, Optional
from modules import shared
from modules.chat import load_character_memoized
from modules.presets import load_preset_memoized
# We use a thread local to store the asyncio lock, so that each thread
# has its own lock. This isn't strictly necessary, but it makes it
# such that if we can support multiple worker threads in the future,
# thus handling multiple requests in parallel.
api_tls = threading.local()
def build_parameters(body, chat=False):
generate_params = {
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
'auto_max_new_tokens': bool(body.get('auto_max_new_tokens', False)),
'max_tokens_second': int(body.get('max_tokens_second', 0)),
'do_sample': bool(body.get('do_sample', True)),
'temperature': float(body.get('temperature', 0.5)),
'temperature_last': bool(body.get('temperature_last', False)),
'top_p': float(body.get('top_p', 1)),
'min_p': float(body.get('min_p', 0)),
'typical_p': float(body.get('typical_p', body.get('typical', 1))),
'epsilon_cutoff': float(body.get('epsilon_cutoff', 0)),
'eta_cutoff': float(body.get('eta_cutoff', 0)),
'tfs': float(body.get('tfs', 1)),
'top_a': float(body.get('top_a', 0)),
'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))),
'presence_penalty': float(body.get('presence_penalty', body.get('presence_pen', 0))),
'frequency_penalty': float(body.get('frequency_penalty', body.get('frequency_pen', 0))),
'repetition_penalty_range': int(body.get('repetition_penalty_range', 0)),
'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)),
'top_k': int(body.get('top_k', 0)),
'min_length': int(body.get('min_length', 0)),
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
'num_beams': int(body.get('num_beams', 1)),
'penalty_alpha': float(body.get('penalty_alpha', 0)),
'length_penalty': float(body.get('length_penalty', 1)),
'early_stopping': bool(body.get('early_stopping', False)),
'mirostat_mode': int(body.get('mirostat_mode', 0)),
'mirostat_tau': float(body.get('mirostat_tau', 5)),
'mirostat_eta': float(body.get('mirostat_eta', 0.1)),
'grammar_string': str(body.get('grammar_string', '')),
'guidance_scale': float(body.get('guidance_scale', 1)),
'negative_prompt': str(body.get('negative_prompt', '')),
'seed': int(body.get('seed', -1)),
'add_bos_token': bool(body.get('add_bos_token', True)),
'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))),
'custom_token_bans': str(body.get('custom_token_bans', '')),
'ban_eos_token': bool(body.get('ban_eos_token', False)),
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
'custom_stopping_strings': '', # leave this blank
'stopping_strings': body.get('stopping_strings', []),
}
preset_name = body.get('preset', 'None')
if preset_name not in ['None', None, '']:
preset = load_preset_memoized(preset_name)
generate_params.update(preset)
if chat:
character = body.get('character')
instruction_template = body.get('instruction_template', shared.settings['instruction_template'])
if str(instruction_template) == "None":
instruction_template = "Vicuna-v1.1"
if str(character) == "None":
character = "Assistant"
name1, name2, _, greeting, context, _, _ = load_character_memoized(character, str(body.get('your_name', shared.settings['name1'])), '', instruct=False)
name1_instruct, name2_instruct, _, _, context_instruct, turn_template, _ = load_character_memoized(instruction_template, '', '', instruct=True)
generate_params.update({
'mode': str(body.get('mode', 'chat')),
'name1': str(body.get('name1', name1)),
'name2': str(body.get('name2', name2)),
'context': str(body.get('context', context)),
'greeting': str(body.get('greeting', greeting)),
'name1_instruct': str(body.get('name1_instruct', name1_instruct)),
'name2_instruct': str(body.get('name2_instruct', name2_instruct)),
'context_instruct': str(body.get('context_instruct', context_instruct)),
'turn_template': str(body.get('turn_template', turn_template)),
'chat-instruct_command': str(body.get('chat_instruct_command', body.get('chat-instruct_command', shared.settings['chat-instruct_command']))),
'history': body.get('history', {'internal': [], 'visible': []})
})
return generate_params
def try_start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
Thread(target=_start_cloudflared, args=[
port, tunnel_id, max_attempts, on_start], daemon=True).start()
def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
try:
from flask_cloudflared import _run_cloudflared
except ImportError:
print('You should install flask_cloudflared manually')
raise Exception(
'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.')
for _ in range(max_attempts):
try:
if tunnel_id is not None:
public_url = _run_cloudflared(port, port + 1, tunnel_id=tunnel_id)
else:
public_url = _run_cloudflared(port, port + 1)
if on_start:
on_start(public_url)
return
except Exception:
traceback.print_exc()
time.sleep(3)
raise Exception('Could not start cloudflared.')
def _get_api_lock(tls) -> asyncio.Lock:
"""
The streaming and blocking API implementations each run on their own
thread, and multiplex requests using asyncio. If multiple outstanding
requests are received at once, we will try to acquire the shared lock
shared.generation_lock multiple times in succession in the same thread,
which will cause a deadlock.
To avoid this, we use this wrapper function to block on an asyncio
lock, and then try and grab the shared lock only while holding
the asyncio lock.
"""
if not hasattr(tls, "asyncio_lock"):
tls.asyncio_lock = asyncio.Lock()
return tls.asyncio_lock
def with_api_lock(func):
"""
This decorator should be added to all streaming API methods which
require access to the shared.generation_lock. It ensures that the
tls.asyncio_lock is acquired before the method is called, and
released afterwards.
"""
@functools.wraps(func)
async def api_wrapper(*args, **kwargs):
async with _get_api_lock(api_tls):
return await func(*args, **kwargs)
return api_wrapper

View File

@ -1,11 +1,11 @@
#!/usr/bin/env python3
# preload the embedding model, useful for Docker images to prevent re-download on config change
# Dockerfile:
# ENV OPENEDAI_EMBEDDING_MODEL=all-mpnet-base-v2 # Optional
# ENV OPENEDAI_EMBEDDING_MODEL="sentence-transformers/all-mpnet-base-v2" # Optional
# RUN python3 cache_embedded_model.py
import os
import sentence_transformers
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", "all-mpnet-base-v2")
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2")
model = sentence_transformers.SentenceTransformer(st_model)

View File

@ -204,8 +204,9 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False) -
name1_instruct, name2_instruct, _, _, context_instruct, turn_template, system_message = load_character_memoized(instruction_template, '', '', instruct=True)
name1_instruct = body['name1_instruct'] or name1_instruct
name2_instruct = body['name2_instruct'] or name2_instruct
context_instruct = body['context_instruct'] or context_instruct
turn_template = body['turn_template'] or turn_template
context_instruct = body['context_instruct'] or context_instruct
system_message = body['system_message'] or system_message
# Chat character
character = body['character'] or shared.settings['character']

View File

@ -3,9 +3,7 @@ import os
import numpy as np
from extensions.openai.errors import ServiceUnavailableError
from extensions.openai.utils import debug_msg, float_list_to_base64
from transformers import AutoModel
from modules import shared
from modules.logging_colors import logger
embeddings_params_initialized = False
@ -17,38 +15,44 @@ def initialize_embedding_params():
'''
global embeddings_params_initialized
if not embeddings_params_initialized:
global st_model, embeddings_model, embeddings_device
from extensions.openai.script import params
global st_model, embeddings_model, embeddings_device
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", params.get('embedding_model', 'all-mpnet-base-v2'))
embeddings_model = None
# OPENEDAI_EMBEDDING_DEVICE: auto (best or cpu), cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone
embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", params.get('embedding_device', 'cpu'))
if embeddings_device.lower() == 'auto':
embeddings_device = None
embeddings_params_initialized = True
def load_embedding_model(model: str):
try:
from sentence_transformers import SentenceTransformer
except ModuleNotFoundError:
logger.error("The sentence_transformers module has not been found. Please install it manually with pip install -U sentence-transformers.")
raise ModuleNotFoundError
initialize_embedding_params()
global embeddings_device, embeddings_model
try:
print(f"Try embedding model: {model} on {embeddings_device}")
trust = shared.args.trust_remote_code
if embeddings_device == 'cpu':
embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=trust).to("cpu", dtype=float)
else: #use the auto mode
embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=trust)
print(f"\nLoaded embedding model: {model} on {embeddings_model.device}")
embeddings_model = SentenceTransformer(model, device=embeddings_device)
print(f"Loaded embedding model: {model}")
except Exception as e:
embeddings_model = None
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
def get_embeddings_model() -> AutoModel:
def get_embeddings_model():
initialize_embedding_params()
global embeddings_model, st_model
if st_model and not embeddings_model:
load_embedding_model(st_model) # lazy load the model
return embeddings_model
@ -67,9 +71,7 @@ def get_embeddings(input: list) -> np.ndarray:
def embeddings(input: list, encoding_format: str) -> dict:
embeddings = get_embeddings(input)
if encoding_format == "base64":
data = [{"object": "embedding", "embedding": float_list_to_base64(emb), "index": n} for n, emb in enumerate(embeddings)]
else:
@ -86,5 +88,4 @@ def embeddings(input: list, encoding_format: str) -> dict:
}
debug_msg(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
return response

View File

@ -55,6 +55,7 @@ def _load_model(data):
setattr(shared.args, k, args[k])
shared.model, shared.tokenizer = load_model(model_name)
shared.model_name = model_name
# Update shared.settings with custom generation defaults
if settings:

View File

@ -1,5 +1,4 @@
SpeechRecognition==3.10.0
flask_cloudflared==0.0.14
sentence-transformers
sse-starlette==1.6.5
tiktoken

View File

@ -31,6 +31,8 @@ from .typing import (
CompletionResponse,
DecodeRequest,
DecodeResponse,
EmbeddingsRequest,
EmbeddingsResponse,
EncodeRequest,
EncodeResponse,
LoadModelRequest,
@ -41,7 +43,7 @@ from .typing import (
params = {
'embedding_device': 'cpu',
'embedding_model': 'all-mpnet-base-v2',
'embedding_model': 'sentence-transformers/all-mpnet-base-v2',
'sd_webui_url': '',
'debug': 0
}
@ -196,19 +198,16 @@ async def handle_image_generation(request: Request):
return JSONResponse(response)
@app.post("/v1/embeddings")
async def handle_embeddings(request: Request):
body = await request.json()
encoding_format = body.get("encoding_format", "")
input = body.get('input', body.get('text', ''))
@app.post("/v1/embeddings", response_model=EmbeddingsResponse)
async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
input = request_data.input
if not input:
raise HTTPException(status_code=400, detail="Missing required argument input")
if type(input) is str:
input = [input]
response = OAIembeddings.embeddings(input, encoding_format)
response = OAIembeddings.embeddings(input, request_data.encoding_format)
return JSONResponse(response)
@ -295,14 +294,14 @@ def run_server():
if shared.args.public_api:
def on_start(public_url: str):
logger.info(f'OpenAI compatible API URL:\n\n{public_url}/v1\n')
logger.info(f'OpenAI compatible API URL:\n\n{public_url}\n')
_start_cloudflared(port, shared.args.public_api_id, max_attempts=3, on_start=on_start)
else:
if ssl_keyfile and ssl_certfile:
logger.info(f'OpenAI compatible API URL:\n\nhttps://{server_addr}:{port}/v1\n')
logger.info(f'OpenAI compatible API URL:\n\nhttps://{server_addr}:{port}\n')
else:
logger.info(f'OpenAI compatible API URL:\n\nhttp://{server_addr}:{port}/v1\n')
logger.info(f'OpenAI compatible API URL:\n\nhttp://{server_addr}:{port}\n')
if shared.args.api_key:
logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n')

View File

@ -42,7 +42,7 @@ class GenerationOptions(BaseModel):
class CompletionRequestParams(BaseModel):
model: str | None = None
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
prompt: str | List[str]
best_of: int | None = Field(default=1, description="Unused parameter.")
echo: bool | None = False
@ -75,7 +75,7 @@ class CompletionResponse(BaseModel):
class ChatCompletionRequestParams(BaseModel):
messages: List[dict]
model: str | None = None
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
frequency_penalty: float | None = 0
function_call: str | dict | None = Field(default=None, description="Unused parameter.")
functions: List[dict] | None = Field(default=None, description="Unused parameter.")
@ -92,10 +92,11 @@ class ChatCompletionRequestParams(BaseModel):
mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.")
instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/instruction-templates. If not set, the correct template will be guessed using the regex expressions in models/config.yaml.")
turn_template: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
name1_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
name2_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
context_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
turn_template: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
system_message: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
character: str | None = Field(default=None, description="A character defined under text-generation-webui/characters. If not set, the default \"Assistant\" character will be used.")
name1: str | None = Field(default=None, description="Overwrites the value set by character.")
@ -153,6 +154,19 @@ class LoadModelRequest(BaseModel):
settings: dict | None = None
class EmbeddingsRequest(BaseModel):
input: str | List[str]
model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")
encoding_format: str = Field(default="float", description="Can be float or base64.")
user: str | None = Field(default=None, description="Unused parameter.")
class EmbeddingsResponse(BaseModel):
index: int
embedding: List[float]
object: str = "embedding"
def to_json(obj):
return json.dumps(obj.__dict__, indent=4)

View File

@ -258,9 +258,8 @@ if args.multimodal_pipeline is not None:
add_extension('multimodal')
# Activate the API extension
if args.api:
# add_extension('openai', last=True)
add_extension('api', last=True)
if args.api or args.public_api:
add_extension('openai', last=True)
# Load model-specific settings
with Path(f'{args.model_dir}/config.yaml') as p:

View File

@ -190,7 +190,7 @@ def install_webui():
use_cuda118 = "Y" if os.environ.get("USE_CUDA118", "").lower() in ("yes", "y", "true", "1", "t", "on") else "N"
else:
# Ask for CUDA version if using NVIDIA
print("\nWould you like to use CUDA 11.8 instead of 12.1? This is only necessary for older GPUs like Kepler.\nIf unsure, say \"N\".\n")
print("\nDo you want to use CUDA 11.8 instead of 12.1? Only choose this option if your GPU is very old (Kepler or older).\nFor RTX and GTX series GPUs, say \"N\". If unsure, say \"N\".\n")
use_cuda118 = input("Input (Y/N)> ").upper().strip('"\'').strip()
while use_cuda118 not in 'YN':
print("Invalid choice. Please try again.")

View File

@ -9,6 +9,7 @@ os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated')
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict')
with RequestBlocker():
import gradio as gr