mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-24 17:06:53 +01:00
extensions/openai: Fixes for: embeddings, tokens, better errors. +Docs update, +Images, +logit_bias/logprobs, +more. (#3122)
This commit is contained in:
parent
1141987a0d
commit
90a4ab631c
@ -1,17 +1,15 @@
|
||||
# An OpenedAI API (openai like)
|
||||
|
||||
This extension creates an API that works kind of like openai (ie. api.openai.com).
|
||||
It's incomplete so far but perhaps is functional enough for you.
|
||||
|
||||
## Setup & installation
|
||||
|
||||
Optional (for flask_cloudflared, embeddings):
|
||||
|
||||
Install the requirements:
|
||||
```
|
||||
pip3 install -r requirements.txt
|
||||
```
|
||||
|
||||
It listens on tcp port 5001 by default. You can use the OPENEDAI_PORT environment variable to change this.
|
||||
It listens on ```tcp port 5001``` by default. You can use the ```OPENEDAI_PORT``` environment variable to change this.
|
||||
|
||||
Make sure you enable it in server launch parameters, it should include:
|
||||
|
||||
@ -21,13 +19,30 @@ Make sure you enable it in server launch parameters, it should include:
|
||||
|
||||
You can also use the ``--listen`` argument to make the server available on the networ, and/or the ```--share``` argument to enable a public Cloudflare endpoint.
|
||||
|
||||
To enable the basic image generation support (txt2img) set the environment variable SD_WEBUI_URL to point to your Stable Diffusion API ([Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui)).
|
||||
To enable the basic image generation support (txt2img) set the environment variable ```SD_WEBUI_URL``` to point to your Stable Diffusion API ([Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui)).
|
||||
|
||||
For example:
|
||||
```
|
||||
SD_WEBUI_URL=http://127.0.0.1:7861
|
||||
```
|
||||
|
||||
## Quick start
|
||||
|
||||
1. Install the requirements.txt (pip)
|
||||
2. Enable the ```openeai``` module (--extensions openai), restart the server.
|
||||
3. Configure the openai client
|
||||
|
||||
Most openai application can be configured to connect the API if you set the following environment variables:
|
||||
|
||||
```shell
|
||||
# Sample .env file:
|
||||
OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
|
||||
OPENAI_API_BASE=http://0.0.0.0:5001/v1
|
||||
```
|
||||
|
||||
If needed, replace 0.0.0.0 with the IP/port of your server.
|
||||
|
||||
|
||||
### Models
|
||||
|
||||
This has been successfully tested with Alpaca, Koala, Vicuna, WizardLM and their variants, (ex. gpt4-x-alpaca, GPT4all-snoozy, stable-vicuna, wizard-vicuna, etc.) and many others. Models that have been trained for **Instruction Following** work best. If you test with other models please let me know how it goes. Less than satisfying results (so far) from: RWKV-4-Raven, llama, mpt-7b-instruct/chat.
|
||||
@ -36,7 +51,7 @@ For best results across all API endpoints, a model like [vicuna-13b-v1.3-GPTQ](h
|
||||
|
||||
For good results with the [Completions](https://platform.openai.com/docs/api-reference/completions) API endpoint, in addition to the above models, you can also try using a base model like [falcon-7b](https://huggingface.co/tiiuae/falcon-7b) or Llama.
|
||||
|
||||
For good results with the [ChatCompletions](https://platform.openai.com/docs/api-reference/chat) or [Edits](https://platform.openai.com/docs/api-reference/edits) API endpoints you can use almost any model trained for instruction following - within the limits of the model. Be sure that the proper instruction template is detected and loaded or the results will not be good.
|
||||
For good results with the [ChatCompletions](https://platform.openai.com/docs/api-reference/chat) or [Edits](https://platform.openai.com/docs/api-reference/edits) API endpoints you can use almost any model trained for instruction following. Be sure that the proper instruction template is detected and loaded or the results will not be good.
|
||||
|
||||
For the proper instruction format to be detected you need to have a matching model entry in your ```models/config.yaml``` file. Be sure to keep this file up to date.
|
||||
A matching instruction template file in the characters/instruction-following/ folder will loaded and applied to format messages correctly for the model - this is critical for good results.
|
||||
@ -76,7 +91,7 @@ Embeddings requires ```sentence-transformers``` installed, but chat and completi
|
||||
| all-mpnet-base-v2 | 768 | 384 | 2800 | 420M | 63.3 |
|
||||
| all-MiniLM-L6-v2 | 384 | 256 | 14200 | 80M | 58.8 |
|
||||
|
||||
In short, the all-MiniLM-L6-v2 model is 5x faster, 5x smaller ram, 2x smaller storage, and still offers good quality. Stats from (https://www.sbert.net/docs/pretrained_models.html). To change the model from the default you can set the environment variable OPENEDAI_EMBEDDING_MODEL, ex. "OPENEDAI_EMBEDDING_MODEL=all-MiniLM-L6-v2".
|
||||
In short, the all-MiniLM-L6-v2 model is 5x faster, 5x smaller ram, 2x smaller storage, and still offers good quality. Stats from (https://www.sbert.net/docs/pretrained_models.html). To change the model from the default you can set the environment variable ```OPENEDAI_EMBEDDING_MODEL```, ex. "OPENEDAI_EMBEDDING_MODEL=all-MiniLM-L6-v2".
|
||||
|
||||
Warning: You cannot mix embeddings from different models even if they have the same dimensions. They are not comparable.
|
||||
|
||||
@ -85,26 +100,27 @@ Warning: You cannot mix embeddings from different models even if they have the s
|
||||
|
||||
Almost everything you use it with will require you to set a dummy OpenAI API key environment variable.
|
||||
|
||||
With the [official python openai client](https://github.com/openai/openai-python), you can set the OPENAI_API_BASE environment variable before you import the openai module, like so:
|
||||
With the [official python openai client](https://github.com/openai/openai-python), set the ```OPENAI_API_BASE``` environment variables:
|
||||
|
||||
```
|
||||
```shell
|
||||
# Sample .env file:
|
||||
OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
|
||||
OPENAI_API_BASE=http://127.0.0.1:5001/v1
|
||||
OPENAI_API_BASE=http://0.0.0.0:5001/v1
|
||||
```
|
||||
|
||||
If needed, replace 127.0.0.1 with the IP/port of your server.
|
||||
If needed, replace 0.0.0.0 with the IP/port of your server.
|
||||
|
||||
If using .env files to save the OPENAI_API_BASE and OPENAI_API_KEY variables, you can ensure compatibility by loading the .env file before loading the openai module, like so in python:
|
||||
If using .env files to save the ```OPENAI_API_BASE``` and ```OPENAI_API_KEY``` variables, make sure the .env file is loaded before the openai module is imported:
|
||||
|
||||
```
|
||||
```python
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
load_dotenv() # make sure the environment variables are set before import
|
||||
import openai
|
||||
```
|
||||
|
||||
With the [official Node.js openai client](https://github.com/openai/openai-node) it is slightly more more complex because the environment variables are not used by default, so small source code changes may be required to use the environment variables, like so:
|
||||
|
||||
```
|
||||
```js
|
||||
const openai = OpenAI(Configuration({
|
||||
apiKey: process.env.OPENAI_API_KEY,
|
||||
basePath: process.env.OPENAI_API_BASE,
|
||||
@ -113,7 +129,7 @@ const openai = OpenAI(Configuration({
|
||||
|
||||
For apps made with the [chatgpt-api Node.js client library](https://github.com/transitive-bullshit/chatgpt-api):
|
||||
|
||||
```
|
||||
```js
|
||||
const api = new ChatGPTAPI({
|
||||
apiKey: process.env.OPENAI_API_KEY,
|
||||
apiBaseUrl: process.env.OPENAI_API_BASE,
|
||||
@ -127,39 +143,43 @@ The OpenAI API is well documented, you can view the documentation here: https://
|
||||
Examples of how to use the Completions API in Python can be found here: https://platform.openai.com/examples
|
||||
Not all of them will work with all models unfortunately, See the notes on Models for how to get the best results.
|
||||
|
||||
Here is a simple python example of how you can use the Edit endpoint as a translator.
|
||||
Here is a simple python example.
|
||||
|
||||
```python
|
||||
import os
|
||||
os.environ['OPENAI_API_KEY']="sk-111111111111111111111111111111111111111111111111"
|
||||
os.environ['OPENAI_API_BASE']="http://0.0.0.0:5001/v1"
|
||||
import openai
|
||||
response = openai.Edit.create(
|
||||
|
||||
response = openai.ChatCompletion.create(
|
||||
model="x",
|
||||
instruction="Translate this into French",
|
||||
input="Our mission is to ensure that artificial general intelligence benefits all of humanity.",
|
||||
messages = [{ 'role': 'system', 'content': "Answer in a consistent style." },
|
||||
{'role': 'user', 'content': "Teach me about patience."},
|
||||
{'role': 'assistant', 'content': "The river that carves the deepest valley flows from a modest spring; the grandest symphony originates from a single note; the most intricate tapestry begins with a solitary thread."},
|
||||
{'role': 'user', 'content': "Teach me about the ocean."},
|
||||
]
|
||||
)
|
||||
print(response['choices'][0]['text'])
|
||||
# Sample Output:
|
||||
# Notre mission est de garantir que l'intelligence artificielle généralisée profite à tous les membres de l'humanité.
|
||||
text = response['choices'][0]['message']['content']
|
||||
print(text)
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Compatibility & not so compatibility
|
||||
|
||||
| API endpoint | tested with | notes |
|
||||
| --- | --- | --- |
|
||||
| /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options |
|
||||
| /v1/models/{id} | openai.Model.get() | returns whatever you ask for, model does nothing yet anyways |
|
||||
| /v1/text_completion | openai.Completion.create() | the most tested, only supports single string input so far, variable quality based on the model |
|
||||
| /v1/chat/completions | openai.ChatCompletion.create() | Quality depends a lot on the model |
|
||||
| /v1/edits | openai.Edit.create() | Works the best of all, perfect for instruction following models |
|
||||
| /v1/chat/completions | openai.ChatCompletion.create() | Use it with instruction following models |
|
||||
| /v1/embeddings | openai.Embedding.create() | Using SentenceTransformer embeddings |
|
||||
| /v1/images/generations | openai.Image.create() | Bare bones, no model configuration, response_format='b64_json' only. |
|
||||
| /v1/embeddings | openai.Embedding.create() | Using Sentence Transformer, dimensions are different and may never be directly comparable to openai embeddings. |
|
||||
| /v1/moderations | openai.Moderation.create() | does nothing. successfully. |
|
||||
| /v1/moderations | openai.Moderation.create() | Basic initial support via embeddings |
|
||||
| /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options |
|
||||
| /v1/models/{id} | openai.Model.get() | returns whatever you ask for |
|
||||
| /v1/edits | openai.Edit.create() | Deprecated by openai, good with instruction following models |
|
||||
| /v1/text_completion | openai.Completion.create() | Legacy endpoint, doesn't support array input, variable quality based on the model |
|
||||
| /v1/completions | openai api completions.create | Legacy endpoint (v0.25) |
|
||||
| /v1/engines/*/embeddings | python-openai v0.25 | Legacy endpoint |
|
||||
| /v1/engines/*/generate | openai engines.generate | Legacy endpoint |
|
||||
| /v1/engines | openai engines.list | Legacy Lists models |
|
||||
| /v1/engines/{model_name} | openai engines.get -i {model_name} | You can use this legacy endpoint to load models via the api |
|
||||
| /v1/engines/{model_name} | openai engines.get -i {model_name} | You can use this legacy endpoint to load models via the api or command line |
|
||||
| /v1/images/edits | openai.Image.create_edit() | not yet supported |
|
||||
| /v1/images/variations | openai.Image.create_variation() | not yet supported |
|
||||
| /v1/audio/\* | openai.Audio.\* | not yet supported |
|
||||
@ -167,7 +187,7 @@ print(response['choices'][0]['text'])
|
||||
| /v1/fine-tunes\* | openai.FineTune.\* | not yet supported |
|
||||
| /v1/search | openai.search, engines.search | not yet supported |
|
||||
|
||||
The model name setting is ignored in completions, but you may need to adjust the maximum token length to fit the model (ie. set to <2048 tokens instead of 4096, 8k, etc). To mitigate some of this, the max_tokens value is halved until it is less than truncation_length for the model (typically 2k).
|
||||
Because of the differences in OpenAI model context sizes (2k, 4k, 8k, 16k, etc,) you may need to adjust the max_tokens to fit into the context of the model you choose.
|
||||
|
||||
Streaming, temperature, top_p, max_tokens, stop, should all work as expected, but not all parameters are mapped correctly.
|
||||
|
||||
@ -175,41 +195,29 @@ Some hacky mappings:
|
||||
|
||||
| OpenAI | text-generation-webui | note |
|
||||
| --- | --- | --- |
|
||||
| model | - | Ignored, the model is not changed |
|
||||
| frequency_penalty | encoder_repetition_penalty | this seems to operate with a different scale and defaults, I tried to scale it based on range & defaults, but the results are terrible. hardcoded to 1.18 until there is a better way |
|
||||
| presence_penalty | repetition_penalty | same issues as frequency_penalty, hardcoded to 1.0 |
|
||||
| best_of | top_k | default is 1 |
|
||||
| stop | custom_stopping_strings | this is also stuffed with ['\n###', "\n{user prompt}", "{user prompt}" ] for good measure. |
|
||||
| best_of | top_k | default is 1 (top_k is 20 for chat, which doesn't support best_of) |
|
||||
| n | 1 | variations are not supported yet. |
|
||||
| 1 | num_beams | hardcoded to 1 |
|
||||
| 1.0 | typical_p | hardcoded to 1.0 |
|
||||
| max_tokens | max_new_tokens | For Text Completions max_tokens is set smaller than the truncation_length minus the prompt length. This can cause no input to be generated if the prompt is too large. For ChatCompletions, the older chat messages may be dropped to fit the max_new_tokens requested |
|
||||
| logprobs | - | not supported yet |
|
||||
| logit_bias | - | not supported yet |
|
||||
| logprobs & logit_bias | - | experimental, llama only, transformers-kin only (ExLlama_HF ok), can also use llama tokens if 'model' is not an openai model or will convert from tiktoken for the openai model specified in 'model' |
|
||||
| messages.name | - | not supported yet |
|
||||
| user | - | not supported yet |
|
||||
| functions/function_call | - | function calls are not supported yet |
|
||||
|
||||
defaults are mostly from openai, so are different. I use the openai defaults where I can and try to scale them to the webui defaults with the same intent.
|
||||
|
||||
### Applications
|
||||
|
||||
Almost everything needs the OPENAI_API_KEY environment variable set, for example:
|
||||
```
|
||||
OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
|
||||
```
|
||||
Some apps are picky about key format, but 'dummy' or 'sk-dummy' also work in most cases.
|
||||
Most application will work if you also set:
|
||||
```
|
||||
OPENAI_API_BASE=http://127.0.0.1:5001/v1
|
||||
```
|
||||
but there are some exceptions.
|
||||
Almost everything needs the ```OPENAI_API_KEY``` and ```OPENAI_API_BASE``` environment variable set, but there are some exceptions.
|
||||
|
||||
| Compatibility | Application/Library | url | notes / setting |
|
||||
| Compatibility | Application/Library | Website | Notes |
|
||||
| --- | --- | --- | --- |
|
||||
| ✅❌ | openai-python (v0.25+) | https://github.com/openai/openai-python | only the endpoints from above are working. OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
||||
| ✅❌ | openai-node | https://github.com/openai/openai-node | only the endpoints from above are working. environment variables don't work by default, but can be configured (see above) |
|
||||
| ✅❌ | chatgpt-api | https://github.com/transitive-bullshit/chatgpt-api | only the endpoints from above are working. environment variables don't work by default, but can be configured (see above) |
|
||||
| ✅ | anse | https://github.com/anse-app/anse | API Key & URL configurable in UI |
|
||||
| ✅ | anse | https://github.com/anse-app/anse | API Key & URL configurable in UI, Images also work |
|
||||
| ✅ | shell_gpt | https://github.com/TheR1D/shell_gpt | OPENAI_API_HOST=http://127.0.0.1:5001 |
|
||||
| ✅ | gpt-shell | https://github.com/jla/gpt-shell | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
||||
| ✅ | gpt-discord-bot | https://github.com/openai/gpt-discord-bot | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
||||
@ -221,11 +229,12 @@ but there are some exceptions.
|
||||
| ❌ | guidance | https://github.com/microsoft/guidance | logit_bias and logprobs not yet supported |
|
||||
|
||||
## Future plans
|
||||
* better error handling
|
||||
* model changing, esp. something for swapping loras or embedding models
|
||||
* consider switching to FastAPI + starlette for SSE (openai SSE seems non-standard)
|
||||
|
||||
## Bugs? Feedback? Comments? Pull requests?
|
||||
|
||||
To enable debugging and get copious output you can set the OPENEDAI_DEBUG=1 environment variable.
|
||||
To enable debugging and get copious output you can set the ```OPENEDAI_DEBUG=1``` environment variable.
|
||||
|
||||
Are all appreciated, please @matatonic and I'll try to get back to you as soon as possible.
|
@ -3,6 +3,7 @@ import yaml
|
||||
import tiktoken
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from math import log, exp
|
||||
|
||||
from transformers import LogitsProcessor, LogitsProcessorList
|
||||
|
||||
@ -18,41 +19,50 @@ from extensions.openai.errors import *
|
||||
class LogitsBiasProcessor(LogitsProcessor):
|
||||
def __init__(self, logit_bias={}):
|
||||
self.logit_bias = logit_bias
|
||||
super().__init__()
|
||||
if self.logit_bias:
|
||||
self.keys = list([int(key) for key in self.logit_bias.keys()])
|
||||
values = [ self.logit_bias[str(key)] for key in self.keys ]
|
||||
self.values = torch.tensor(values, dtype=torch.float, device=shared.model.device)
|
||||
debug_msg(f"{self})")
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.logit_bias:
|
||||
keys = list([int(key) for key in self.logit_bias.keys()])
|
||||
values = list([int(val) for val in self.logit_bias.values()])
|
||||
logits[0, keys] += torch.tensor(values).cuda()
|
||||
|
||||
debug_msg(logits[0, self.keys], " + ", self.values)
|
||||
logits[0, self.keys] += self.values
|
||||
debug_msg(" --> ", logits[0, self.keys])
|
||||
debug_msg(" max/min ", float(torch.max(logits[0])), float(torch.min(logits[0])))
|
||||
return logits
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.__class__.__name__}(logit_bias={self.logit_bias})>"
|
||||
|
||||
class LogprobProcessor(LogitsProcessor):
|
||||
def __init__(self, logprobs=None):
|
||||
self.logprobs = logprobs
|
||||
self.token_alternatives = {}
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.logprobs is not None: # 0-5
|
||||
log_e_probabilities = F.log_softmax(logits, dim=1)
|
||||
# XXX hack. should find the selected token and include the prob of that
|
||||
# ... but we just +1 here instead because we don't know it yet.
|
||||
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs+1)
|
||||
top_tokens = [ decode(tok) for tok in top_indices[0] ]
|
||||
self.token_alternatives = dict(zip(top_tokens, top_values[0].tolist()))
|
||||
top_probs = [ float(x) for x in top_values[0] ]
|
||||
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
||||
debug_msg(f"{self.__class__.__name__}(logprobs+1={self.logprobs+1}, token_alternatives={self.token_alternatives})")
|
||||
return logits
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.__class__.__name__}(logprobs={self.logprobs}, token_alternatives={self.token_alternatives})>"
|
||||
|
||||
|
||||
def convert_logprobs_to_tiktoken(model, logprobs):
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(model)
|
||||
# just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
|
||||
return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
|
||||
except KeyError:
|
||||
# assume native tokens if we can't find the tokenizer
|
||||
# more problems than it's worth.
|
||||
# try:
|
||||
# encoder = tiktoken.encoding_for_model(model)
|
||||
# # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
|
||||
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
|
||||
# except KeyError:
|
||||
# # assume native tokens if we can't find the tokenizer
|
||||
return logprobs
|
||||
|
||||
|
||||
@ -73,8 +83,8 @@ def marshal_common_params(body):
|
||||
req_params['requested_model'] = body.get('model', shared.model_name)
|
||||
|
||||
req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
||||
req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0/2.0
|
||||
req_params['top_p'] = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
|
||||
req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.01, 1.99) # fixup absolute 0.0/2.0
|
||||
req_params['top_p'] = clamp(default(body, 'top_p', req_params['top_p']), 0.01, 1.0)
|
||||
n = default(body, 'n', 1)
|
||||
if n != 1:
|
||||
raise InvalidRequestError(message="Only n = 1 is supported.", param='n')
|
||||
@ -87,6 +97,11 @@ def marshal_common_params(body):
|
||||
|
||||
# presence_penalty - ignored
|
||||
# frequency_penalty - ignored
|
||||
|
||||
# pass through unofficial params
|
||||
req_params['repetition_penalty'] = default(body, 'repetition_penalty', req_params['repetition_penalty'])
|
||||
req_params['encoder_repetition_penalty'] = default(body, 'encoder_repetition_penalty', req_params['encoder_repetition_penalty'])
|
||||
|
||||
# user - ignored
|
||||
|
||||
logits_processor = []
|
||||
@ -98,9 +113,11 @@ def marshal_common_params(body):
|
||||
encoder = tiktoken.encoding_for_model(req_params['requested_model'])
|
||||
new_logit_bias = {}
|
||||
for logit, bias in logit_bias.items():
|
||||
for x in encode(encoder.decode([int(logit)]))[0]:
|
||||
for x in encode(encoder.decode([int(logit)]), add_special_tokens=False)[0]:
|
||||
if int(x) in [0, 1, 2, 29871]: # XXX LLAMA tokens
|
||||
continue
|
||||
new_logit_bias[str(int(x))] = bias
|
||||
print(logit_bias, '->', new_logit_bias)
|
||||
debug_msg('logit_bias_map', logit_bias, '->', new_logit_bias)
|
||||
logit_bias = new_logit_bias
|
||||
except KeyError:
|
||||
pass # assume native tokens if we can't find the tokenizer
|
||||
@ -134,11 +151,11 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
||||
messages = body['messages']
|
||||
|
||||
role_formats = {
|
||||
'user': 'user: {message}\n',
|
||||
'assistant': 'assistant: {message}\n',
|
||||
'user': 'User: {message}\n',
|
||||
'assistant': 'Assistant: {message}\n',
|
||||
'system': '{message}',
|
||||
'context': 'You are a helpful assistant. Answer as concisely as possible.',
|
||||
'prompt': 'assistant:',
|
||||
'context': 'You are a helpful assistant. Answer as concisely as possible.\nUser: I want your assistance.\nAssistant: Sure! What can I do for you?',
|
||||
'prompt': 'Assistant:',
|
||||
}
|
||||
|
||||
if not 'stopping_strings' in req_params:
|
||||
@ -151,10 +168,10 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
||||
|
||||
template = instruct['turn_template']
|
||||
system_message_template = "{message}"
|
||||
system_message_default = instruct['context']
|
||||
system_message_default = instruct.get('context', '') # can be missing
|
||||
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
|
||||
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user'])
|
||||
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot'])
|
||||
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct.get('user', ''))
|
||||
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct.get('bot', ''))
|
||||
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
|
||||
|
||||
role_formats = {
|
||||
@ -173,13 +190,13 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
||||
debug_msg(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
||||
|
||||
except Exception as e:
|
||||
req_params['stopping_strings'].extend(['\nuser:'])
|
||||
req_params['stopping_strings'].extend(['\nUser:', 'User:']) # XXX User: prompt here also
|
||||
|
||||
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
||||
print("Warning: Loaded default instruction-following template for model.")
|
||||
|
||||
else:
|
||||
req_params['stopping_strings'].extend(['\nuser:'])
|
||||
req_params['stopping_strings'].extend(['\nUser:', 'User:']) # XXX User: prompt here also
|
||||
print("Warning: Loaded default instruction-following template for model.")
|
||||
|
||||
system_msgs = []
|
||||
@ -194,6 +211,11 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
||||
context_msg = end_line(role_formats['system'].format(message=body['prompt'])) + context_msg
|
||||
|
||||
for m in messages:
|
||||
if 'role' not in m:
|
||||
raise InvalidRequestError(message="messages: missing role", param='messages')
|
||||
if 'content' not in m:
|
||||
raise InvalidRequestError(message="messages: missing content", param='messages')
|
||||
|
||||
role = m['role']
|
||||
content = m['content']
|
||||
# name = m.get('name', None)
|
||||
@ -215,12 +237,12 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
||||
|
||||
if token_count >= req_params['truncation_length']:
|
||||
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens."
|
||||
raise InvalidRequestError(message=err_msg)
|
||||
raise InvalidRequestError(message=err_msg, param='messages')
|
||||
|
||||
if max_tokens > 0 and token_count + max_tokens > req_params['truncation_length']:
|
||||
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens and max_tokens is {max_tokens}."
|
||||
print(f"Warning: ${err_msg}")
|
||||
# raise InvalidRequestError(message=err_msg)
|
||||
# raise InvalidRequestError(message=err_msg, params='max_tokens')
|
||||
|
||||
return prompt, token_count
|
||||
|
||||
@ -251,6 +273,10 @@ def chat_completions(body: dict, is_legacy: bool = False) -> dict:
|
||||
# format the prompt from messages
|
||||
prompt, token_count = messages_to_prompt(body, req_params, max_tokens)
|
||||
|
||||
# set real max, avoid deeper errors
|
||||
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
|
||||
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
@ -267,7 +293,7 @@ def chat_completions(body: dict, is_legacy: bool = False) -> dict:
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']:
|
||||
stop_reason = "length"
|
||||
|
||||
resp = {
|
||||
@ -323,6 +349,10 @@ def stream_chat_completions(body: dict, is_legacy: bool = False):
|
||||
# format the prompt from messages
|
||||
prompt, token_count = messages_to_prompt(body, req_params, max_tokens)
|
||||
|
||||
# set real max, avoid deeper errors
|
||||
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
|
||||
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
||||
|
||||
def chat_streaming_chunk(content):
|
||||
# begin streaming
|
||||
chunk = {
|
||||
@ -352,7 +382,6 @@ def stream_chat_completions(body: dict, is_legacy: bool = False):
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
@ -375,13 +404,17 @@ def stream_chat_completions(body: dict, is_legacy: bool = False):
|
||||
if len_seen == 0 and new_content[0] == ' ':
|
||||
new_content = new_content[1:]
|
||||
|
||||
completion_token_count += len(encode(new_content)[0])
|
||||
chunk = chat_streaming_chunk(new_content)
|
||||
|
||||
yield chunk
|
||||
|
||||
# to get the correct token_count, strip leading space if present
|
||||
if answer and answer[0] == ' ':
|
||||
answer = answer[1:]
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']:
|
||||
stop_reason = "length"
|
||||
|
||||
chunk = chat_streaming_chunk('')
|
||||
@ -413,7 +446,7 @@ def completions(body: dict, is_legacy: bool = False):
|
||||
if prompt and isinstance(prompt[0], int):
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(requested_model)
|
||||
prompt = encode(encoder.decode(prompt))[0]
|
||||
prompt = encoder.decode(prompt)
|
||||
except KeyError:
|
||||
prompt = decode(prompt)[0]
|
||||
else:
|
||||
@ -441,7 +474,6 @@ def completions(body: dict, is_legacy: bool = False):
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
answer = ''
|
||||
@ -475,7 +507,7 @@ def completions(body: dict, is_legacy: bool = False):
|
||||
}
|
||||
}
|
||||
|
||||
if logprob_proc:
|
||||
if logprob_proc and logprob_proc.token_alternatives:
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
else:
|
||||
@ -504,7 +536,7 @@ def stream_completions(body: dict, is_legacy: bool = False):
|
||||
if prompt and isinstance(prompt[0], int):
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(requested_model)
|
||||
prompt = encode(encoder.decode(prompt))[0]
|
||||
prompt = encoder.decode(prompt)
|
||||
except KeyError:
|
||||
prompt = decode(prompt)[0]
|
||||
else:
|
||||
@ -579,9 +611,13 @@ def stream_completions(body: dict, is_legacy: bool = False):
|
||||
|
||||
chunk = text_streaming_chunk(new_content)
|
||||
|
||||
completion_token_count += len(encode(new_content)[0])
|
||||
yield chunk
|
||||
|
||||
# to get the correct count, we strip the leading space if present
|
||||
if answer and answer[0] == ' ':
|
||||
answer = answer[1:]
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
@ -46,8 +46,6 @@ def get_default_req_params():
|
||||
return copy.deepcopy(default_req_params)
|
||||
|
||||
# little helper to get defaults if arg is present but None and should be the same type as default.
|
||||
|
||||
|
||||
def default(dic, key, default):
|
||||
val = dic.get(key, default)
|
||||
if type(val) != type(default):
|
||||
|
@ -1,43 +1,54 @@
|
||||
import os
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import numpy as np
|
||||
from extensions.openai.utils import float_list_to_base64, debug_msg
|
||||
from extensions.openai.errors import *
|
||||
|
||||
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
||||
embeddings_model = None
|
||||
# OPENEDAI_EMBEDDING_DEVICE: auto (best or cpu), cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone
|
||||
embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", "cpu")
|
||||
if embeddings_device.lower() == 'auto':
|
||||
embeddings_device = None
|
||||
|
||||
|
||||
def load_embedding_model(model):
|
||||
def load_embedding_model(model: str) -> SentenceTransformer:
|
||||
global embeddings_device, embeddings_model
|
||||
try:
|
||||
emb_model = SentenceTransformer(model)
|
||||
print(f"\nLoaded embedding model: {model}, max sequence length: {emb_model.max_seq_length}")
|
||||
embeddings_model = 'loading...' # flag
|
||||
# see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer
|
||||
emb_model = SentenceTransformer(model, device=embeddings_device)
|
||||
# ... emb_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM
|
||||
print(f"\nLoaded embedding model: {model} on {emb_model.device} [always seems to say 'cpu', even if 'cuda'], max sequence length: {emb_model.max_seq_length}")
|
||||
except Exception as e:
|
||||
print(f"\nError: Failed to load embedding model: {model}")
|
||||
embeddings_model = None
|
||||
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
|
||||
|
||||
return emb_model
|
||||
|
||||
|
||||
def get_embeddings_model():
|
||||
def get_embeddings_model() -> SentenceTransformer:
|
||||
global embeddings_model, st_model
|
||||
if st_model and not embeddings_model:
|
||||
embeddings_model = load_embedding_model(st_model) # lazy load the model
|
||||
return embeddings_model
|
||||
|
||||
|
||||
def get_embeddings_model_name():
|
||||
def get_embeddings_model_name() -> str:
|
||||
global st_model
|
||||
return st_model
|
||||
|
||||
|
||||
def embeddings(input: list, encoding_format: str):
|
||||
def get_embeddings(input: list) -> np.ndarray:
|
||||
return get_embeddings_model().encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False, device=embeddings_device)
|
||||
|
||||
embeddings = get_embeddings_model().encode(input).tolist()
|
||||
def embeddings(input: list, encoding_format: str) -> dict:
|
||||
|
||||
embeddings = get_embeddings(input)
|
||||
|
||||
if encoding_format == "base64":
|
||||
data = [{"object": "embedding", "embedding": float_list_to_base64(emb), "index": n} for n, emb in enumerate(embeddings)]
|
||||
else:
|
||||
data = [{"object": "embedding", "embedding": emb, "index": n} for n, emb in enumerate(embeddings)]
|
||||
data = [{"object": "embedding", "embedding": emb.tolist(), "index": n} for n, emb in enumerate(embeddings)]
|
||||
|
||||
response = {
|
||||
"object": "list",
|
||||
|
@ -13,8 +13,8 @@ class OpenAIError(Exception):
|
||||
|
||||
|
||||
class InvalidRequestError(OpenAIError):
|
||||
def __init__(self, message, param, code=400, error_type='InvalidRequestError', internal_message=''):
|
||||
super(OpenAIError, self).__init__(message, code, error_type, internal_message)
|
||||
def __init__(self, message, param, code=400, internal_message=''):
|
||||
super().__init__(message, code, internal_message)
|
||||
self.param = param
|
||||
|
||||
def __repr__(self):
|
||||
@ -27,5 +27,5 @@ class InvalidRequestError(OpenAIError):
|
||||
|
||||
|
||||
class ServiceUnavailableError(OpenAIError):
|
||||
def __init__(self, message=None, code=500, error_type='ServiceUnavailableError', internal_message=''):
|
||||
super(OpenAIError, self).__init__(message, code, error_type, internal_message)
|
||||
def __init__(self, message="Service unavailable, please try again later.", code=503, internal_message=''):
|
||||
super().__init__(message, code, internal_message)
|
||||
|
@ -9,12 +9,16 @@ def generations(prompt: str, size: str, response_format: str, n: int):
|
||||
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
|
||||
# the results will be limited and likely poor. SD has hundreds of models and dozens of settings.
|
||||
# If you want high quality tailored results you should just use the Stable Diffusion API directly.
|
||||
# it's too general an API to try and shape the result with specific tags like "masterpiece", etc,
|
||||
# Will probably work best with the stock SD models.
|
||||
# SD configuration is beyond the scope of this API.
|
||||
# it's too general an API to try and shape the result with specific tags like negative prompts
|
||||
# or "masterpiece", etc. SD configuration is beyond the scope of this API.
|
||||
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
|
||||
# require changing the form data handling to accept multipart form data, also to properly support
|
||||
# url return types will require file management and a web serving files... Perhaps later!
|
||||
base_model_size = 512 if not 'SD_BASE_MODEL_SIZE' in os.environ else int(os.environ.get('SD_BASE_MODEL_SIZE', 512))
|
||||
sd_defaults = {
|
||||
'sampler_name': 'DPM++ 2M Karras', # vast improvement
|
||||
'steps': 30,
|
||||
}
|
||||
|
||||
width, height = [int(x) for x in size.split('x')] # ignore the restrictions on size
|
||||
|
||||
@ -24,8 +28,21 @@ def generations(prompt: str, size: str, response_format: str, n: int):
|
||||
'width': width,
|
||||
'height': height,
|
||||
'batch_size': n,
|
||||
'restore_faces': True, # slightly less horrible
|
||||
}
|
||||
payload.update(sd_defaults)
|
||||
|
||||
scale = min(width, height) / base_model_size
|
||||
if scale >= 1.2:
|
||||
# for better performance with the default size (1024), and larger res.
|
||||
scaler = {
|
||||
'width': width // scale,
|
||||
'height': height // scale,
|
||||
'hr_scale': scale,
|
||||
'enable_hr': True,
|
||||
'hr_upscaler': 'Latent',
|
||||
'denoising_strength': 0.68,
|
||||
}
|
||||
payload.update(scaler)
|
||||
|
||||
resp = {
|
||||
'created': int(time.time()),
|
||||
@ -38,7 +55,8 @@ def generations(prompt: str, size: str, response_format: str, n: int):
|
||||
response = requests.post(url=sd_url, json=payload)
|
||||
r = response.json()
|
||||
if response.status_code != 200 or 'images' not in r:
|
||||
raise ServiceUnavailableError(r.get('detail', [{'msg': 'Unknown error calling Stable Diffusion'}])[0]['msg'], code=response.status_code)
|
||||
print(r)
|
||||
raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code, internal_message=r.get('errors',None))
|
||||
# r['parameters']...
|
||||
for b64_json in r['images']:
|
||||
if response_format == 'b64_json':
|
||||
|
@ -1,7 +1,7 @@
|
||||
import time
|
||||
import numpy as np
|
||||
from numpy.linalg import norm
|
||||
from extensions.openai.embeddings import get_embeddings_model
|
||||
from extensions.openai.embeddings import get_embeddings
|
||||
|
||||
|
||||
moderations_disabled = False # return 0/false
|
||||
@ -11,21 +11,21 @@ categories = ["sexual", "hate", "harassment", "self-harm", "sexual/minors", "hat
|
||||
flag_threshold = 0.5
|
||||
|
||||
|
||||
def get_category_embeddings():
|
||||
def get_category_embeddings() -> dict:
|
||||
global category_embeddings, categories
|
||||
if category_embeddings is None:
|
||||
embeddings = get_embeddings_model().encode(categories).tolist()
|
||||
embeddings = get_embeddings(categories).tolist()
|
||||
category_embeddings = dict(zip(categories, embeddings))
|
||||
|
||||
return category_embeddings
|
||||
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
return np.dot(a, b) / (norm(a) * norm(b))
|
||||
|
||||
|
||||
# seems most openai like with all-mpnet-base-v2
|
||||
def mod_score(a, b):
|
||||
def mod_score(a: np.ndarray, b: np.ndarray) -> float:
|
||||
return 2.0 * np.dot(a, b)
|
||||
|
||||
|
||||
@ -37,8 +37,7 @@ def moderations(input):
|
||||
"results": [],
|
||||
}
|
||||
|
||||
embeddings_model = get_embeddings_model()
|
||||
if not embeddings_model or moderations_disabled:
|
||||
if moderations_disabled:
|
||||
results['results'] = [{
|
||||
'categories': dict([(C, False) for C in categories]),
|
||||
'category_scores': dict([(C, 0.0) for C in categories]),
|
||||
@ -53,7 +52,7 @@ def moderations(input):
|
||||
input = [input]
|
||||
|
||||
for in_str in input:
|
||||
for ine in embeddings_model.encode([in_str]).tolist():
|
||||
for ine in get_embeddings([in_str]):
|
||||
category_scores = dict([(C, mod_score(category_embeddings[C], ine)) for C in categories])
|
||||
category_flags = dict([(C, bool(category_scores[C] > flag_threshold)) for C in categories])
|
||||
flagged = any(category_flags.values())
|
||||
|
@ -55,11 +55,13 @@ class Handler(BaseHTTPRequestHandler):
|
||||
|
||||
def send_sse(self, chunk: dict):
|
||||
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
||||
debug_msg(response)
|
||||
debug_msg(response[:-4])
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
def end_sse(self):
|
||||
self.wfile.write('data: [DONE]\r\n\r\n'.encode('utf-8'))
|
||||
response = 'data: [DONE]\r\n\r\n'
|
||||
debug_msg(response[:-4])
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
def return_json(self, ret: dict, code: int = 200, no_debug=False):
|
||||
self.send_response(code)
|
||||
@ -84,6 +86,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
}
|
||||
}
|
||||
if internal_message:
|
||||
print(error_type, message)
|
||||
print(internal_message)
|
||||
# error_resp['internal_message'] = internal_message
|
||||
|
||||
@ -93,12 +96,10 @@ class Handler(BaseHTTPRequestHandler):
|
||||
def wrapper(self):
|
||||
try:
|
||||
func(self)
|
||||
except ServiceUnavailableError as e:
|
||||
self.openai_error(e.message, e.code, e.error_type, internal_message=e.internal_message)
|
||||
except InvalidRequestError as e:
|
||||
self.openai_error(e.message, e.code, e.error_type, e.param, internal_message=e.internal_message)
|
||||
self.openai_error(e.message, e.code, e.__class__.__name__, e.param, internal_message=e.internal_message)
|
||||
except OpenAIError as e:
|
||||
self.openai_error(e.message, e.code, e.error_type, internal_message=e.internal_message)
|
||||
self.openai_error(e.message, e.code, e.__class__.__name__, internal_message=e.internal_message)
|
||||
except Exception as e:
|
||||
self.openai_error(repr(e), 500, 'OpenAIError', internal_message=traceback.format_exc())
|
||||
|
||||
@ -143,8 +144,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
if '/completions' in self.path or '/generate' in self.path:
|
||||
|
||||
if not shared.model:
|
||||
self.openai_error("No model loaded.")
|
||||
return
|
||||
raise ServiceUnavailableError("No model loaded.")
|
||||
|
||||
is_legacy = '/generate' in self.path
|
||||
is_streaming = body.get('stream', False)
|
||||
@ -176,8 +176,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
# deprecated
|
||||
|
||||
if not shared.model:
|
||||
self.openai_error("No model loaded.")
|
||||
return
|
||||
raise ServiceUnavailableError("No model loaded.")
|
||||
|
||||
req_params = get_default_req_params()
|
||||
|
||||
@ -190,7 +189,10 @@ class Handler(BaseHTTPRequestHandler):
|
||||
|
||||
self.return_json(response)
|
||||
|
||||
elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ:
|
||||
elif '/images/generations' in self.path:
|
||||
if not 'SD_WEBUI_URL' in os.environ:
|
||||
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
|
||||
|
||||
prompt = body['prompt']
|
||||
size = default(body, 'size', '1024x1024')
|
||||
response_format = default(body, 'response_format', 'url') # or b64_json
|
||||
@ -256,11 +258,11 @@ def run_server():
|
||||
try:
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
public_url = _run_cloudflared(params['port'], params['port'] + 1)
|
||||
print(f'Starting OpenAI compatible api at\nOPENAI_API_BASE={public_url}/v1')
|
||||
print(f'OpenAI compatible API ready at: OPENAI_API_BASE={public_url}/v1')
|
||||
except ImportError:
|
||||
print('You should install flask_cloudflared manually')
|
||||
else:
|
||||
print(f'Starting OpenAI compatible api:\nOPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')
|
||||
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')
|
||||
|
||||
server.serve_forever()
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from extensions.openai.utils import float_list_to_base64
|
||||
from modules.text_generation import encode, decode
|
||||
|
||||
import numpy as np
|
||||
|
||||
def token_count(prompt):
|
||||
tokens = encode(prompt)[0]
|
||||
@ -12,14 +12,13 @@ def token_count(prompt):
|
||||
}
|
||||
|
||||
|
||||
def token_encode(input, encoding_format=''):
|
||||
def token_encode(input, encoding_format):
|
||||
# if isinstance(input, list):
|
||||
tokens = encode(input)[0]
|
||||
|
||||
return {
|
||||
'results': [{
|
||||
'encoding_format': encoding_format,
|
||||
'tokens': float_list_to_base64(tokens) if encoding_format == "base64" else tokens,
|
||||
'tokens': tokens,
|
||||
'length': len(tokens),
|
||||
}]
|
||||
}
|
||||
|
@ -3,9 +3,9 @@ import base64
|
||||
import numpy as np
|
||||
|
||||
|
||||
def float_list_to_base64(float_list):
|
||||
def float_list_to_base64(float_array: np.ndarray) -> str:
|
||||
# Convert the list to a float32 array that the OpenAPI client expects
|
||||
float_array = np.array(float_list, dtype="float32")
|
||||
#float_array = np.array(float_list, dtype="float32")
|
||||
|
||||
# Get raw bytes
|
||||
bytes_array = float_array.tobytes()
|
||||
|
Loading…
Reference in New Issue
Block a user