mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
[extension/openai] add edits & image endpoints & fix prompt return in non --chat modes (#1935)
This commit is contained in:
parent
23d3f6909a
commit
309b72e549
@ -1,4 +1,4 @@
|
|||||||
user: "[Round <|round|>]\n问:"
|
user: "[Round <|round|>]\n问:"
|
||||||
bot: "答:"
|
bot: "答:"
|
||||||
turn_template: "<|user|><|user-message|>\n<|bot|><|bot-message|>\n"
|
turn_template: "<|user|><|user-message|>\n<|bot|><|bot-message|>\n"
|
||||||
context: ""
|
context: ""
|
||||||
|
@ -11,6 +11,15 @@ Optional (for flask_cloudflared, embeddings):
|
|||||||
pip3 install -r requirements.txt
|
pip3 install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
|
It listens on tcp port 5001 by default. You can use the OPENEDAI_PORT environment variable to change this.
|
||||||
|
|
||||||
|
To enable the bare bones image generation (txt2img) set: SD_WEBUI_URL to point to your Stable Diffusion API ([Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui)).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
SD_WEBUI_URL=http://127.0.0.1:7861
|
||||||
|
```
|
||||||
|
|
||||||
### Embeddings (alpha)
|
### Embeddings (alpha)
|
||||||
|
|
||||||
Embeddings requires ```sentence-transformers``` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: ```sentence-transformers/all-mpnet-base-v2``` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default ```text-embedding-ada-002``` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future.
|
Embeddings requires ```sentence-transformers``` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: ```sentence-transformers/all-mpnet-base-v2``` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default ```text-embedding-ada-002``` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future.
|
||||||
@ -67,17 +76,22 @@ const api = new ChatGPTAPI({
|
|||||||
|
|
||||||
## Compatibility & not so compatibility
|
## Compatibility & not so compatibility
|
||||||
|
|
||||||
What's working:
|
|
||||||
|
|
||||||
| API endpoint | tested with | notes |
|
| API endpoint | tested with | notes |
|
||||||
| --- | --- | --- |
|
| --- | --- | --- |
|
||||||
| /v1/models | openai.Model.list() | returns the currently loaded model_name and some mock compatibility options |
|
| /v1/models | openai.Model.list() | returns the currently loaded model_name and some mock compatibility options |
|
||||||
| /v1/models/{id} | openai.Model.get() | returns whatever you ask for, model does nothing yet anyways |
|
| /v1/models/{id} | openai.Model.get() | returns whatever you ask for, model does nothing yet anyways |
|
||||||
| /v1/text_completion | openai.Completion.create() | the most tested, only supports single string input so far |
|
| /v1/text_completion | openai.Completion.create() | the most tested, only supports single string input so far |
|
||||||
| /v1/chat/completions | openai.ChatCompletion.create() | depending on the model, this may add leading linefeeds |
|
| /v1/chat/completions | openai.ChatCompletion.create() | depending on the model, this may add leading linefeeds |
|
||||||
|
| /v1/edits | openai.Edit.create() | Assumes an instruction following model, but may work with others |
|
||||||
|
| /v1/images/generations | openai.Image.create() | Bare bones, no model configuration, response_format='b64_json' only. |
|
||||||
| /v1/embeddings | openai.Embedding.create() | Using Sentence Transformer, dimensions are different and may never be directly comparable to openai embeddings. |
|
| /v1/embeddings | openai.Embedding.create() | Using Sentence Transformer, dimensions are different and may never be directly comparable to openai embeddings. |
|
||||||
| /v1/moderations | openai.Moderation.create() | does nothing. successfully. |
|
| /v1/moderations | openai.Moderation.create() | does nothing. successfully. |
|
||||||
| /v1/engines/\*/... completions, embeddings, generate | python-openai v0.25 and earlier | Legacy engines endpoints |
|
| /v1/engines/\*/... completions, embeddings, generate | python-openai v0.25 and earlier | Legacy engines endpoints |
|
||||||
|
| /v1/images/edits | openai.Image.create_edit() | not supported |
|
||||||
|
| /v1/images/variations | openai.Image.create_variation() | not supported |
|
||||||
|
| /v1/audio/\* | openai.Audio.\* | not supported |
|
||||||
|
| /v1/files\* | openai.Files.\* | not supported |
|
||||||
|
| /v1/fine-tunes\* | openai.FineTune.\* | not supported |
|
||||||
|
|
||||||
The model name setting is ignored in completions, but you may need to adjust the maximum token length to fit the model (ie. set to <2048 tokens instead of 4096, 8k, etc). To mitigate some of this, the max_tokens value is halved until it is less than truncation_length for the model (typically 2k).
|
The model name setting is ignored in completions, but you may need to adjust the maximum token length to fit the model (ie. set to <2048 tokens instead of 4096, 8k, etc). To mitigate some of this, the max_tokens value is halved until it is less than truncation_length for the model (typically 2k).
|
||||||
|
|
||||||
@ -99,6 +113,10 @@ Some hacky mappings:
|
|||||||
|
|
||||||
defaults are mostly from openai, so are different. I use the openai defaults where I can and try to scale them to the webui defaults with the same intent.
|
defaults are mostly from openai, so are different. I use the openai defaults where I can and try to scale them to the webui defaults with the same intent.
|
||||||
|
|
||||||
|
### Models
|
||||||
|
|
||||||
|
This has been successfully tested with Koala, Alpaca, gpt4-x-alpaca, GPT4all-snoozy, wizard-vicuna, stable-vicuna and Vicuna 1.1 - ie. Instruction Following models. If you test with other models please let me know how it goes. Less than satisfying results (so far): RWKV-4-Raven, llama, mpt-7b-instruct/chat
|
||||||
|
|
||||||
### Applications
|
### Applications
|
||||||
|
|
||||||
Everything needs OPENAI_API_KEY=dummy set.
|
Everything needs OPENAI_API_KEY=dummy set.
|
||||||
@ -120,4 +138,7 @@ Everything needs OPENAI_API_KEY=dummy set.
|
|||||||
* model changing, esp. something for swapping loras or embedding models
|
* model changing, esp. something for swapping loras or embedding models
|
||||||
* consider switching to FastAPI + starlette for SSE (openai SSE seems non-standard)
|
* consider switching to FastAPI + starlette for SSE (openai SSE seems non-standard)
|
||||||
* do something about rate limiting or locking requests for completions, most systems will only be able handle a single request at a time before OOM
|
* do something about rate limiting or locking requests for completions, most systems will only be able handle a single request at a time before OOM
|
||||||
* the whole api, images (stable diffusion), audio (whisper), fine-tunes (training), edits, files, etc.
|
|
||||||
|
## Bugs? Feedback? Comments? Pull requests?
|
||||||
|
|
||||||
|
Are all appreciated, please @matatonic and I'll try to get back to you as soon as possible.
|
||||||
|
8
extensions/openai/cache_embedding_model.py
Executable file
8
extensions/openai/cache_embedding_model.py
Executable file
@ -0,0 +1,8 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# preload the embedding model, useful for Docker images to prevent re-download on config change
|
||||||
|
# Dockerfile:
|
||||||
|
# ENV OPENEDAI_EMBEDDING_MODEL=all-mpnet-base-v2 # Optional
|
||||||
|
# RUN python3 cache_embedded_model.py
|
||||||
|
import os, sentence_transformers
|
||||||
|
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
||||||
|
model = sentence_transformers.SentenceTransformer(st_model)
|
@ -2,6 +2,8 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import requests
|
||||||
|
import yaml
|
||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
@ -48,6 +50,31 @@ def clamp(value, minvalue, maxvalue):
|
|||||||
return max(minvalue, min(value, maxvalue))
|
return max(minvalue, min(value, maxvalue))
|
||||||
|
|
||||||
|
|
||||||
|
def deduce_template():
|
||||||
|
# Alpaca is verbose so a good default prompt
|
||||||
|
default_template = (
|
||||||
|
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||||
|
"Write a response that appropriately completes the request.\n\n"
|
||||||
|
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the special instruction/input/response template for anything trained like Alpaca
|
||||||
|
if shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']:
|
||||||
|
return default_template
|
||||||
|
|
||||||
|
try:
|
||||||
|
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||||
|
|
||||||
|
template = instruct['turn_template']
|
||||||
|
template = template\
|
||||||
|
.replace('<|user|>', instruct.get('user', ''))\
|
||||||
|
.replace('<|bot|>', instruct.get('bot', ''))\
|
||||||
|
.replace('<|user-message|>', '{instruction}\n{input}')
|
||||||
|
return instruct.get('context', '') + template[:template.find('<|bot-message|>')]
|
||||||
|
except:
|
||||||
|
return default_template
|
||||||
|
|
||||||
|
|
||||||
def float_list_to_base64(float_list):
|
def float_list_to_base64(float_list):
|
||||||
# Convert the list to a float32 array that the OpenAPI client expects
|
# Convert the list to a float32 array that the OpenAPI client expects
|
||||||
float_array = np.array(float_list, dtype="float32")
|
float_array = np.array(float_list, dtype="float32")
|
||||||
@ -120,11 +147,20 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
self.send_error(404)
|
self.send_error(404)
|
||||||
|
|
||||||
def do_POST(self):
|
def do_POST(self):
|
||||||
|
# ... haaack.
|
||||||
|
is_chat = shared.args.chat
|
||||||
|
try:
|
||||||
|
shared.args.chat = True
|
||||||
|
self.do_POST_wrap()
|
||||||
|
finally:
|
||||||
|
shared.args.chat = is_chat
|
||||||
|
|
||||||
|
def do_POST_wrap(self):
|
||||||
|
if debug:
|
||||||
|
print(self.headers) # did you know... python-openai sends your linux kernel & python version?
|
||||||
content_length = int(self.headers['Content-Length'])
|
content_length = int(self.headers['Content-Length'])
|
||||||
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
|
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(self.headers) # did you know... python-openai sends your linux kernel & python version?
|
|
||||||
if debug:
|
if debug:
|
||||||
print(body)
|
print(body)
|
||||||
|
|
||||||
@ -150,7 +186,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
||||||
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length)
|
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length)
|
||||||
|
|
||||||
default_max_tokens = truncation_length if is_chat else 16 # completions default, chat default is 'inf' so we need to cap it., the default for chat is "inf"
|
default_max_tokens = truncation_length if is_chat else 16 # completions default, chat default is 'inf' so we need to cap it.
|
||||||
|
|
||||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||||
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
|
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
|
||||||
@ -440,6 +476,129 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
else:
|
else:
|
||||||
resp[resp_list][0]["text"] = answer
|
resp[resp_list][0]["text"] = answer
|
||||||
|
|
||||||
|
response = json.dumps(resp)
|
||||||
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
elif '/edits' in self.path:
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header('Content-Type', 'application/json')
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
created_time = int(time.time())
|
||||||
|
|
||||||
|
# Using Alpaca format, this may work with other models too.
|
||||||
|
instruction = body['instruction']
|
||||||
|
input = body.get('input', '')
|
||||||
|
|
||||||
|
instruction_template = deduce_template()
|
||||||
|
edit_task = instruction_template.format(instruction=instruction, input=input)
|
||||||
|
|
||||||
|
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
||||||
|
token_count = len(encode(edit_task)[0])
|
||||||
|
max_tokens = truncation_length - token_count
|
||||||
|
|
||||||
|
req_params = {
|
||||||
|
'max_new_tokens': max_tokens,
|
||||||
|
'temperature': clamp(default(body, 'temperature', 1.0), 0.001, 1.999),
|
||||||
|
'top_p': clamp(default(body, 'top_p', 1.0), 0.001, 1.0),
|
||||||
|
'top_k': 1,
|
||||||
|
'repetition_penalty': 1.18,
|
||||||
|
'encoder_repetition_penalty': 1.0,
|
||||||
|
'suffix': None,
|
||||||
|
'stream': False,
|
||||||
|
'echo': False,
|
||||||
|
'seed': shared.settings.get('seed', -1),
|
||||||
|
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
|
||||||
|
'truncation_length': truncation_length,
|
||||||
|
'add_bos_token': shared.settings.get('add_bos_token', True),
|
||||||
|
'do_sample': True,
|
||||||
|
'typical_p': 1.0,
|
||||||
|
'min_length': 0,
|
||||||
|
'no_repeat_ngram_size': 0,
|
||||||
|
'num_beams': 1,
|
||||||
|
'penalty_alpha': 0.0,
|
||||||
|
'length_penalty': 1,
|
||||||
|
'early_stopping': False,
|
||||||
|
'ban_eos_token': False,
|
||||||
|
'skip_special_tokens': True,
|
||||||
|
'custom_stopping_strings': [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
||||||
|
|
||||||
|
generator = generate_reply(edit_task, req_params, stopping_strings=standard_stopping_strings)
|
||||||
|
|
||||||
|
answer = ''
|
||||||
|
for a in generator:
|
||||||
|
if isinstance(a, str):
|
||||||
|
answer = a
|
||||||
|
else:
|
||||||
|
answer = a[0]
|
||||||
|
|
||||||
|
completion_token_count = len(encode(answer)[0])
|
||||||
|
|
||||||
|
resp = {
|
||||||
|
"object": "edit",
|
||||||
|
"created": created_time,
|
||||||
|
"choices": [{
|
||||||
|
"text": answer,
|
||||||
|
"index": 0,
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": token_count,
|
||||||
|
"completion_tokens": completion_token_count,
|
||||||
|
"total_tokens": token_count + completion_token_count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print({'answer': answer, 'completion_token_count': completion_token_count})
|
||||||
|
|
||||||
|
response = json.dumps(resp)
|
||||||
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ:
|
||||||
|
# Stable Diffusion callout wrapper for txt2img
|
||||||
|
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
|
||||||
|
# the results will be limited and likely poor. SD has hundreds of models and dozens of settings.
|
||||||
|
# If you want high quality tailored results you should just use the Stable Diffusion API directly.
|
||||||
|
# it's too general an API to try and shape the result with specific tags like "masterpiece", etc,
|
||||||
|
# Will probably work best with the stock SD models.
|
||||||
|
# SD configuration is beyond the scope of this API.
|
||||||
|
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
|
||||||
|
# require changing the form data handling to accept multipart form data, also to properly support
|
||||||
|
# url return types will require file management and a web serving files... Perhaps later!
|
||||||
|
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header('Content-Type', 'application/json')
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
width, height = [ int(x) for x in default(body, 'size', '1024x1024').split('x') ] # ignore the restrictions on size
|
||||||
|
response_format = default(body, 'response_format', 'url') # or b64_json
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
'prompt': body['prompt'], # ignore prompt limit of 1000 characters
|
||||||
|
'width': width,
|
||||||
|
'height': height,
|
||||||
|
'batch_size': default(body, 'n', 1) # ignore the batch limits of max 10
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = {
|
||||||
|
'created': int(time.time()),
|
||||||
|
'data': []
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: support SD_WEBUI_AUTH username:password pair.
|
||||||
|
sd_url = f"{os.environ['SD_WEBUI_URL']}/sdapi/v1/txt2img"
|
||||||
|
|
||||||
|
response = requests.post(url=sd_url, json=payload)
|
||||||
|
r = response.json()
|
||||||
|
# r['parameters']...
|
||||||
|
for b64_json in r['images']:
|
||||||
|
if response_format == 'b64_json':
|
||||||
|
resp['data'].extend([{'b64_json': b64_json}])
|
||||||
|
else:
|
||||||
|
resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this
|
||||||
|
|
||||||
response = json.dumps(resp)
|
response = json.dumps(resp)
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
elif '/embeddings' in self.path and embedding_model is not None:
|
elif '/embeddings' in self.path and embedding_model is not None:
|
||||||
@ -540,11 +699,12 @@ def run_server():
|
|||||||
try:
|
try:
|
||||||
from flask_cloudflared import _run_cloudflared
|
from flask_cloudflared import _run_cloudflared
|
||||||
public_url = _run_cloudflared(params['port'], params['port'] + 1)
|
public_url = _run_cloudflared(params['port'], params['port'] + 1)
|
||||||
print(f'Starting OpenAI compatible api at {public_url}/')
|
print(f'Starting OpenAI compatible api at\nOPENAI_API_BASE={public_url}/v1')
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print('You should install flask_cloudflared manually')
|
print('You should install flask_cloudflared manually')
|
||||||
else:
|
else:
|
||||||
print(f'Starting OpenAI compatible api at http://{server_addr[0]}:{server_addr[1]}/')
|
print(f'Starting OpenAI compatible api:\nOPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')
|
||||||
|
|
||||||
server.serve_forever()
|
server.serve_forever()
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,6 +54,9 @@
|
|||||||
.*vicuna.*(1.1|1_1):
|
.*vicuna.*(1.1|1_1):
|
||||||
mode: 'instruct'
|
mode: 'instruct'
|
||||||
instruction_template: 'Vicuna-v1.1'
|
instruction_template: 'Vicuna-v1.1'
|
||||||
|
.*wizard.*vicuna:
|
||||||
|
mode: 'instruct'
|
||||||
|
instruction_template: 'Vicuna-v1.1'
|
||||||
.*stable.*vicuna:
|
.*stable.*vicuna:
|
||||||
mode: 'instruct'
|
mode: 'instruct'
|
||||||
instruction_template: 'StableVicuna'
|
instruction_template: 'StableVicuna'
|
||||||
|
Loading…
Reference in New Issue
Block a user