Make OpenAI API the default API (#4430)

This commit is contained in:
oobabooga 2023-11-06 02:38:29 -03:00 committed by GitHub
parent 84d957ba62
commit ec17a5d2b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 769 additions and 1432 deletions

View File

@ -22,7 +22,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
* [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab#character) * [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab#character)
* Very efficient text streaming * Very efficient text streaming
* Markdown output with LaTeX rendering, to use for instance with [GALACTICA](https://github.com/paperswithcode/galai) * Markdown output with LaTeX rendering, to use for instance with [GALACTICA](https://github.com/paperswithcode/galai)
* API, including endpoints for websocket streaming ([see the examples](https://github.com/oobabooga/text-generation-webui/blob/main/api-examples)) * OpenAI-compatible API server
## Documentation ## Documentation
@ -412,8 +412,8 @@ Optionally, you can use the following command-line flags:
| `--api` | Enable the API extension. | | `--api` | Enable the API extension. |
| `--public-api` | Create a public URL for the API using Cloudfare. | | `--public-api` | Create a public URL for the API using Cloudfare. |
| `--public-api-id PUBLIC_API_ID` | Tunnel ID for named Cloudflare Tunnel. Use together with public-api option. | | `--public-api-id PUBLIC_API_ID` | Tunnel ID for named Cloudflare Tunnel. Use together with public-api option. |
| `--api-blocking-port BLOCKING_PORT` | The listening port for the blocking API. | | `--api-port API_PORT` | The listening port for the API. |
| `--api-streaming-port STREAMING_PORT` | The listening port for the streaming API. | | `--api-key API_KEY` | API authentication key. |
#### Multimodal #### Multimodal

View File

@ -1,114 +0,0 @@
import asyncio
import html
import json
import sys
try:
import websockets
except ImportError:
print("Websockets package not found. Make sure it's installed.")
# For local streaming, the websockets are hosted without ssl - ws://
HOST = 'localhost:5005'
URI = f'ws://{HOST}/api/v1/chat-stream'
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
async def run(user_input, history):
# Note: the selected defaults change from time to time.
request = {
'user_input': user_input,
'max_new_tokens': 250,
'auto_max_new_tokens': False,
'max_tokens_second': 0,
'history': history,
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
'character': 'Example',
'instruction_template': 'Vicuna-v1.1', # Will get autodetected if unset
'your_name': 'You',
# 'name1': 'name of user', # Optional
# 'name2': 'name of character', # Optional
# 'context': 'character context', # Optional
# 'greeting': 'greeting', # Optional
# 'name1_instruct': 'You', # Optional
# 'name2_instruct': 'Assistant', # Optional
# 'context_instruct': 'context_instruct', # Optional
# 'turn_template': 'turn_template', # Optional
'regenerate': False,
'_continue': False,
'chat_instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,
'typical_p': 1,
'epsilon_cutoff': 0, # In units of 1e-4
'eta_cutoff': 0, # In units of 1e-4
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
'presence_penalty': 0,
'frequency_penalty': 0,
'repetition_penalty_range': 0,
'top_k': 40,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'grammar_string': '',
'guidance_scale': 1,
'negative_prompt': '',
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,
'ban_eos_token': False,
'custom_token_bans': '',
'skip_special_tokens': True,
'stopping_strings': []
}
async with websockets.connect(URI, ping_interval=None) as websocket:
await websocket.send(json.dumps(request))
while True:
incoming_data = await websocket.recv()
incoming_data = json.loads(incoming_data)
match incoming_data['event']:
case 'text_stream':
yield incoming_data['history']
case 'stream_end':
return
async def print_response_stream(user_input, history):
cur_len = 0
async for new_history in run(user_input, history):
cur_message = new_history['visible'][-1][1][cur_len:]
cur_len += len(cur_message)
print(html.unescape(cur_message), end='')
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
if __name__ == '__main__':
user_input = "Please give me a step-by-step guide on how to plant a tree in my backyard."
# Basic example
history = {'internal': [], 'visible': []}
# "Continue" example. Make sure to set '_continue' to True above
# arr = [user_input, 'Surely, here is']
# history = {'internal': [arr], 'visible': [arr]}
asyncio.run(print_response_stream(user_input, history))

View File

@ -1,94 +0,0 @@
import html
import json
import requests
# For local streaming, the websockets are hosted without ssl - http://
HOST = 'localhost:5000'
URI = f'http://{HOST}/api/v1/chat'
# For reverse-proxied streaming, the remote will likely host with ssl - https://
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/chat'
def run(user_input, history):
request = {
'user_input': user_input,
'max_new_tokens': 250,
'auto_max_new_tokens': False,
'max_tokens_second': 0,
'history': history,
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
'character': 'Example',
'instruction_template': 'Vicuna-v1.1', # Will get autodetected if unset
'your_name': 'You',
# 'name1': 'name of user', # Optional
# 'name2': 'name of character', # Optional
# 'context': 'character context', # Optional
# 'greeting': 'greeting', # Optional
# 'name1_instruct': 'You', # Optional
# 'name2_instruct': 'Assistant', # Optional
# 'context_instruct': 'context_instruct', # Optional
# 'turn_template': 'turn_template', # Optional
'regenerate': False,
'_continue': False,
'chat_instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,
'typical_p': 1,
'epsilon_cutoff': 0, # In units of 1e-4
'eta_cutoff': 0, # In units of 1e-4
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
'presence_penalty': 0,
'frequency_penalty': 0,
'repetition_penalty_range': 0,
'top_k': 40,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'grammar_string': '',
'guidance_scale': 1,
'negative_prompt': '',
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,
'ban_eos_token': False,
'custom_token_bans': '',
'skip_special_tokens': True,
'stopping_strings': []
}
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()['results'][0]['history']
print(json.dumps(result, indent=4))
print()
print(html.unescape(result['visible'][-1][1]))
if __name__ == '__main__':
user_input = "Please give me a step-by-step guide on how to plant a tree in my backyard."
# Basic example
history = {'internal': [], 'visible': []}
# "Continue" example. Make sure to set '_continue' to True above
# arr = [user_input, 'Surely, here is']
# history = {'internal': [arr], 'visible': [arr]}
run(user_input, history)

View File

@ -1,88 +0,0 @@
import asyncio
import json
import sys
try:
import websockets
except ImportError:
print("Websockets package not found. Make sure it's installed.")
# For local streaming, the websockets are hosted without ssl - ws://
HOST = 'localhost:5005'
URI = f'ws://{HOST}/api/v1/stream'
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
async def run(context):
# Note: the selected defaults change from time to time.
request = {
'prompt': context,
'max_new_tokens': 250,
'auto_max_new_tokens': False,
'max_tokens_second': 0,
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,
'typical_p': 1,
'epsilon_cutoff': 0, # In units of 1e-4
'eta_cutoff': 0, # In units of 1e-4
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
'presence_penalty': 0,
'frequency_penalty': 0,
'repetition_penalty_range': 0,
'top_k': 40,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'grammar_string': '',
'guidance_scale': 1,
'negative_prompt': '',
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,
'ban_eos_token': False,
'custom_token_bans': '',
'skip_special_tokens': True,
'stopping_strings': []
}
async with websockets.connect(URI, ping_interval=None) as websocket:
await websocket.send(json.dumps(request))
yield context # Remove this if you just want to see the reply
while True:
incoming_data = await websocket.recv()
incoming_data = json.loads(incoming_data)
match incoming_data['event']:
case 'text_stream':
yield incoming_data['text']
case 'stream_end':
return
async def print_response_stream(prompt):
async for response in run(prompt):
print(response, end='')
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
if __name__ == '__main__':
prompt = "In order to make homemade bread, follow these steps:\n1)"
asyncio.run(print_response_stream(prompt))

View File

@ -1,65 +0,0 @@
import requests
# For local streaming, the websockets are hosted without ssl - http://
HOST = 'localhost:5000'
URI = f'http://{HOST}/api/v1/generate'
# For reverse-proxied streaming, the remote will likely host with ssl - https://
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate'
def run(prompt):
request = {
'prompt': prompt,
'max_new_tokens': 250,
'auto_max_new_tokens': False,
'max_tokens_second': 0,
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,
'typical_p': 1,
'epsilon_cutoff': 0, # In units of 1e-4
'eta_cutoff': 0, # In units of 1e-4
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
'presence_penalty': 0,
'frequency_penalty': 0,
'repetition_penalty_range': 0,
'top_k': 40,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'grammar_string': '',
'guidance_scale': 1,
'negative_prompt': '',
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,
'ban_eos_token': False,
'custom_token_bans': '',
'skip_special_tokens': True,
'stopping_strings': []
}
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()['results'][0]['text']
print(prompt + result)
if __name__ == '__main__':
prompt = "In order to make homemade bread, follow these steps:\n1)"
run(prompt)

View File

@ -1,124 +1,64 @@
# An OpenedAI API (openai like) ## OpenAI compatible API
This extension creates an API that works kind of like openai (ie. api.openai.com). This project includes an API compatible with multiple OpenAI endpoints, including Chat and Completions.
## Setup & installation If you did not use the one-click installers, you may need to install the requirements first:
Install the requirements:
``` ```
pip3 install -r requirements.txt pip install -r extensions/openai/requirements.txt
``` ```
It listens on `tcp port 5001` by default. You can use the `OPENEDAI_PORT` environment variable to change this. ### Starting the API
Make sure you enable it in server launch parameters, it should include: Add `--extensions openai` to your command-line flags.
* To create a public Cloudflare URL, add the `--public-api` flag.
* To listen on your local network, add the `--listen` flag.
* To change the port, which is 5000 by default, use `--port 1234` (change 1234 to your desired port number).
* To use SSL, add `--ssl-keyfile key.pem --ssl-certfile cert.pem`. Note that it doesn't work with `--public-api`.
#### Environment variables
The following environment variables can be used (they take precendence over everything else):
| Variable Name | Description | Example Value |
|------------------------|------------------------------------|----------------------------|
| `OPENEDAI_PORT` | Port number | 5000 |
| `OPENEDAI_CERT_PATH` | SSL certificate file path | cert.pem |
| `OPENEDAI_KEY_PATH` | SSL key file path | key.pem |
| `OPENEDAI_DEBUG` | Enable debugging (set to 1) | 1 |
| `SD_WEBUI_URL` | WebUI URL (used by endpoint) | http://127.0.0.1:7861 |
| `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | all-mpnet-base-v2 |
| `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda |
#### Persistent settings with `settings.yaml`
You can also set default values by adding these lines to your `settings.yaml` file:
``` ```
--extensions openai
```
You can also use the `--listen` argument to make the server available on the networ, and/or the `--share` argument to enable a public Cloudflare endpoint.
To enable the basic image generation support (txt2img) set the environment variable `SD_WEBUI_URL` to point to your Stable Diffusion API ([Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui)).
For example:
```
SD_WEBUI_URL=http://127.0.0.1:7861
```
## Quick start
1. Install the requirements.txt (pip)
2. Enable the `openeai` module (--extensions openai), restart the server.
3. Configure the openai client
Most openai application can be configured to connect the API if you set the following environment variables:
```shell
# Sample .env file:
OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
OPENAI_API_BASE=http://0.0.0.0:5001/v1
```
If needed, replace 0.0.0.0 with the IP/port of your server.
### Settings
To adjust your default settings, you can add the following to your `settings.yaml` file.
```
openai-port: 5002
openai-embedding_device: cuda openai-embedding_device: cuda
openai-embedding_model: all-mpnet-base-v2
openai-sd_webui_url: http://127.0.0.1:7861 openai-sd_webui_url: http://127.0.0.1:7861
openai-debug: 1 openai-debug: 1
``` ```
If you've configured the environment variables, please note that settings from `settings.yaml` won't take effect. For instance, if you set `openai-port: 5002` in `settings.yaml` but `OPENEDAI_PORT=5001` in the environment variables, the extension will use `5001` as the port number. ### Examples
When using `cache_embedding_model.py` to preload the embedding model during Docker image building, consider the following:
- If you wish to use the default settings, leave the environment variables unset.
- If you intend to change the default embedding model, ensure that you configure the environment variable `OPENEDAI_EMBEDDING_MODEL` to the desired model. Avoid setting `openai-embedding_model` in `settings.yaml` because those settings only take effect after the server starts.
### Models
This has been successfully tested with Alpaca, Koala, Vicuna, WizardLM and their variants, (ex. gpt4-x-alpaca, GPT4all-snoozy, stable-vicuna, wizard-vicuna, etc.) and many others. Models that have been trained for **Instruction Following** work best. If you test with other models please let me know how it goes. Less than satisfying results (so far) from: RWKV-4-Raven, llama, mpt-7b-instruct/chat.
For best results across all API endpoints, a model like [vicuna-13b-v1.3-GPTQ](https://huggingface.co/TheBloke/vicuna-13b-v1.3-GPTQ), [stable-vicuna-13B-GPTQ](https://huggingface.co/TheBloke/stable-vicuna-13B-GPTQ) or [airoboros-13B-gpt4-1.3-GPTQ](https://huggingface.co/TheBloke/airoboros-13B-gpt4-1.3-GPTQ) is a good start.
For good results with the [Completions](https://platform.openai.com/docs/api-reference/completions) API endpoint, in addition to the above models, you can also try using a base model like [falcon-7b](https://huggingface.co/tiiuae/falcon-7b) or Llama.
For good results with the [ChatCompletions](https://platform.openai.com/docs/api-reference/chat) or [Edits](https://platform.openai.com/docs/api-reference/edits) API endpoints you can use almost any model trained for instruction following. Be sure that the proper instruction template is detected and loaded or the results will not be good.
For the proper instruction format to be detected you need to have a matching model entry in your `models/config.yaml` file. Be sure to keep this file up to date.
A matching instruction template file in the characters/instruction-following/ folder will loaded and applied to format messages correctly for the model - this is critical for good results.
For example, the Wizard-Vicuna family of models are trained with the Vicuna 1.1 format. In the models/config.yaml file there is this matching entry:
```
.*wizard.*vicuna:
mode: 'instruct'
instruction_template: 'Vicuna-v1.1'
```
This refers to `characters/instruction-following/Vicuna-v1.1.yaml`, which looks like this:
```
user: "USER:"
bot: "ASSISTANT:"
turn_template: "<|user|> <|user-message|>\n<|bot|> <|bot-message|></s>\n"
context: "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n"
```
For most common models this is already setup, but if you are using a new or uncommon model you may need add a matching entry to the models/config.yaml and possibly create your own instruction-following template and for best results.
If you see this in your logs, it probably means that the correct format could not be loaded:
```
Warning: Loaded default instruction-following template for model.
```
### Embeddings (alpha)
Embeddings requires `sentence-transformers` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: `sentence-transformers/all-mpnet-base-v2` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default `text-embedding-ada-002` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future.
| model name | dimensions | input max tokens | speed | size | Avg. performance |
| ---------------------- | ---------- | ---------------- | ----- | ---- | ---------------- |
| text-embedding-ada-002 | 1536 | 8192 | - | - | - |
| text-davinci-002 | 768 | 2046 | - | - | - |
| all-mpnet-base-v2 | 768 | 384 | 2800 | 420M | 63.3 |
| all-MiniLM-L6-v2 | 384 | 256 | 14200 | 80M | 58.8 |
In short, the all-MiniLM-L6-v2 model is 5x faster, 5x smaller ram, 2x smaller storage, and still offers good quality. Stats from (https://www.sbert.net/docs/pretrained_models.html). To change the model from the default you can set the environment variable `OPENEDAI_EMBEDDING_MODEL`, ex. "OPENEDAI_EMBEDDING_MODEL=all-MiniLM-L6-v2".
Warning: You cannot mix embeddings from different models even if they have the same dimensions. They are not comparable.
### Client Application Setup ### Client Application Setup
Almost everything you use it with will require you to set a dummy OpenAI API key environment variable.
You can usually force an application that uses the OpenAI API to connect to the local API by using the following environment variables:
```shell
OPENAI_API_HOST=http://127.0.0.1:5000
```
or
```shell
OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
OPENAI_API_BASE=http://127.0.0.1:500/v1
```
With the [official python openai client](https://github.com/openai/openai-python), set the `OPENAI_API_BASE` environment variables: With the [official python openai client](https://github.com/openai/openai-python), set the `OPENAI_API_BASE` environment variables:
@ -128,7 +68,7 @@ OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
OPENAI_API_BASE=http://0.0.0.0:5001/v1 OPENAI_API_BASE=http://0.0.0.0:5001/v1
``` ```
If needed, replace 0.0.0.0 with the IP/port of your server. If needed, replace 127.0.0.1 with the IP/port of your server.
If using .env files to save the `OPENAI_API_BASE` and `OPENAI_API_KEY` variables, make sure the .env file is loaded before the openai module is imported: If using .env files to save the `OPENAI_API_BASE` and `OPENAI_API_KEY` variables, make sure the .env file is loaded before the openai module is imported:
@ -157,8 +97,22 @@ const api = new ChatGPTAPI({
apiBaseUrl: process.env.OPENAI_API_BASE apiBaseUrl: process.env.OPENAI_API_BASE
}); });
``` ```
### Embeddings (alpha)
## API Documentation & Examples Embeddings requires `sentence-transformers` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: `sentence-transformers/all-mpnet-base-v2` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default `text-embedding-ada-002` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future.
| model name | dimensions | input max tokens | speed | size | Avg. performance |
| ---------------------- | ---------- | ---------------- | ----- | ---- | ---------------- |
| text-embedding-ada-002 | 1536 | 8192 | - | - | - |
| text-davinci-002 | 768 | 2046 | - | - | - |
| all-mpnet-base-v2 | 768 | 384 | 2800 | 420M | 63.3 |
| all-MiniLM-L6-v2 | 384 | 256 | 14200 | 80M | 58.8 |
In short, the all-MiniLM-L6-v2 model is 5x faster, 5x smaller ram, 2x smaller storage, and still offers good quality. Stats from (https://www.sbert.net/docs/pretrained_models.html). To change the model from the default you can set the environment variable `OPENEDAI_EMBEDDING_MODEL`, ex. "OPENEDAI_EMBEDDING_MODEL=all-MiniLM-L6-v2".
Warning: You cannot mix embeddings from different models even if they have the same dimensions. They are not comparable.
### API Documentation & Examples
The OpenAI API is well documented, you can view the documentation here: https://platform.openai.com/docs/api-reference The OpenAI API is well documented, you can view the documentation here: https://platform.openai.com/docs/api-reference
@ -185,7 +139,7 @@ text = response['choices'][0]['message']['content']
print(text) print(text)
``` ```
## Compatibility & not so compatibility ### Compatibility & not so compatibility
| API endpoint | tested with | notes | | API endpoint | tested with | notes |
| ------------------------- | ---------------------------------- | --------------------------------------------------------------------------- | | ------------------------- | ---------------------------------- | --------------------------------------------------------------------------- |
@ -195,7 +149,7 @@ print(text)
| /v1/moderations | openai.Moderation.create() | Basic initial support via embeddings | | /v1/moderations | openai.Moderation.create() | Basic initial support via embeddings |
| /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options | | /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options |
| /v1/models/{id} | openai.Model.get() | returns whatever you ask for | | /v1/models/{id} | openai.Model.get() | returns whatever you ask for |
| /v1/edits | openai.Edit.create() | Deprecated by openai, good with instruction following models | | /v1/edits | openai.Edit.create() | Removed, use /v1/chat/completions instead |
| /v1/text_completion | openai.Completion.create() | Legacy endpoint, variable quality based on the model | | /v1/text_completion | openai.Completion.create() | Legacy endpoint, variable quality based on the model |
| /v1/completions | openai api completions.create | Legacy endpoint (v0.25) | | /v1/completions | openai api completions.create | Legacy endpoint (v0.25) |
| /v1/engines/\*/embeddings | python-openai v0.25 | Legacy endpoint | | /v1/engines/\*/embeddings | python-openai v0.25 | Legacy endpoint |
@ -209,28 +163,8 @@ print(text)
| /v1/fine-tunes\* | openai.FineTune.\* | not yet supported | | /v1/fine-tunes\* | openai.FineTune.\* | not yet supported |
| /v1/search | openai.search, engines.search | not yet supported | | /v1/search | openai.search, engines.search | not yet supported |
Because of the differences in OpenAI model context sizes (2k, 4k, 8k, 16k, etc,) you may need to adjust the max_tokens to fit into the context of the model you choose.
Streaming, temperature, top_p, max_tokens, stop, should all work as expected, but not all parameters are mapped correctly. #### Applications
Some hacky mappings:
| OpenAI | text-generation-webui | note |
| ----------------------- | -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| model | - | Ignored, the model is not changed |
| frequency_penalty | encoder_repetition_penalty | this seems to operate with a different scale and defaults, I tried to scale it based on range & defaults, but the results are terrible. hardcoded to 1.18 until there is a better way |
| presence_penalty | repetition_penalty | same issues as frequency_penalty, hardcoded to 1.0 |
| best_of | top_k | default is 1 (top_k is 20 for chat, which doesn't support best_of) |
| n | 1 | variations are not supported yet. |
| 1 | num_beams | hardcoded to 1 |
| 1.0 | typical_p | hardcoded to 1.0 |
| logprobs & logit_bias | - | experimental, llama only, transformers-kin only (ExLlama_HF ok), can also use llama tokens if 'model' is not an openai model or will convert from tiktoken for the openai model specified in 'model' |
| messages.name | - | not supported yet |
| suffix | - | not supported yet |
| user | - | not supported yet |
| functions/function_call | - | function calls are not supported yet |
### Applications
Almost everything needs the `OPENAI_API_KEY` and `OPENAI_API_BASE` environment variable set, but there are some exceptions. Almost everything needs the `OPENAI_API_KEY` and `OPENAI_API_BASE` environment variable set, but there are some exceptions.
@ -249,15 +183,3 @@ Almost everything needs the `OPENAI_API_KEY` and `OPENAI_API_BASE` environment v
| ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context | | ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context |
| ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 | | ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
| ❌ | guidance | https://github.com/microsoft/guidance | logit_bias and logprobs not yet supported | | ❌ | guidance | https://github.com/microsoft/guidance | logit_bias and logprobs not yet supported |
## Future plans
- better error handling
- model changing, esp. something for swapping loras or embedding models
- consider switching to FastAPI + starlette for SSE (openai SSE seems non-standard)
## Bugs? Feedback? Comments? Pull requests?
To enable debugging and get copious output you can set the `OPENEDAI_DEBUG=1` environment variable.
Are all appreciated, please @matatonic and I'll try to get back to you as soon as possible.

View File

@ -3,9 +3,11 @@ import time
import extensions.api.blocking_api as blocking_api import extensions.api.blocking_api as blocking_api
import extensions.api.streaming_api as streaming_api import extensions.api.streaming_api as streaming_api
from modules import shared from modules import shared
from modules.logging_colors import logger
def setup(): def setup():
logger.warning("The current API is deprecated and will be replaced with the OpenAI compatible API on November xxth. To test the new API, use \"--extensions openai\" instead of \"--api\".")
blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id) blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id)
if shared.args.public_api: if shared.args.public_api:
time.sleep(5) time.sleep(5)

View File

@ -1,18 +1,23 @@
import copy
import time import time
from collections import deque
import tiktoken import tiktoken
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import yaml
from extensions.openai.defaults import clamp, default, get_default_req_params
from extensions.openai.errors import InvalidRequestError from extensions.openai.errors import InvalidRequestError
from extensions.openai.utils import debug_msg, end_line from extensions.openai.utils import debug_msg
from modules import shared from modules import shared
from modules.chat import (
generate_chat_prompt,
generate_chat_reply,
load_character_memoized
)
from modules.presets import load_preset_memoized
from modules.text_generation import decode, encode, generate_reply from modules.text_generation import decode, encode, generate_reply
from transformers import LogitsProcessor, LogitsProcessorList from transformers import LogitsProcessor, LogitsProcessorList
# Thanks to @Cypherfox [Cypherfoxy] for the logits code, blame to @matatonic
class LogitsBiasProcessor(LogitsProcessor): class LogitsBiasProcessor(LogitsProcessor):
def __init__(self, logit_bias={}): def __init__(self, logit_bias={}):
self.logit_bias = logit_bias self.logit_bias = logit_bias
@ -28,6 +33,7 @@ class LogitsBiasProcessor(LogitsProcessor):
logits[0, self.keys] += self.values logits[0, self.keys] += self.values
debug_msg(" --> ", logits[0, self.keys]) debug_msg(" --> ", logits[0, self.keys])
debug_msg(" max/min ", float(torch.max(logits[0])), float(torch.min(logits[0]))) debug_msg(" max/min ", float(torch.max(logits[0])), float(torch.min(logits[0])))
return logits return logits
def __repr__(self): def __repr__(self):
@ -47,6 +53,7 @@ class LogprobProcessor(LogitsProcessor):
top_probs = [float(x) for x in top_values[0]] top_probs = [float(x) for x in top_values[0]]
self.token_alternatives = dict(zip(top_tokens, top_probs)) self.token_alternatives = dict(zip(top_tokens, top_probs))
debug_msg(repr(self)) debug_msg(repr(self))
return logits return logits
def __repr__(self): def __repr__(self):
@ -66,43 +73,28 @@ def convert_logprobs_to_tiktoken(model, logprobs):
return logprobs return logprobs
def marshal_common_params(body): def process_parameters(body, is_legacy=False):
# Request Parameters generate_params = body
# Try to use openai defaults or map them to something with the same intent max_tokens_str = 'length' if is_legacy else 'max_tokens'
generate_params['max_new_tokens'] = body.pop(max_tokens_str)
if generate_params['truncation_length'] == 0:
if shared.args.loader and shared.args.loader.lower().startswith('exllama'):
generate_params['truncation_length'] = shared.args.max_seq_len
elif shared.args.loader and shared.args.loader in ['llama.cpp', 'llamacpp_HF', 'ctransformers']:
generate_params['truncation_length'] = shared.args.n_ctx
else:
generate_params['truncation_length'] = shared.settings['truncation_length']
req_params = get_default_req_params() if body['preset'] is not None:
preset = load_preset_memoized(body['preset'])
# Common request parameters generate_params.update(preset)
req_params['truncation_length'] = shared.settings['truncation_length']
req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token'])
req_params['seed'] = shared.settings.get('seed', req_params['seed'])
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
# OpenAI API Parameters
# model - ignored for now, TODO: When we can reliably load a model or lora from a name only change this
req_params['requested_model'] = body.get('model', shared.model_name)
req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.01, 1.99) # fixup absolute 0.0/2.0
req_params['top_p'] = clamp(default(body, 'top_p', req_params['top_p']), 0.01, 1.0)
n = default(body, 'n', 1)
if n != 1:
raise InvalidRequestError(message="Only n = 1 is supported.", param='n')
generate_params['custom_stopping_strings'] = []
if 'stop' in body: # str or array, max len 4 (ignored) if 'stop' in body: # str or array, max len 4 (ignored)
if isinstance(body['stop'], str): if isinstance(body['stop'], str):
req_params['stopping_strings'] = [body['stop']] # non-standard parameter generate_params['custom_stopping_strings'] = [body['stop']]
elif isinstance(body['stop'], list): elif isinstance(body['stop'], list):
req_params['stopping_strings'] = body['stop'] generate_params['custom_stopping_strings'] = body['stop']
# presence_penalty - ignored
# frequency_penalty - ignored
# pass through unofficial params
req_params['repetition_penalty'] = default(body, 'repetition_penalty', req_params['repetition_penalty'])
req_params['encoder_repetition_penalty'] = default(body, 'encoder_repetition_penalty', req_params['encoder_repetition_penalty'])
# user - ignored
logits_processor = [] logits_processor = []
logit_bias = body.get('logit_bias', None) logit_bias = body.get('logit_bias', None)
@ -110,12 +102,13 @@ def marshal_common_params(body):
# XXX convert tokens from tiktoken based on requested model # XXX convert tokens from tiktoken based on requested model
# Ex.: 'logit_bias': {'1129': 100, '11442': 100, '16243': 100} # Ex.: 'logit_bias': {'1129': 100, '11442': 100, '16243': 100}
try: try:
encoder = tiktoken.encoding_for_model(req_params['requested_model']) encoder = tiktoken.encoding_for_model(generate_params['model'])
new_logit_bias = {} new_logit_bias = {}
for logit, bias in logit_bias.items(): for logit, bias in logit_bias.items():
for x in encode(encoder.decode([int(logit)]), add_special_tokens=False)[0]: for x in encode(encoder.decode([int(logit)]), add_special_tokens=False)[0]:
if int(x) in [0, 1, 2, 29871]: # XXX LLAMA tokens if int(x) in [0, 1, 2, 29871]: # XXX LLAMA tokens
continue continue
new_logit_bias[str(int(x))] = bias new_logit_bias[str(int(x))] = bias
debug_msg('logit_bias_map', logit_bias, '->', new_logit_bias) debug_msg('logit_bias_map', logit_bias, '->', new_logit_bias)
logit_bias = new_logit_bias logit_bias = new_logit_bias
@ -126,238 +119,129 @@ def marshal_common_params(body):
logprobs = None # coming to chat eventually logprobs = None # coming to chat eventually
if 'logprobs' in body: if 'logprobs' in body:
logprobs = default(body, 'logprobs', 0) # maybe cap at topk? don't clamp 0-5. logprobs = body.get('logprobs', 0) # maybe cap at topk? don't clamp 0-5.
req_params['logprob_proc'] = LogprobProcessor(logprobs) generate_params['logprob_proc'] = LogprobProcessor(logprobs)
logits_processor.extend([req_params['logprob_proc']]) logits_processor.extend([generate_params['logprob_proc']])
else: else:
logprobs = None logprobs = None
if logits_processor: # requires logits_processor support if logits_processor: # requires logits_processor support
req_params['logits_processor'] = LogitsProcessorList(logits_processor) generate_params['logits_processor'] = LogitsProcessorList(logits_processor)
return req_params return generate_params
def messages_to_prompt(body: dict, req_params: dict, max_tokens): def convert_history(history):
# functions '''
if body.get('functions', []): # chat only Chat histories in this program are in the format [message, reply].
This function converts OpenAI histories to that format.
'''
chat_dialogue = []
current_message = ""
current_reply = ""
user_input = ""
for entry in history:
content = entry["content"]
role = entry["role"]
if role == "user":
user_input = content
if current_message:
chat_dialogue.append([current_message, ''])
current_message = ""
current_message = content
elif role == "assistant":
current_reply = content
if current_message:
chat_dialogue.append([current_message, current_reply])
current_message = ""
current_reply = ""
else:
chat_dialogue.append(['', current_reply])
# if current_message:
# chat_dialogue.append([current_message, ''])
return user_input, {'internal': chat_dialogue, 'visible': copy.deepcopy(chat_dialogue)}
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False) -> dict:
if body.get('functions', []):
raise InvalidRequestError(message="functions is not supported.", param='functions') raise InvalidRequestError(message="functions is not supported.", param='functions')
if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'}
if body.get('function_call', ''):
raise InvalidRequestError(message="function_call is not supported.", param='function_call') raise InvalidRequestError(message="function_call is not supported.", param='function_call')
if 'messages' not in body: if 'messages' not in body:
raise InvalidRequestError(message="messages is required", param='messages') raise InvalidRequestError(message="messages is required", param='messages')
messages = body['messages'] messages = body['messages']
role_formats = {
'user': 'User: {message}\n',
'assistant': 'Assistant: {message}\n',
'system': '{message}',
'context': 'You are a helpful assistant. Answer as concisely as possible.\nUser: I want your assistance.\nAssistant: Sure! What can I do for you?',
'prompt': 'Assistant:',
}
if 'stopping_strings' not in req_params:
req_params['stopping_strings'] = []
# Instruct models can be much better
if shared.settings['instruction_template']:
try:
instruct = yaml.safe_load(open(f"instruction-templates/{shared.settings['instruction_template']}.yaml", 'r'))
template = instruct['turn_template']
system_message_template = "{message}"
system_message_default = instruct.get('context', '') # can be missing
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct.get('user', ''))
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct.get('bot', ''))
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
role_formats = {
'user': user_message_template,
'assistant': bot_message_template,
'system': system_message_template,
'context': system_message_default,
'prompt': bot_prompt,
}
if 'Alpaca' in shared.settings['instruction_template']:
req_params['stopping_strings'].extend(['\n###'])
elif instruct['user']: # WizardLM and some others have no user prompt.
req_params['stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
debug_msg(f"Loaded instruction role format: {shared.settings['instruction_template']}")
except Exception as e:
req_params['stopping_strings'].extend(['\nUser:', 'User:']) # XXX User: prompt here also
print(f"Exception: When loading instruction-templates/{shared.settings['instruction_template']}.yaml: {repr(e)}")
print("Warning: Loaded default instruction-following template for model.")
else:
req_params['stopping_strings'].extend(['\nUser:', 'User:']) # XXX User: prompt here also
print("Warning: Loaded default instruction-following template for model.")
system_msgs = []
chat_msgs = []
# You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else ''
context_msg = end_line(context_msg)
# Maybe they sent both? This is not documented in the API, but some clients seem to do this.
if 'prompt' in body:
context_msg = end_line(role_formats['system'].format(message=body['prompt'])) + context_msg
for m in messages: for m in messages:
if 'role' not in m: if 'role' not in m:
raise InvalidRequestError(message="messages: missing role", param='messages') raise InvalidRequestError(message="messages: missing role", param='messages')
elif m['role'] == 'function':
raise InvalidRequestError(message="role: function is not supported.", param='messages')
if 'content' not in m: if 'content' not in m:
raise InvalidRequestError(message="messages: missing content", param='messages') raise InvalidRequestError(message="messages: missing content", param='messages')
role = m['role']
content = m['content']
# name = m.get('name', None)
# function_call = m.get('function_call', None) # user name or function name with output in content
msg = role_formats[role].format(message=content)
if role == 'system':
system_msgs.extend([msg])
elif role == 'function':
raise InvalidRequestError(message="role: function is not supported.", param='messages')
else:
chat_msgs.extend([msg])
system_msg = '\n'.join(system_msgs)
system_msg = end_line(system_msg)
prompt = system_msg + context_msg + ''.join(chat_msgs) + role_formats['prompt']
token_count = len(encode(prompt)[0])
if token_count >= req_params['truncation_length']:
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens."
raise InvalidRequestError(message=err_msg, param='messages')
if max_tokens > 0 and token_count + max_tokens > req_params['truncation_length']:
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens and max_tokens is {max_tokens}."
print(f"Warning: ${err_msg}")
# raise InvalidRequestError(message=err_msg, params='max_tokens')
return prompt, token_count
def chat_completions(body: dict, is_legacy: bool = False) -> dict:
# Chat Completions # Chat Completions
object_type = 'chat.completions' object_type = 'chat.completions' if not stream else 'chat.completions.chunk'
created_time = int(time.time()) created_time = int(time.time())
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000)) cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices' resp_list = 'data' if is_legacy else 'choices'
# common params # generation parameters
req_params = marshal_common_params(body) generate_params = process_parameters(body, is_legacy=is_legacy)
req_params['stream'] = False continue_ = body['continue_']
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
# chat default max_tokens is 'inf', but also flexible # Instruction template
max_tokens = 0 instruction_template = body['instruction_template'] or shared.settings['instruction_template']
max_tokens_str = 'length' if is_legacy else 'max_tokens' name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character_memoized(instruction_template, '', '', instruct=True)
if max_tokens_str in body: name1_instruct = body['name1_instruct'] or name1_instruct
max_tokens = default(body, max_tokens_str, req_params['truncation_length']) name2_instruct = body['name2_instruct'] or name2_instruct
req_params['max_new_tokens'] = max_tokens context_instruct = body['context_instruct'] or context_instruct
else: turn_template = body['turn_template'] or turn_template
req_params['max_new_tokens'] = req_params['truncation_length']
# format the prompt from messages # Chat character
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings'] character = body['character'] or shared.settings['character']
name1 = body['name1'] or shared.settings['name1']
name1, name2, _, greeting, context, _ = load_character_memoized(character, name1, '', instruct=False)
name2 = body['name2'] or name2
context = body['context'] or context
greeting = body['greeting'] or greeting
# set real max, avoid deeper errors # History
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']: user_input, history = convert_history(messages)
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
stopping_strings = req_params.pop('stopping_strings', []) generate_params.update({
'mode': body['mode'],
'name1': name1,
'name2': name2,
'context': context,
'greeting': greeting,
'name1_instruct': name1_instruct,
'name2_instruct': name2_instruct,
'context_instruct': context_instruct,
'turn_template': turn_template,
'chat-instruct_command': body['chat_instruct_command'],
'history': history,
'stream': stream
})
# generate reply ####################################### max_tokens = generate_params['max_new_tokens']
debug_msg({'prompt': prompt, 'req_params': req_params}) if max_tokens in [None, 0]:
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) generate_params['max_new_tokens'] = 200
generate_params['auto_max_new_tokens'] = True
answer = '' requested_model = generate_params.pop('model')
for a in generator: logprob_proc = generate_params.pop('logprob_proc', None)
answer = a
# strip extra leading space off new generated content
if answer and answer[0] == ' ':
answer = answer[1:]
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']:
stop_reason = "length"
resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name, # TODO: add Lora info?
resp_list: [{
"index": 0,
"finish_reason": stop_reason,
"message": {"role": "assistant", "content": answer}
}],
"usage": {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
}
if logprob_proc: # not official for chat yet
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
# else:
# resp[resp_list][0]["logprobs"] = None
return resp
# generator
def stream_chat_completions(body: dict, is_legacy: bool = False):
# Chat Completions
stream_object_type = 'chat.completions.chunk'
created_time = int(time.time())
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
# common params
req_params = marshal_common_params(body)
req_params['stream'] = True
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
# chat default max_tokens is 'inf', but also flexible
max_tokens = 0
max_tokens_str = 'length' if is_legacy else 'max_tokens'
if max_tokens_str in body:
max_tokens = default(body, max_tokens_str, req_params['truncation_length'])
req_params['max_new_tokens'] = max_tokens
else:
req_params['max_new_tokens'] = req_params['truncation_length']
# format the prompt from messages
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']
# set real max, avoid deeper errors
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
def chat_streaming_chunk(content): def chat_streaming_chunk(content):
# begin streaming # begin streaming
chunk = { chunk = {
"id": cmpl_id, "id": cmpl_id,
"object": stream_object_type, "object": object_type,
"created": created_time, "created": created_time,
"model": shared.model_name, "model": shared.model_name,
resp_list: [{ resp_list: [{
@ -376,262 +260,262 @@ def stream_chat_completions(body: dict, is_legacy: bool = False):
# chunk[resp_list][0]["logprobs"] = None # chunk[resp_list][0]["logprobs"] = None
return chunk return chunk
yield chat_streaming_chunk('') if stream:
yield chat_streaming_chunk('')
# generate reply ####################################### # generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params}) prompt = generate_chat_prompt(user_input, generate_params)
token_count = len(encode(prompt)[0])
debug_msg({'prompt': prompt, 'generate_params': generate_params})
stopping_strings = req_params.pop('stopping_strings', []) generator = generate_chat_reply(
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = '' answer = ''
seen_content = '' seen_content = ''
completion_token_count = 0 completion_token_count = 0
for a in generator: for a in generator:
answer = a answer = a['internal'][-1][1]
if stream:
len_seen = len(seen_content)
new_content = answer[len_seen:]
len_seen = len(seen_content) if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
new_content = answer[len_seen:] continue
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. seen_content = answer
continue
seen_content = answer # strip extra leading space off new generated content
if len_seen == 0 and new_content[0] == ' ':
new_content = new_content[1:]
# strip extra leading space off new generated content chunk = chat_streaming_chunk(new_content)
if len_seen == 0 and new_content[0] == ' ':
new_content = new_content[1:]
chunk = chat_streaming_chunk(new_content) yield chunk
yield chunk
# to get the correct token_count, strip leading space if present
if answer and answer[0] == ' ':
answer = answer[1:]
completion_token_count = len(encode(answer)[0]) completion_token_count = len(encode(answer)[0])
stop_reason = "stop" stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']: if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
stop_reason = "length" stop_reason = "length"
chunk = chat_streaming_chunk('') if stream:
chunk[resp_list][0]['finish_reason'] = stop_reason chunk = chat_streaming_chunk('')
chunk['usage'] = { chunk[resp_list][0]['finish_reason'] = stop_reason
"prompt_tokens": token_count, chunk['usage'] = {
"completion_tokens": completion_token_count, "prompt_tokens": token_count,
"total_tokens": token_count + completion_token_count "completion_tokens": completion_token_count,
} "total_tokens": token_count + completion_token_count
}
yield chunk yield chunk
else:
resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
resp_list: [{
"index": 0,
"finish_reason": stop_reason,
"message": {"role": "assistant", "content": answer}
}],
"usage": {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
}
if logprob_proc: # not official for chat yet
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
# else:
# resp[resp_list][0]["logprobs"] = None
yield resp
def completions(body: dict, is_legacy: bool = False): def completions_common(body: dict, is_legacy: bool = False, stream=False):
# Legacy object_type = 'text_completion.chunk' if stream else 'text_completion'
# Text Completions
object_type = 'text_completion'
created_time = int(time.time()) created_time = int(time.time())
cmpl_id = "conv-%d" % (int(time.time() * 1000000000)) cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices' resp_list = 'data' if is_legacy else 'choices'
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
prompt_str = 'context' if is_legacy else 'prompt' prompt_str = 'context' if is_legacy else 'prompt'
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
if prompt_str not in body: if prompt_str not in body:
raise InvalidRequestError("Missing required input", param=prompt_str) raise InvalidRequestError("Missing required input", param=prompt_str)
prompt_arg = body[prompt_str]
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)):
prompt_arg = [prompt_arg]
# common params # common params
req_params = marshal_common_params(body) generate_params = process_parameters(body, is_legacy=is_legacy)
req_params['stream'] = False max_tokens = generate_params['max_new_tokens']
max_tokens_str = 'length' if is_legacy else 'max_tokens' generate_params['stream'] = stream
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens']) requested_model = generate_params.pop('model')
req_params['max_new_tokens'] = max_tokens logprob_proc = generate_params.pop('logprob_proc', None)
requested_model = req_params.pop('requested_model') # generate_params['suffix'] = body.get('suffix', generate_params['suffix'])
logprob_proc = req_params.pop('logprob_proc', None) generate_params['echo'] = body.get('echo', generate_params['echo'])
stopping_strings = req_params.pop('stopping_strings', [])
# req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
resp_list_data = [] if not stream:
total_completion_token_count = 0 prompt_arg = body[prompt_str]
total_prompt_token_count = 0 if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)):
prompt_arg = [prompt_arg]
for idx, prompt in enumerate(prompt_arg, start=0): resp_list_data = []
if isinstance(prompt[0], int): total_completion_token_count = 0
# token lists total_prompt_token_count = 0
if requested_model == shared.model_name:
prompt = decode(prompt)[0] for idx, prompt in enumerate(prompt_arg, start=0):
else: if isinstance(prompt[0], int):
# token lists
if requested_model == shared.model_name:
prompt = decode(prompt)[0]
else:
try:
encoder = tiktoken.encoding_for_model(requested_model)
prompt = encoder.decode(prompt)
except KeyError:
prompt = decode(prompt)[0]
token_count = len(encode(prompt)[0])
total_prompt_token_count += token_count
# generate reply #######################################
debug_msg({'prompt': prompt, 'generate_params': generate_params})
generator = generate_reply(prompt, generate_params, is_chat=False)
answer = ''
for a in generator:
answer = a
# strip extra leading space off new generated content
if answer and answer[0] == ' ':
answer = answer[1:]
completion_token_count = len(encode(answer)[0])
total_completion_token_count += completion_token_count
stop_reason = "stop"
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
respi = {
"index": idx,
"finish_reason": stop_reason,
"text": answer,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
}
resp_list_data.extend([respi])
resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
resp_list: resp_list_data,
"usage": {
"prompt_tokens": total_prompt_token_count,
"completion_tokens": total_completion_token_count,
"total_tokens": total_prompt_token_count + total_completion_token_count
}
}
yield resp
else:
prompt = body[prompt_str]
if isinstance(prompt, list):
if prompt and isinstance(prompt[0], int):
try: try:
encoder = tiktoken.encoding_for_model(requested_model) encoder = tiktoken.encoding_for_model(requested_model)
prompt = encoder.decode(prompt) prompt = encoder.decode(prompt)
except KeyError: except KeyError:
prompt = decode(prompt)[0] prompt = decode(prompt)[0]
else:
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
token_count = len(encode(prompt)[0]) token_count = len(encode(prompt)[0])
total_prompt_token_count += token_count
if token_count + max_tokens > req_params['truncation_length']: def text_streaming_chunk(content):
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})." # begin streaming
# print(f"Warning: ${err_msg}") chunk = {
raise InvalidRequestError(message=err_msg, param=max_tokens_str) "id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name,
resp_list: [{
"index": 0,
"finish_reason": None,
"text": content,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
}],
}
return chunk
yield text_streaming_chunk('')
# generate reply ####################################### # generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params}) debug_msg({'prompt': prompt, 'generate_params': generate_params})
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) generator = generate_reply(prompt, generate_params, is_chat=False)
answer = '' answer = ''
seen_content = ''
completion_token_count = 0
for a in generator: for a in generator:
answer = a answer = a
# strip extra leading space off new generated content len_seen = len(seen_content)
new_content = answer[len_seen:]
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue
seen_content = answer
# strip extra leading space off new generated content
if len_seen == 0 and new_content[0] == ' ':
new_content = new_content[1:]
chunk = text_streaming_chunk(new_content)
yield chunk
# to get the correct count, we strip the leading space if present
if answer and answer[0] == ' ': if answer and answer[0] == ' ':
answer = answer[1:] answer = answer[1:]
completion_token_count = len(encode(answer)[0]) completion_token_count = len(encode(answer)[0])
total_completion_token_count += completion_token_count
stop_reason = "stop" stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens: if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length" stop_reason = "length"
respi = { chunk = text_streaming_chunk('')
"index": idx, chunk[resp_list][0]["finish_reason"] = stop_reason
"finish_reason": stop_reason, chunk["usage"] = {
"text": answer, "prompt_tokens": token_count,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, "completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
} }
resp_list_data.extend([respi])
resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": shared.model_name, # TODO: add Lora info?
resp_list: resp_list_data,
"usage": {
"prompt_tokens": total_prompt_token_count,
"completion_tokens": total_completion_token_count,
"total_tokens": total_prompt_token_count + total_completion_token_count
}
}
return resp
# generator
def stream_completions(body: dict, is_legacy: bool = False):
# Legacy
# Text Completions
# object_type = 'text_completion'
stream_object_type = 'text_completion.chunk'
created_time = int(time.time())
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
prompt_str = 'context' if is_legacy else 'prompt'
if prompt_str not in body:
raise InvalidRequestError("Missing required input", param=prompt_str)
prompt = body[prompt_str]
req_params = marshal_common_params(body)
requested_model = req_params.pop('requested_model')
if isinstance(prompt, list):
if prompt and isinstance(prompt[0], int):
try:
encoder = tiktoken.encoding_for_model(requested_model)
prompt = encoder.decode(prompt)
except KeyError:
prompt = decode(prompt)[0]
else:
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
# common params
req_params['stream'] = True
max_tokens_str = 'length' if is_legacy else 'max_tokens'
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
req_params['max_new_tokens'] = max_tokens
logprob_proc = req_params.pop('logprob_proc', None)
stopping_strings = req_params.pop('stopping_strings', [])
# req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
token_count = len(encode(prompt)[0])
if token_count + max_tokens > req_params['truncation_length']:
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
# print(f"Warning: ${err_msg}")
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
def text_streaming_chunk(content):
# begin streaming
chunk = {
"id": cmpl_id,
"object": stream_object_type,
"created": created_time,
"model": shared.model_name,
resp_list: [{
"index": 0,
"finish_reason": None,
"text": content,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
}],
}
return chunk
yield text_streaming_chunk('')
# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
seen_content = ''
completion_token_count = 0
for a in generator:
answer = a
len_seen = len(seen_content)
new_content = answer[len_seen:]
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue
seen_content = answer
# strip extra leading space off new generated content
if len_seen == 0 and new_content[0] == ' ':
new_content = new_content[1:]
chunk = text_streaming_chunk(new_content)
yield chunk yield chunk
# to get the correct count, we strip the leading space if present
if answer and answer[0] == ' ':
answer = answer[1:]
completion_token_count = len(encode(answer)[0]) def chat_completions(body: dict, is_legacy: bool = False) -> dict:
stop_reason = "stop" generator = chat_completions_common(body, is_legacy, stream=False)
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens: return deque(generator, maxlen=1).pop()
stop_reason = "length"
chunk = text_streaming_chunk('')
chunk[resp_list][0]["finish_reason"] = stop_reason
chunk["usage"] = {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
yield chunk def stream_chat_completions(body: dict, is_legacy: bool = False):
for resp in chat_completions_common(body, is_legacy, stream=True):
yield resp
def completions(body: dict, is_legacy: bool = False) -> dict:
generator = completions_common(body, is_legacy, stream=False)
return deque(generator, maxlen=1).pop()
def stream_completions(body: dict, is_legacy: bool = False):
for resp in completions_common(body, is_legacy, stream=True):
yield resp

View File

@ -1,78 +0,0 @@
import copy
# Slightly different defaults for OpenAI's API
# Data type is important, Ex. use 0.0 for a float 0
default_req_params = {
'max_new_tokens': 16, # 'Inf' for chat
'auto_max_new_tokens': False,
'max_tokens_second': 0,
'temperature': 1.0,
'temperature_last': False,
'top_p': 1.0,
'min_p': 0,
'top_k': 1, # choose 20 for chat in absence of another default
'repetition_penalty': 1.18,
'presence_penalty': 0,
'frequency_penalty': 0,
'repetition_penalty_range': 0,
'encoder_repetition_penalty': 1.0,
'suffix': None,
'stream': False,
'echo': False,
'seed': -1,
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
'truncation_length': 2048, # first use shared.settings value
'add_bos_token': True,
'do_sample': True,
'typical_p': 1.0,
'epsilon_cutoff': 0.0, # In units of 1e-4
'eta_cutoff': 0.0, # In units of 1e-4
'tfs': 1.0,
'top_a': 0.0,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0.0,
'length_penalty': 1.0,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5.0,
'mirostat_eta': 0.1,
'grammar_string': '',
'guidance_scale': 1,
'negative_prompt': '',
'ban_eos_token': False,
'custom_token_bans': '',
'skip_special_tokens': True,
'custom_stopping_strings': '',
# 'logits_processor' - conditionally passed
# 'stopping_strings' - temporarily used
# 'logprobs' - temporarily used
# 'requested_model' - temporarily used
}
def get_default_req_params():
return copy.deepcopy(default_req_params)
def default(dic, key, default):
'''
little helper to get defaults if arg is present but None and should be the same type as default.
'''
val = dic.get(key, default)
if not isinstance(val, type(default)):
# maybe it's just something like 1 instead of 1.0
try:
v = type(default)(val)
if type(val)(v) == val: # if it's the same value passed in, it's ok.
return v
except:
pass
val = default
return val
def clamp(value, minvalue, maxvalue):
return max(minvalue, min(value, maxvalue))

View File

@ -1,101 +0,0 @@
import time
import yaml
from extensions.openai.defaults import get_default_req_params
from extensions.openai.errors import InvalidRequestError
from extensions.openai.utils import debug_msg
from modules import shared
from modules.text_generation import encode, generate_reply
def edits(instruction: str, input: str, temperature=1.0, top_p=1.0) -> dict:
created_time = int(time.time() * 1000)
# Request parameters
req_params = get_default_req_params()
stopping_strings = []
# Alpaca is verbose so a good default prompt
default_template = (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
)
instruction_template = default_template
# Use the special instruction/input/response template for anything trained like Alpaca
if shared.settings['instruction_template']:
if 'Alpaca' in shared.settings['instruction_template']:
stopping_strings.extend(['\n###'])
else:
try:
instruct = yaml.safe_load(open(f"instruction-templates/{shared.settings['instruction_template']}.yaml", 'r'))
template = instruct['turn_template']
template = template\
.replace('<|user|>', instruct.get('user', ''))\
.replace('<|bot|>', instruct.get('bot', ''))\
.replace('<|user-message|>', '{instruction}\n{input}')
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
if instruct['user']:
stopping_strings.extend(['\n' + instruct['user'], instruct['user']])
except Exception as e:
instruction_template = default_template
print(f"Exception: When loading instruction-templates/{shared.settings['instruction_template']}.yaml: {repr(e)}")
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
else:
stopping_strings.extend(['\n###'])
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
edit_task = instruction_template.format(instruction=instruction, input=input)
truncation_length = shared.settings['truncation_length']
token_count = len(encode(edit_task)[0])
max_tokens = truncation_length - token_count
if max_tokens < 1:
err_msg = f"This model maximum context length is {truncation_length} tokens. However, your messages resulted in over {truncation_length - max_tokens} tokens."
raise InvalidRequestError(err_msg, param='input')
req_params['max_new_tokens'] = max_tokens
req_params['truncation_length'] = truncation_length
req_params['temperature'] = temperature
req_params['top_p'] = top_p
req_params['seed'] = shared.settings.get('seed', req_params['seed'])
req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token'])
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
debug_msg({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
for a in generator:
answer = a
# some reply's have an extra leading space to fit the instruction template, just clip it off from the reply.
if edit_task[-1] != '\n' and answer and answer[0] == ' ':
answer = answer[1:]
completion_token_count = len(encode(answer)[0])
resp = {
"object": "edit",
"created": created_time,
"choices": [{
"text": answer,
"index": 0,
}],
"usage": {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
}
return resp

View File

@ -6,9 +6,13 @@ from extensions.openai.utils import debug_msg, float_list_to_base64
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
embeddings_params_initialized = False embeddings_params_initialized = False
# using 'lazy loading' to avoid circular import
# so this function will be executed only once
def initialize_embedding_params(): def initialize_embedding_params():
'''
using 'lazy loading' to avoid circular import
so this function will be executed only once
'''
global embeddings_params_initialized global embeddings_params_initialized
if not embeddings_params_initialized: if not embeddings_params_initialized:
global st_model, embeddings_model, embeddings_device global st_model, embeddings_model, embeddings_device
@ -26,7 +30,7 @@ def load_embedding_model(model: str) -> SentenceTransformer:
initialize_embedding_params() initialize_embedding_params()
global embeddings_device, embeddings_model global embeddings_device, embeddings_model
try: try:
print(f"\Try embedding model: {model} on {embeddings_device}") print(f"Try embedding model: {model} on {embeddings_device}")
# see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer # see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer
embeddings_model = SentenceTransformer(model, device=embeddings_device) embeddings_model = SentenceTransformer(model, device=embeddings_device)
# ... embeddings_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM # ... embeddings_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM
@ -54,7 +58,7 @@ def get_embeddings(input: list) -> np.ndarray:
model = get_embeddings_model() model = get_embeddings_model()
debug_msg(f"embedding model : {model}") debug_msg(f"embedding model : {model}")
embedding = model.encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False) embedding = model.encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False)
debug_msg(f"embedding result : {embedding}") # might be too long even for debug, use at you own will debug_msg(f"embedding result : {embedding}") # might be too long even for debug, use at you own will
return embedding return embedding

View File

@ -50,6 +50,7 @@ def generations(prompt: str, size: str, response_format: str, n: int):
'data': [] 'data': []
} }
from extensions.openai.script import params from extensions.openai.script import params
# TODO: support SD_WEBUI_AUTH username:password pair. # TODO: support SD_WEBUI_AUTH username:password pair.
sd_url = f"{os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', ''))}/sdapi/v1/txt2img" sd_url = f"{os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', ''))}/sdapi/v1/txt2img"

View File

@ -1,4 +1,5 @@
SpeechRecognition==3.10.0 SpeechRecognition==3.10.0
flask_cloudflared==0.0.12 flask_cloudflared==0.0.14
sentence-transformers sentence-transformers
sse-starlette==1.6.5
tiktoken tiktoken

View File

@ -1,351 +1,255 @@
import json import json
import os import os
import ssl
import traceback
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread from threading import Thread
import extensions.openai.completions as OAIcompletions import extensions.openai.completions as OAIcompletions
import extensions.openai.edits as OAIedits
import extensions.openai.embeddings as OAIembeddings import extensions.openai.embeddings as OAIembeddings
import extensions.openai.images as OAIimages import extensions.openai.images as OAIimages
import extensions.openai.models as OAImodels import extensions.openai.models as OAImodels
import extensions.openai.moderations as OAImoderations import extensions.openai.moderations as OAImoderations
from extensions.openai.defaults import clamp, default, get_default_req_params
from extensions.openai.errors import (
InvalidRequestError,
OpenAIError,
ServiceUnavailableError
)
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import debug_msg
from modules import shared
import cgi
import speech_recognition as sr import speech_recognition as sr
import uvicorn
from extensions.openai.errors import ServiceUnavailableError
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import _start_cloudflared
from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.requests import Request
from fastapi.responses import JSONResponse
from modules import shared
from modules.logging_colors import logger
from pydub import AudioSegment from pydub import AudioSegment
from sse_starlette import EventSourceResponse
from .typing import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
to_dict
)
params = { params = {
# default params
'port': 5001,
'embedding_device': 'cpu', 'embedding_device': 'cpu',
'embedding_model': 'all-mpnet-base-v2', 'embedding_model': 'all-mpnet-base-v2',
# optional params
'sd_webui_url': '', 'sd_webui_url': '',
'debug': 0 'debug': 0
} }
class Handler(BaseHTTPRequestHandler):
def send_access_control_headers(self):
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Credentials", "true")
self.send_header(
"Access-Control-Allow-Methods",
"GET,HEAD,OPTIONS,POST,PUT"
)
self.send_header(
"Access-Control-Allow-Headers",
"Origin, Accept, X-Requested-With, Content-Type, "
"Access-Control-Request-Method, Access-Control-Request-Headers, "
"Authorization"
)
def do_OPTIONS(self): def verify_api_key(authorization: str = Header(None)) -> None:
self.send_response(200) expected_api_key = shared.args.api_key
self.send_access_control_headers() if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
self.send_header('Content-Type', 'application/json') raise HTTPException(status_code=401, detail="Unauthorized")
self.end_headers()
self.wfile.write("OK".encode('utf-8'))
def start_sse(self):
self.send_response(200)
self.send_access_control_headers()
self.send_header('Content-Type', 'text/event-stream')
self.send_header('Cache-Control', 'no-cache')
# self.send_header('Connection', 'keep-alive')
self.end_headers()
def send_sse(self, chunk: dict): app = FastAPI(dependencies=[Depends(verify_api_key)])
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
debug_msg(response[:-4])
self.wfile.write(response.encode('utf-8'))
def end_sse(self): # Configure CORS settings to allow all origins, methods, and headers
response = 'data: [DONE]\r\n\r\n' app.add_middleware(
debug_msg(response[:-4]) CORSMiddleware,
self.wfile.write(response.encode('utf-8')) allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "HEAD", "OPTIONS", "POST", "PUT"],
allow_headers=[
"Origin",
"Accept",
"X-Requested-With",
"Content-Type",
"Access-Control-Request-Method",
"Access-Control-Request-Headers",
"Authorization",
],
)
def return_json(self, ret: dict, code: int = 200, no_debug=False):
self.send_response(code)
self.send_access_control_headers()
self.send_header('Content-Type', 'application/json')
response = json.dumps(ret) @app.options("/")
r_utf8 = response.encode('utf-8') async def options_route():
return JSONResponse(content="OK")
self.send_header('Content-Length', str(len(r_utf8)))
self.end_headers()
self.wfile.write(r_utf8) @app.post('/v1/completions', response_model=CompletionResponse)
if not no_debug: @app.post('/v1/generate', response_model=CompletionResponse)
debug_msg(r_utf8) async def openai_completions(request: Request, request_data: CompletionRequest):
path = request.url.path
is_legacy = "/generate" in path
def openai_error(self, message, code=500, error_type='APIError', param='', internal_message=''): if request_data.stream:
async def generator():
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy)
for resp in response:
yield {"data": json.dumps(resp)}
error_resp = { return EventSourceResponse(generator()) # SSE streaming
'error': {
'message': message,
'code': code,
'type': error_type,
'param': param,
}
}
if internal_message:
print(error_type, message)
print(internal_message)
# error_resp['internal_message'] = internal_message
self.return_json(error_resp, code) else:
response = OAIcompletions.completions(to_dict(request_data), is_legacy=is_legacy)
return JSONResponse(response)
def openai_error_handler(func):
def wrapper(self):
try:
func(self)
except InvalidRequestError as e:
self.openai_error(e.message, e.code, e.__class__.__name__, e.param, internal_message=e.internal_message)
except OpenAIError as e:
self.openai_error(e.message, e.code, e.__class__.__name__, internal_message=e.internal_message)
except Exception as e:
self.openai_error(repr(e), 500, 'OpenAIError', internal_message=traceback.format_exc())
return wrapper @app.post('/v1/chat/completions', response_model=ChatCompletionResponse)
async def openai_chat_completions(request: Request, request_data: ChatCompletionRequest):
path = request.url.path
is_legacy = "/generate" in path
@openai_error_handler if request_data.stream:
def do_GET(self): async def generator():
debug_msg(self.requestline) response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy)
debug_msg(self.headers) for resp in response:
yield {"data": json.dumps(resp)}
if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'): return EventSourceResponse(generator()) # SSE streaming
is_legacy = 'engines' in self.path
is_list = self.path.split('?')[0].split('#')[0] in ['/v1/engines', '/v1/models']
if is_legacy and not is_list:
model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):]
resp = OAImodels.load_model(model_name)
elif is_list:
resp = OAImodels.list_models(is_legacy)
else:
model_name = self.path[len('/v1/models/'):]
resp = OAImodels.model_info(model_name)
self.return_json(resp) else:
response = OAIcompletions.chat_completions(to_dict(request_data), is_legacy=is_legacy)
return JSONResponse(response)
elif '/billing/usage' in self.path:
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
self.return_json({"total_usage": 0}, no_debug=True)
else: @app.get("/v1/models")
self.send_error(404) @app.get("/v1/engines")
async def handle_models(request: Request):
path = request.url.path
is_legacy = 'engines' in path
is_list = request.url.path.split('?')[0].split('#')[0] in ['/v1/engines', '/v1/models']
@openai_error_handler if is_legacy and not is_list:
def do_POST(self): model_name = path[path.find('/v1/engines/') + len('/v1/engines/'):]
resp = OAImodels.load_model(model_name)
elif is_list:
resp = OAImodels.list_models(is_legacy)
else:
model_name = path[len('/v1/models/'):]
resp = OAImodels.model_info(model_name)
if '/v1/audio/transcriptions' in self.path: return JSONResponse(content=resp)
r = sr.Recognizer()
# Parse the form data
form = cgi.FieldStorage(
fp=self.rfile,
headers=self.headers,
environ={'REQUEST_METHOD': 'POST', 'CONTENT_TYPE': self.headers['Content-Type']}
)
audio_file = form['file'].file
audio_data = AudioSegment.from_file(audio_file)
# Convert AudioSegment to raw data
raw_data = audio_data.raw_data
# Create AudioData object
audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
whipser_language = form.getvalue('language', None)
whipser_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny
transcription = {"text": ""} @app.get('/v1/billing/usage')
def handle_billing_usage():
try: '''
transcription["text"] = r.recognize_whisper(audio_data, language=whipser_language, model=whipser_model) Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
except sr.UnknownValueError: '''
print("Whisper could not understand audio") return JSONResponse(content={"total_usage": 0})
transcription["text"] = "Whisper could not understand audio UnknownValueError"
except sr.RequestError as e:
print("Could not request results from Whisper", e)
transcription["text"] = "Whisper could not understand audio RequestError"
self.return_json(transcription, no_debug=True)
return
debug_msg(self.requestline)
debug_msg(self.headers)
content_length = self.headers.get('Content-Length')
transfer_encoding = self.headers.get('Transfer-Encoding')
if content_length: @app.post('/v1/audio/transcriptions')
body = json.loads(self.rfile.read(int(content_length)).decode('utf-8')) async def handle_audio_transcription(request: Request):
elif transfer_encoding == 'chunked': r = sr.Recognizer()
chunks = []
while True:
chunk_size = int(self.rfile.readline(), 16) # Read the chunk size
if chunk_size == 0:
break # End of chunks
chunks.append(self.rfile.read(chunk_size))
self.rfile.readline() # Consume the trailing newline after each chunk
body = json.loads(b''.join(chunks).decode('utf-8'))
else:
self.send_response(400, "Bad Request: Either Content-Length or Transfer-Encoding header expected.")
self.end_headers()
return
debug_msg(body) form = await request.form()
audio_file = await form["file"].read()
audio_data = AudioSegment.from_file(audio_file)
if '/completions' in self.path or '/generate' in self.path: # Convert AudioSegment to raw data
raw_data = audio_data.raw_data
if not shared.model: # Create AudioData object
raise ServiceUnavailableError("No model loaded.") audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
whipser_language = form.getvalue('language', None)
whipser_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny
is_legacy = '/generate' in self.path transcription = {"text": ""}
is_streaming = body.get('stream', False)
if is_streaming: try:
self.start_sse() transcription["text"] = r.recognize_whisper(audio_data, language=whipser_language, model=whipser_model)
except sr.UnknownValueError:
print("Whisper could not understand audio")
transcription["text"] = "Whisper could not understand audio UnknownValueError"
except sr.RequestError as e:
print("Could not request results from Whisper", e)
transcription["text"] = "Whisper could not understand audio RequestError"
response = [] return JSONResponse(content=transcription)
if 'chat' in self.path:
response = OAIcompletions.stream_chat_completions(body, is_legacy=is_legacy)
else:
response = OAIcompletions.stream_completions(body, is_legacy=is_legacy)
for resp in response:
self.send_sse(resp)
self.end_sse() @app.post('/v1/images/generations')
async def handle_image_generation(request: Request):
else: if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
response = '' raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
if 'chat' in self.path:
response = OAIcompletions.chat_completions(body, is_legacy=is_legacy)
else:
response = OAIcompletions.completions(body, is_legacy=is_legacy)
self.return_json(response) body = await request.json()
prompt = body['prompt']
size = body.get('size', '1024x1024')
response_format = body.get('response_format', 'url') # or b64_json
n = body.get('n', 1) # ignore the batch limits of max 10
elif '/edits' in self.path: response = await OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
# deprecated return JSONResponse(response)
if not shared.model:
raise ServiceUnavailableError("No model loaded.")
req_params = get_default_req_params() @app.post("/v1/embeddings")
async def handle_embeddings(request: Request):
body = await request.json()
encoding_format = body.get("encoding_format", "")
instruction = body['instruction'] input = body.get('input', body.get('text', ''))
input = body.get('input', '') if not input:
temperature = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0 raise HTTPException(status_code=400, detail="Missing required argument input")
top_p = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
response = OAIedits.edits(instruction, input, temperature, top_p) if type(input) is str:
input = [input]
self.return_json(response) response = OAIembeddings.embeddings(input, encoding_format)
return JSONResponse(response)
elif '/images/generations' in self.path:
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
prompt = body['prompt'] @app.post("/v1/moderations")
size = default(body, 'size', '1024x1024') async def handle_moderations(request: Request):
response_format = default(body, 'response_format', 'url') # or b64_json body = await request.json()
n = default(body, 'n', 1) # ignore the batch limits of max 10 input = body["input"]
if not input:
raise HTTPException(status_code=400, detail="Missing required argument input")
response = OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n) response = OAImoderations.moderations(input)
return JSONResponse(response)
self.return_json(response, no_debug=True)
elif '/embeddings' in self.path: @app.post("/api/v1/token-count")
encoding_format = body.get('encoding_format', '') async def handle_token_count(request: Request):
body = await request.json()
response = token_count(body['prompt'])
return JSONResponse(response)
input = body.get('input', body.get('text', ''))
if not input:
raise InvalidRequestError("Missing required argument input", params='input')
if type(input) is str: @app.post("/api/v1/token/encode")
input = [input] async def handle_token_encode(request: Request):
body = await request.json()
encoding_format = body.get("encoding_format", "")
response = token_encode(body["input"], encoding_format)
return JSONResponse(response)
response = OAIembeddings.embeddings(input, encoding_format)
self.return_json(response, no_debug=True) @app.post("/api/v1/token/decode")
async def handle_token_decode(request: Request):
elif '/moderations' in self.path: body = await request.json()
input = body['input'] encoding_format = body.get("encoding_format", "")
if not input: response = token_decode(body["input"], encoding_format)
raise InvalidRequestError("Missing required argument input", params='input') return JSONResponse(response, no_debug=True)
response = OAImoderations.moderations(input)
self.return_json(response, no_debug=True)
elif self.path == '/api/v1/token-count':
# NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side.
response = token_count(body['prompt'])
self.return_json(response, no_debug=True)
elif self.path == '/api/v1/token/encode':
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
encoding_format = body.get('encoding_format', '')
response = token_encode(body['input'], encoding_format)
self.return_json(response, no_debug=True)
elif self.path == '/api/v1/token/decode':
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
encoding_format = body.get('encoding_format', '')
response = token_decode(body['input'], encoding_format)
self.return_json(response, no_debug=True)
else:
self.send_error(404)
def run_server(): def run_server():
port = int(os.environ.get('OPENEDAI_PORT', params.get('port', 5001))) server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1'
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', port) port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port))
server = ThreadingHTTPServer(server_addr, Handler)
ssl_certfile = os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile)
ssl_certfile=os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile) ssl_keyfile = os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile)
ssl_keyfile=os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile)
ssl_verify=True if (ssl_keyfile and ssl_certfile) else False if shared.args.public_api:
if ssl_verify: def on_start(public_url: str):
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) logger.info(f'OpenAI compatible API URL:\n\n{public_url}/v1\n')
context.load_cert_chain(ssl_certfile, ssl_keyfile)
server.socket = context.wrap_socket(server.socket, server_side=True) _start_cloudflared(port, shared.args.public_api_id, max_attempts=3, on_start=on_start)
if shared.args.share:
try:
from flask_cloudflared import _run_cloudflared
public_url = _run_cloudflared(port, port + 1)
print(f'OpenAI compatible API ready at: OPENAI_API_BASE={public_url}/v1')
except ImportError:
print('You should install flask_cloudflared manually')
else: else:
if ssl_verify: if ssl_keyfile and ssl_certfile:
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=https://{server_addr[0]}:{server_addr[1]}/v1') logger.info(f'OpenAI compatible API URL:\n\nhttps://{server_addr}:{port}/v1\n')
else: else:
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1') logger.info(f'OpenAI compatible API URL:\n\nhttp://{server_addr}:{port}/v1\n')
server.serve_forever() if shared.args.api_key:
logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n')
uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile)
def setup(): def setup():

125
extensions/openai/typing.py Normal file
View File

@ -0,0 +1,125 @@
import json
import time
from typing import List
from pydantic import BaseModel, Field
class GenerationOptions(BaseModel):
preset: str | None = None
temperature: float = 1
temperature_last: bool = False
top_p: float = 1
min_p: float = 0
top_k: int = 0
repetition_penalty: float = 1
presence_penalty: float = 0
frequency_penalty: float = 0
repetition_penalty_range: int = 0
typical_p: float = 1
tfs: float = 1
top_a: float = 0
epsilon_cutoff: float = 0
eta_cutoff: float = 0
guidance_scale: float = 1
negative_prompt: str = ''
penalty_alpha: float = 0
mirostat_mode: int = 0
mirostat_tau: float = 5
mirostat_eta: float = 0.1
do_sample: bool = True
seed: int = -1
encoder_repetition_penalty: float = 1
no_repeat_ngram_size: int = 0
min_length: int = 0
num_beams: int = 1
length_penalty: float = 1
early_stopping: bool = False
truncation_length: int = 0
max_tokens_second: int = 0
custom_token_bans: str = ""
auto_max_new_tokens: bool = False
ban_eos_token: bool = False
add_bos_token: bool = True
skip_special_tokens: bool = True
grammar_string: str = ""
class CompletionRequest(GenerationOptions):
model: str | None = None
prompt: str | List[str]
best_of: int | None = 1
echo: bool | None = False
frequency_penalty: float | None = 0
logit_bias: dict | None = None
logprobs: int | None = None
max_tokens: int | None = 16
n: int | None = 1
presence_penalty: int | None = 0
stop: str | List[str] | None = None
stream: bool | None = False
suffix: str | None = None
temperature: float | None = 1
top_p: float | None = 1
user: str | None = None
class CompletionResponse(BaseModel):
id: str
choices: List[dict]
created: int = int(time.time())
model: str
object: str = "text_completion"
usage: dict
class ChatCompletionRequest(GenerationOptions):
messages: List[dict]
model: str | None = None
frequency_penalty: float | None = 0
function_call: str | dict | None = None
functions: List[dict] | None = None
logit_bias: dict | None = None
max_tokens: int | None = None
n: int | None = 1
presence_penalty: int | None = 0
stop: str | List[str] | None = None
stream: bool | None = False
temperature: float | None = 1
top_p: float | None = 1
user: str | None = None
mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.")
instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/instruction-templates. If not set, the correct template will be guessed using the regex expressions in models/config.yaml.")
name1_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
name2_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
context_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
turn_template: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
character: str | None = Field(default=None, description="A character defined under text-generation-webui/characters. If not set, the default \"Assistant\" character will be used.")
name1: str | None = Field(default=None, description="Overwrites the value set by character.")
name2: str | None = Field(default=None, description="Overwrites the value set by character.")
context: str | None = Field(default=None, description="Overwrites the value set by character.")
greeting: str | None = Field(default=None, description="Overwrites the value set by character.")
chat_instruct_command: str | None = None
continue_: bool = Field(default=False, description="Makes the last bot message in the history be continued instead of starting a new message.")
class ChatCompletionResponse(BaseModel):
id: str
choices: List[dict]
created: int = int(time.time())
model: str
object: str = "chat.completion"
usage: dict
def to_json(obj):
return json.dumps(obj.__dict__, indent=4)
def to_dict(obj):
return obj.__dict__

View File

@ -1,8 +1,12 @@
import base64 import base64
import os import os
import time
import traceback
from typing import Callable, Optional
import numpy as np import numpy as np
def float_list_to_base64(float_array: np.ndarray) -> str: def float_list_to_base64(float_array: np.ndarray) -> str:
# Convert the list to a float32 array that the OpenAPI client expects # Convert the list to a float32 array that the OpenAPI client expects
# float_array = np.array(float_list, dtype="float32") # float_array = np.array(float_list, dtype="float32")
@ -18,13 +22,33 @@ def float_list_to_base64(float_array: np.ndarray) -> str:
return ascii_string return ascii_string
def end_line(s):
if s and s[-1] != '\n':
s = s + '\n'
return s
def debug_msg(*args, **kwargs): def debug_msg(*args, **kwargs):
from extensions.openai.script import params from extensions.openai.script import params
if os.environ.get("OPENEDAI_DEBUG", params.get('debug', 0)): if os.environ.get("OPENEDAI_DEBUG", params.get('debug', 0)):
print(*args, **kwargs) print(*args, **kwargs)
def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
try:
from flask_cloudflared import _run_cloudflared
except ImportError:
print('You should install flask_cloudflared manually')
raise Exception(
'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.')
for _ in range(max_attempts):
try:
if tunnel_id is not None:
public_url = _run_cloudflared(port, port + 1, tunnel_id=tunnel_id)
else:
public_url = _run_cloudflared(port, port + 1)
if on_start:
on_start(public_url)
return
except Exception:
traceback.print_exc()
time.sleep(3)
raise Exception('Could not start cloudflared.')

View File

@ -81,7 +81,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
# Find the maximum prompt size # Find the maximum prompt size
max_length = get_max_prompt_length(state) max_length = get_max_prompt_length(state)
all_substrings = { all_substrings = {
'chat': get_turn_substrings(state, instruct=False), 'chat': get_turn_substrings(state, instruct=False) if state['mode'] in ['chat', 'chat-instruct'] else None,
'instruct': get_turn_substrings(state, instruct=True) 'instruct': get_turn_substrings(state, instruct=True)
} }
@ -237,7 +237,10 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True)): for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True)):
# Extract the reply # Extract the reply
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply) visible_reply = reply
if state['mode'] in ['chat', 'chat-instruct']:
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
visible_reply = html.escape(visible_reply) visible_reply = html.escape(visible_reply)
if shared.stop_everything: if shared.stop_everything:

View File

@ -71,11 +71,12 @@ def load_model(model_name, loader=None):
'AutoAWQ': AutoAWQ_loader, 'AutoAWQ': AutoAWQ_loader,
} }
metadata = get_model_metadata(model_name)
if loader is None: if loader is None:
if shared.args.loader is not None: if shared.args.loader is not None:
loader = shared.args.loader loader = shared.args.loader
else: else:
loader = get_model_metadata(model_name)['loader'] loader = metadata['loader']
if loader is None: if loader is None:
logger.error('The path to the model does not exist. Exiting.') logger.error('The path to the model does not exist. Exiting.')
return None, None return None, None
@ -95,6 +96,7 @@ def load_model(model_name, loader=None):
if any((shared.args.xformers, shared.args.sdp_attention)): if any((shared.args.xformers, shared.args.sdp_attention)):
llama_attn_hijack.hijack_llama_attention() llama_attn_hijack.hijack_llama_attention()
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
logger.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.") logger.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer return model, tokenizer

View File

@ -6,33 +6,32 @@ import yaml
def default_preset(): def default_preset():
return { return {
'do_sample': True,
'temperature': 1, 'temperature': 1,
'temperature_last': False, 'temperature_last': False,
'top_p': 1, 'top_p': 1,
'min_p': 0, 'min_p': 0,
'top_k': 0, 'top_k': 0,
'typical_p': 1,
'epsilon_cutoff': 0,
'eta_cutoff': 0,
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1, 'repetition_penalty': 1,
'presence_penalty': 0, 'presence_penalty': 0,
'frequency_penalty': 0, 'frequency_penalty': 0,
'repetition_penalty_range': 0, 'repetition_penalty_range': 0,
'typical_p': 1,
'tfs': 1,
'top_a': 0,
'epsilon_cutoff': 0,
'eta_cutoff': 0,
'guidance_scale': 1,
'penalty_alpha': 0,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'do_sample': True,
'encoder_repetition_penalty': 1, 'encoder_repetition_penalty': 1,
'no_repeat_ngram_size': 0, 'no_repeat_ngram_size': 0,
'min_length': 0, 'min_length': 0,
'guidance_scale': 1,
'mirostat_mode': 0,
'mirostat_tau': 5.0,
'mirostat_eta': 0.1,
'penalty_alpha': 0,
'num_beams': 1, 'num_beams': 1,
'length_penalty': 1, 'length_penalty': 1,
'early_stopping': False, 'early_stopping': False,
'custom_token_bans': '',
} }

View File

@ -39,21 +39,21 @@ settings = {
'max_new_tokens': 200, 'max_new_tokens': 200,
'max_new_tokens_min': 1, 'max_new_tokens_min': 1,
'max_new_tokens_max': 4096, 'max_new_tokens_max': 4096,
'seed': -1,
'negative_prompt': '', 'negative_prompt': '',
'seed': -1,
'truncation_length': 2048, 'truncation_length': 2048,
'truncation_length_min': 0, 'truncation_length_min': 0,
'truncation_length_max': 32768, 'truncation_length_max': 32768,
'custom_stopping_strings': '',
'auto_max_new_tokens': False,
'max_tokens_second': 0, 'max_tokens_second': 0,
'ban_eos_token': False, 'custom_stopping_strings': '',
'custom_token_bans': '', 'custom_token_bans': '',
'auto_max_new_tokens': False,
'ban_eos_token': False,
'add_bos_token': True, 'add_bos_token': True,
'skip_special_tokens': True, 'skip_special_tokens': True,
'stream': True, 'stream': True,
'name1': 'You',
'character': 'Assistant', 'character': 'Assistant',
'name1': 'You',
'instruction_template': 'Alpaca', 'instruction_template': 'Alpaca',
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>', 'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
'autoload_model': False, 'autoload_model': False,
@ -167,8 +167,8 @@ parser.add_argument('--ssl-certfile', type=str, help='The path to the SSL certif
parser.add_argument('--api', action='store_true', help='Enable the API extension.') parser.add_argument('--api', action='store_true', help='Enable the API extension.')
parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.') parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
parser.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None) parser.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None)
parser.add_argument('--api-blocking-port', type=int, default=5000, help='The listening port for the blocking API.') parser.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.')
parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.') parser.add_argument('--api-key', type=str, default='', help='API authentication key.')
# Multimodal # Multimodal
parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.') parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.')
@ -178,6 +178,8 @@ parser.add_argument('--notebook', action='store_true', help='DEPRECATED')
parser.add_argument('--chat', action='store_true', help='DEPRECATED') parser.add_argument('--chat', action='store_true', help='DEPRECATED')
parser.add_argument('--no-stream', action='store_true', help='DEPRECATED') parser.add_argument('--no-stream', action='store_true', help='DEPRECATED')
parser.add_argument('--mul_mat_q', action='store_true', help='DEPRECATED') parser.add_argument('--mul_mat_q', action='store_true', help='DEPRECATED')
parser.add_argument('--api-blocking-port', type=int, default=5000, help='DEPRECATED')
parser.add_argument('--api-streaming-port', type=int, default=5005, help='DEPRECATED')
args = parser.parse_args() args = parser.parse_args()
args_defaults = parser.parse_args([]) args_defaults = parser.parse_args([])
@ -233,10 +235,13 @@ def fix_loader_name(name):
return 'AutoAWQ' return 'AutoAWQ'
def add_extension(name): def add_extension(name, last=False):
if args.extensions is None: if args.extensions is None:
args.extensions = [name] args.extensions = [name]
elif 'api' not in args.extensions: elif last:
args.extensions = [x for x in args.extensions if x != name]
args.extensions.append(name)
elif name not in args.extensions:
args.extensions.append(name) args.extensions.append(name)
@ -246,14 +251,15 @@ def is_chat():
args.loader = fix_loader_name(args.loader) args.loader = fix_loader_name(args.loader)
# Activate the API extension
if args.api or args.public_api:
add_extension('api')
# Activate the multimodal extension # Activate the multimodal extension
if args.multimodal_pipeline is not None: if args.multimodal_pipeline is not None:
add_extension('multimodal') add_extension('multimodal')
# Activate the API extension
if args.api:
# add_extension('openai', last=True)
add_extension('api', last=True)
# Load model-specific settings # Load model-specific settings
with Path(f'{args.model_dir}/config.yaml') as p: with Path(f'{args.model_dir}/config.yaml') as p:
if p.exists(): if p.exists():

View File

@ -56,7 +56,10 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
# Find the stopping strings # Find the stopping strings
all_stop_strings = [] all_stop_strings = []
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")): for st in (stopping_strings, state['custom_stopping_strings']):
if type(st) is str:
st = ast.literal_eval(f"[{st}]")
if type(st) is list and len(st) > 0: if type(st) is list and len(st) > 0:
all_stop_strings += st all_stop_strings += st

View File

@ -215,9 +215,6 @@ def load_model_wrapper(selected_model, loader, autoload=False):
if 'instruction_template' in settings: if 'instruction_template' in settings:
output += '\n\nIt seems to be an instruction-following model with template "{}". In the chat tab, instruct or chat-instruct modes should be used.'.format(settings['instruction_template']) output += '\n\nIt seems to be an instruction-following model with template "{}". In the chat tab, instruct or chat-instruct modes should be used.'.format(settings['instruction_template'])
# Applying the changes to the global shared settings (in-memory)
shared.settings.update({k: v for k, v in settings.items() if k in shared.settings})
yield output yield output
else: else:
yield f"Failed to load `{selected_model}`." yield f"Failed to load `{selected_model}`."