mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-21 23:57:58 +01:00
Make OpenAI API the default API (#4430)
This commit is contained in:
parent
84d957ba62
commit
ec17a5d2b7
@ -22,7 +22,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
|
||||
* [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab#character)
|
||||
* Very efficient text streaming
|
||||
* Markdown output with LaTeX rendering, to use for instance with [GALACTICA](https://github.com/paperswithcode/galai)
|
||||
* API, including endpoints for websocket streaming ([see the examples](https://github.com/oobabooga/text-generation-webui/blob/main/api-examples))
|
||||
* OpenAI-compatible API server
|
||||
|
||||
## Documentation
|
||||
|
||||
@ -412,8 +412,8 @@ Optionally, you can use the following command-line flags:
|
||||
| `--api` | Enable the API extension. |
|
||||
| `--public-api` | Create a public URL for the API using Cloudfare. |
|
||||
| `--public-api-id PUBLIC_API_ID` | Tunnel ID for named Cloudflare Tunnel. Use together with public-api option. |
|
||||
| `--api-blocking-port BLOCKING_PORT` | The listening port for the blocking API. |
|
||||
| `--api-streaming-port STREAMING_PORT` | The listening port for the streaming API. |
|
||||
| `--api-port API_PORT` | The listening port for the API. |
|
||||
| `--api-key API_KEY` | API authentication key. |
|
||||
|
||||
#### Multimodal
|
||||
|
||||
|
@ -1,114 +0,0 @@
|
||||
import asyncio
|
||||
import html
|
||||
import json
|
||||
import sys
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
print("Websockets package not found. Make sure it's installed.")
|
||||
|
||||
# For local streaming, the websockets are hosted without ssl - ws://
|
||||
HOST = 'localhost:5005'
|
||||
URI = f'ws://{HOST}/api/v1/chat-stream'
|
||||
|
||||
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
|
||||
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
|
||||
|
||||
|
||||
async def run(user_input, history):
|
||||
# Note: the selected defaults change from time to time.
|
||||
request = {
|
||||
'user_input': user_input,
|
||||
'max_new_tokens': 250,
|
||||
'auto_max_new_tokens': False,
|
||||
'max_tokens_second': 0,
|
||||
'history': history,
|
||||
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
||||
'character': 'Example',
|
||||
'instruction_template': 'Vicuna-v1.1', # Will get autodetected if unset
|
||||
'your_name': 'You',
|
||||
# 'name1': 'name of user', # Optional
|
||||
# 'name2': 'name of character', # Optional
|
||||
# 'context': 'character context', # Optional
|
||||
# 'greeting': 'greeting', # Optional
|
||||
# 'name1_instruct': 'You', # Optional
|
||||
# 'name2_instruct': 'Assistant', # Optional
|
||||
# 'context_instruct': 'context_instruct', # Optional
|
||||
# 'turn_template': 'turn_template', # Optional
|
||||
'regenerate': False,
|
||||
'_continue': False,
|
||||
'chat_instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
|
||||
|
||||
# Generation params. If 'preset' is set to different than 'None', the values
|
||||
# in presets/preset-name.yaml are used instead of the individual numbers.
|
||||
'preset': 'None',
|
||||
'do_sample': True,
|
||||
'temperature': 0.7,
|
||||
'top_p': 0.1,
|
||||
'typical_p': 1,
|
||||
'epsilon_cutoff': 0, # In units of 1e-4
|
||||
'eta_cutoff': 0, # In units of 1e-4
|
||||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
'repetition_penalty': 1.18,
|
||||
'presence_penalty': 0,
|
||||
'frequency_penalty': 0,
|
||||
'repetition_penalty_range': 0,
|
||||
'top_k': 40,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0,
|
||||
'length_penalty': 1,
|
||||
'early_stopping': False,
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5,
|
||||
'mirostat_eta': 0.1,
|
||||
'grammar_string': '',
|
||||
'guidance_scale': 1,
|
||||
'negative_prompt': '',
|
||||
|
||||
'seed': -1,
|
||||
'add_bos_token': True,
|
||||
'truncation_length': 2048,
|
||||
'ban_eos_token': False,
|
||||
'custom_token_bans': '',
|
||||
'skip_special_tokens': True,
|
||||
'stopping_strings': []
|
||||
}
|
||||
|
||||
async with websockets.connect(URI, ping_interval=None) as websocket:
|
||||
await websocket.send(json.dumps(request))
|
||||
|
||||
while True:
|
||||
incoming_data = await websocket.recv()
|
||||
incoming_data = json.loads(incoming_data)
|
||||
|
||||
match incoming_data['event']:
|
||||
case 'text_stream':
|
||||
yield incoming_data['history']
|
||||
case 'stream_end':
|
||||
return
|
||||
|
||||
|
||||
async def print_response_stream(user_input, history):
|
||||
cur_len = 0
|
||||
async for new_history in run(user_input, history):
|
||||
cur_message = new_history['visible'][-1][1][cur_len:]
|
||||
cur_len += len(cur_message)
|
||||
print(html.unescape(cur_message), end='')
|
||||
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
user_input = "Please give me a step-by-step guide on how to plant a tree in my backyard."
|
||||
|
||||
# Basic example
|
||||
history = {'internal': [], 'visible': []}
|
||||
|
||||
# "Continue" example. Make sure to set '_continue' to True above
|
||||
# arr = [user_input, 'Surely, here is']
|
||||
# history = {'internal': [arr], 'visible': [arr]}
|
||||
|
||||
asyncio.run(print_response_stream(user_input, history))
|
@ -1,94 +0,0 @@
|
||||
import html
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
# For local streaming, the websockets are hosted without ssl - http://
|
||||
HOST = 'localhost:5000'
|
||||
URI = f'http://{HOST}/api/v1/chat'
|
||||
|
||||
# For reverse-proxied streaming, the remote will likely host with ssl - https://
|
||||
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/chat'
|
||||
|
||||
|
||||
def run(user_input, history):
|
||||
request = {
|
||||
'user_input': user_input,
|
||||
'max_new_tokens': 250,
|
||||
'auto_max_new_tokens': False,
|
||||
'max_tokens_second': 0,
|
||||
'history': history,
|
||||
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
||||
'character': 'Example',
|
||||
'instruction_template': 'Vicuna-v1.1', # Will get autodetected if unset
|
||||
'your_name': 'You',
|
||||
# 'name1': 'name of user', # Optional
|
||||
# 'name2': 'name of character', # Optional
|
||||
# 'context': 'character context', # Optional
|
||||
# 'greeting': 'greeting', # Optional
|
||||
# 'name1_instruct': 'You', # Optional
|
||||
# 'name2_instruct': 'Assistant', # Optional
|
||||
# 'context_instruct': 'context_instruct', # Optional
|
||||
# 'turn_template': 'turn_template', # Optional
|
||||
'regenerate': False,
|
||||
'_continue': False,
|
||||
'chat_instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
|
||||
|
||||
# Generation params. If 'preset' is set to different than 'None', the values
|
||||
# in presets/preset-name.yaml are used instead of the individual numbers.
|
||||
'preset': 'None',
|
||||
'do_sample': True,
|
||||
'temperature': 0.7,
|
||||
'top_p': 0.1,
|
||||
'typical_p': 1,
|
||||
'epsilon_cutoff': 0, # In units of 1e-4
|
||||
'eta_cutoff': 0, # In units of 1e-4
|
||||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
'repetition_penalty': 1.18,
|
||||
'presence_penalty': 0,
|
||||
'frequency_penalty': 0,
|
||||
'repetition_penalty_range': 0,
|
||||
'top_k': 40,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0,
|
||||
'length_penalty': 1,
|
||||
'early_stopping': False,
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5,
|
||||
'mirostat_eta': 0.1,
|
||||
'grammar_string': '',
|
||||
'guidance_scale': 1,
|
||||
'negative_prompt': '',
|
||||
|
||||
'seed': -1,
|
||||
'add_bos_token': True,
|
||||
'truncation_length': 2048,
|
||||
'ban_eos_token': False,
|
||||
'custom_token_bans': '',
|
||||
'skip_special_tokens': True,
|
||||
'stopping_strings': []
|
||||
}
|
||||
|
||||
response = requests.post(URI, json=request)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()['results'][0]['history']
|
||||
print(json.dumps(result, indent=4))
|
||||
print()
|
||||
print(html.unescape(result['visible'][-1][1]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
user_input = "Please give me a step-by-step guide on how to plant a tree in my backyard."
|
||||
|
||||
# Basic example
|
||||
history = {'internal': [], 'visible': []}
|
||||
|
||||
# "Continue" example. Make sure to set '_continue' to True above
|
||||
# arr = [user_input, 'Surely, here is']
|
||||
# history = {'internal': [arr], 'visible': [arr]}
|
||||
|
||||
run(user_input, history)
|
@ -1,88 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
print("Websockets package not found. Make sure it's installed.")
|
||||
|
||||
# For local streaming, the websockets are hosted without ssl - ws://
|
||||
HOST = 'localhost:5005'
|
||||
URI = f'ws://{HOST}/api/v1/stream'
|
||||
|
||||
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
|
||||
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
|
||||
|
||||
|
||||
async def run(context):
|
||||
# Note: the selected defaults change from time to time.
|
||||
request = {
|
||||
'prompt': context,
|
||||
'max_new_tokens': 250,
|
||||
'auto_max_new_tokens': False,
|
||||
'max_tokens_second': 0,
|
||||
|
||||
# Generation params. If 'preset' is set to different than 'None', the values
|
||||
# in presets/preset-name.yaml are used instead of the individual numbers.
|
||||
'preset': 'None',
|
||||
'do_sample': True,
|
||||
'temperature': 0.7,
|
||||
'top_p': 0.1,
|
||||
'typical_p': 1,
|
||||
'epsilon_cutoff': 0, # In units of 1e-4
|
||||
'eta_cutoff': 0, # In units of 1e-4
|
||||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
'repetition_penalty': 1.18,
|
||||
'presence_penalty': 0,
|
||||
'frequency_penalty': 0,
|
||||
'repetition_penalty_range': 0,
|
||||
'top_k': 40,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0,
|
||||
'length_penalty': 1,
|
||||
'early_stopping': False,
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5,
|
||||
'mirostat_eta': 0.1,
|
||||
'grammar_string': '',
|
||||
'guidance_scale': 1,
|
||||
'negative_prompt': '',
|
||||
|
||||
'seed': -1,
|
||||
'add_bos_token': True,
|
||||
'truncation_length': 2048,
|
||||
'ban_eos_token': False,
|
||||
'custom_token_bans': '',
|
||||
'skip_special_tokens': True,
|
||||
'stopping_strings': []
|
||||
}
|
||||
|
||||
async with websockets.connect(URI, ping_interval=None) as websocket:
|
||||
await websocket.send(json.dumps(request))
|
||||
|
||||
yield context # Remove this if you just want to see the reply
|
||||
|
||||
while True:
|
||||
incoming_data = await websocket.recv()
|
||||
incoming_data = json.loads(incoming_data)
|
||||
|
||||
match incoming_data['event']:
|
||||
case 'text_stream':
|
||||
yield incoming_data['text']
|
||||
case 'stream_end':
|
||||
return
|
||||
|
||||
|
||||
async def print_response_stream(prompt):
|
||||
async for response in run(prompt):
|
||||
print(response, end='')
|
||||
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
prompt = "In order to make homemade bread, follow these steps:\n1)"
|
||||
asyncio.run(print_response_stream(prompt))
|
@ -1,65 +0,0 @@
|
||||
import requests
|
||||
|
||||
# For local streaming, the websockets are hosted without ssl - http://
|
||||
HOST = 'localhost:5000'
|
||||
URI = f'http://{HOST}/api/v1/generate'
|
||||
|
||||
# For reverse-proxied streaming, the remote will likely host with ssl - https://
|
||||
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate'
|
||||
|
||||
|
||||
def run(prompt):
|
||||
request = {
|
||||
'prompt': prompt,
|
||||
'max_new_tokens': 250,
|
||||
'auto_max_new_tokens': False,
|
||||
'max_tokens_second': 0,
|
||||
|
||||
# Generation params. If 'preset' is set to different than 'None', the values
|
||||
# in presets/preset-name.yaml are used instead of the individual numbers.
|
||||
'preset': 'None',
|
||||
'do_sample': True,
|
||||
'temperature': 0.7,
|
||||
'top_p': 0.1,
|
||||
'typical_p': 1,
|
||||
'epsilon_cutoff': 0, # In units of 1e-4
|
||||
'eta_cutoff': 0, # In units of 1e-4
|
||||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
'repetition_penalty': 1.18,
|
||||
'presence_penalty': 0,
|
||||
'frequency_penalty': 0,
|
||||
'repetition_penalty_range': 0,
|
||||
'top_k': 40,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0,
|
||||
'length_penalty': 1,
|
||||
'early_stopping': False,
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5,
|
||||
'mirostat_eta': 0.1,
|
||||
'grammar_string': '',
|
||||
'guidance_scale': 1,
|
||||
'negative_prompt': '',
|
||||
|
||||
'seed': -1,
|
||||
'add_bos_token': True,
|
||||
'truncation_length': 2048,
|
||||
'ban_eos_token': False,
|
||||
'custom_token_bans': '',
|
||||
'skip_special_tokens': True,
|
||||
'stopping_strings': []
|
||||
}
|
||||
|
||||
response = requests.post(URI, json=request)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()['results'][0]['text']
|
||||
print(prompt + result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
prompt = "In order to make homemade bread, follow these steps:\n1)"
|
||||
run(prompt)
|
@ -1,124 +1,64 @@
|
||||
# An OpenedAI API (openai like)
|
||||
## OpenAI compatible API
|
||||
|
||||
This extension creates an API that works kind of like openai (ie. api.openai.com).
|
||||
This project includes an API compatible with multiple OpenAI endpoints, including Chat and Completions.
|
||||
|
||||
## Setup & installation
|
||||
|
||||
Install the requirements:
|
||||
If you did not use the one-click installers, you may need to install the requirements first:
|
||||
|
||||
```
|
||||
pip3 install -r requirements.txt
|
||||
pip install -r extensions/openai/requirements.txt
|
||||
```
|
||||
|
||||
It listens on `tcp port 5001` by default. You can use the `OPENEDAI_PORT` environment variable to change this.
|
||||
### Starting the API
|
||||
|
||||
Make sure you enable it in server launch parameters, it should include:
|
||||
Add `--extensions openai` to your command-line flags.
|
||||
|
||||
* To create a public Cloudflare URL, add the `--public-api` flag.
|
||||
* To listen on your local network, add the `--listen` flag.
|
||||
* To change the port, which is 5000 by default, use `--port 1234` (change 1234 to your desired port number).
|
||||
* To use SSL, add `--ssl-keyfile key.pem --ssl-certfile cert.pem`. Note that it doesn't work with `--public-api`.
|
||||
|
||||
#### Environment variables
|
||||
|
||||
The following environment variables can be used (they take precendence over everything else):
|
||||
|
||||
| Variable Name | Description | Example Value |
|
||||
|------------------------|------------------------------------|----------------------------|
|
||||
| `OPENEDAI_PORT` | Port number | 5000 |
|
||||
| `OPENEDAI_CERT_PATH` | SSL certificate file path | cert.pem |
|
||||
| `OPENEDAI_KEY_PATH` | SSL key file path | key.pem |
|
||||
| `OPENEDAI_DEBUG` | Enable debugging (set to 1) | 1 |
|
||||
| `SD_WEBUI_URL` | WebUI URL (used by endpoint) | http://127.0.0.1:7861 |
|
||||
| `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | all-mpnet-base-v2 |
|
||||
| `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda |
|
||||
|
||||
#### Persistent settings with `settings.yaml`
|
||||
|
||||
You can also set default values by adding these lines to your `settings.yaml` file:
|
||||
|
||||
```
|
||||
--extensions openai
|
||||
```
|
||||
|
||||
You can also use the `--listen` argument to make the server available on the networ, and/or the `--share` argument to enable a public Cloudflare endpoint.
|
||||
|
||||
To enable the basic image generation support (txt2img) set the environment variable `SD_WEBUI_URL` to point to your Stable Diffusion API ([Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui)).
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
SD_WEBUI_URL=http://127.0.0.1:7861
|
||||
```
|
||||
|
||||
## Quick start
|
||||
|
||||
1. Install the requirements.txt (pip)
|
||||
2. Enable the `openeai` module (--extensions openai), restart the server.
|
||||
3. Configure the openai client
|
||||
|
||||
Most openai application can be configured to connect the API if you set the following environment variables:
|
||||
|
||||
```shell
|
||||
# Sample .env file:
|
||||
OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
|
||||
OPENAI_API_BASE=http://0.0.0.0:5001/v1
|
||||
```
|
||||
|
||||
If needed, replace 0.0.0.0 with the IP/port of your server.
|
||||
|
||||
|
||||
### Settings
|
||||
|
||||
To adjust your default settings, you can add the following to your `settings.yaml` file.
|
||||
|
||||
```
|
||||
openai-port: 5002
|
||||
openai-embedding_device: cuda
|
||||
openai-embedding_model: all-mpnet-base-v2
|
||||
openai-sd_webui_url: http://127.0.0.1:7861
|
||||
openai-debug: 1
|
||||
```
|
||||
|
||||
If you've configured the environment variables, please note that settings from `settings.yaml` won't take effect. For instance, if you set `openai-port: 5002` in `settings.yaml` but `OPENEDAI_PORT=5001` in the environment variables, the extension will use `5001` as the port number.
|
||||
|
||||
When using `cache_embedding_model.py` to preload the embedding model during Docker image building, consider the following:
|
||||
|
||||
- If you wish to use the default settings, leave the environment variables unset.
|
||||
- If you intend to change the default embedding model, ensure that you configure the environment variable `OPENEDAI_EMBEDDING_MODEL` to the desired model. Avoid setting `openai-embedding_model` in `settings.yaml` because those settings only take effect after the server starts.
|
||||
|
||||
### Models
|
||||
|
||||
This has been successfully tested with Alpaca, Koala, Vicuna, WizardLM and their variants, (ex. gpt4-x-alpaca, GPT4all-snoozy, stable-vicuna, wizard-vicuna, etc.) and many others. Models that have been trained for **Instruction Following** work best. If you test with other models please let me know how it goes. Less than satisfying results (so far) from: RWKV-4-Raven, llama, mpt-7b-instruct/chat.
|
||||
|
||||
For best results across all API endpoints, a model like [vicuna-13b-v1.3-GPTQ](https://huggingface.co/TheBloke/vicuna-13b-v1.3-GPTQ), [stable-vicuna-13B-GPTQ](https://huggingface.co/TheBloke/stable-vicuna-13B-GPTQ) or [airoboros-13B-gpt4-1.3-GPTQ](https://huggingface.co/TheBloke/airoboros-13B-gpt4-1.3-GPTQ) is a good start.
|
||||
|
||||
For good results with the [Completions](https://platform.openai.com/docs/api-reference/completions) API endpoint, in addition to the above models, you can also try using a base model like [falcon-7b](https://huggingface.co/tiiuae/falcon-7b) or Llama.
|
||||
|
||||
For good results with the [ChatCompletions](https://platform.openai.com/docs/api-reference/chat) or [Edits](https://platform.openai.com/docs/api-reference/edits) API endpoints you can use almost any model trained for instruction following. Be sure that the proper instruction template is detected and loaded or the results will not be good.
|
||||
|
||||
For the proper instruction format to be detected you need to have a matching model entry in your `models/config.yaml` file. Be sure to keep this file up to date.
|
||||
A matching instruction template file in the characters/instruction-following/ folder will loaded and applied to format messages correctly for the model - this is critical for good results.
|
||||
|
||||
For example, the Wizard-Vicuna family of models are trained with the Vicuna 1.1 format. In the models/config.yaml file there is this matching entry:
|
||||
|
||||
```
|
||||
.*wizard.*vicuna:
|
||||
mode: 'instruct'
|
||||
instruction_template: 'Vicuna-v1.1'
|
||||
```
|
||||
|
||||
This refers to `characters/instruction-following/Vicuna-v1.1.yaml`, which looks like this:
|
||||
|
||||
```
|
||||
user: "USER:"
|
||||
bot: "ASSISTANT:"
|
||||
turn_template: "<|user|> <|user-message|>\n<|bot|> <|bot-message|></s>\n"
|
||||
context: "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n"
|
||||
```
|
||||
|
||||
For most common models this is already setup, but if you are using a new or uncommon model you may need add a matching entry to the models/config.yaml and possibly create your own instruction-following template and for best results.
|
||||
|
||||
If you see this in your logs, it probably means that the correct format could not be loaded:
|
||||
|
||||
```
|
||||
Warning: Loaded default instruction-following template for model.
|
||||
```
|
||||
|
||||
### Embeddings (alpha)
|
||||
|
||||
Embeddings requires `sentence-transformers` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: `sentence-transformers/all-mpnet-base-v2` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default `text-embedding-ada-002` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future.
|
||||
|
||||
| model name | dimensions | input max tokens | speed | size | Avg. performance |
|
||||
| ---------------------- | ---------- | ---------------- | ----- | ---- | ---------------- |
|
||||
| text-embedding-ada-002 | 1536 | 8192 | - | - | - |
|
||||
| text-davinci-002 | 768 | 2046 | - | - | - |
|
||||
| all-mpnet-base-v2 | 768 | 384 | 2800 | 420M | 63.3 |
|
||||
| all-MiniLM-L6-v2 | 384 | 256 | 14200 | 80M | 58.8 |
|
||||
|
||||
In short, the all-MiniLM-L6-v2 model is 5x faster, 5x smaller ram, 2x smaller storage, and still offers good quality. Stats from (https://www.sbert.net/docs/pretrained_models.html). To change the model from the default you can set the environment variable `OPENEDAI_EMBEDDING_MODEL`, ex. "OPENEDAI_EMBEDDING_MODEL=all-MiniLM-L6-v2".
|
||||
|
||||
Warning: You cannot mix embeddings from different models even if they have the same dimensions. They are not comparable.
|
||||
### Examples
|
||||
|
||||
### Client Application Setup
|
||||
|
||||
Almost everything you use it with will require you to set a dummy OpenAI API key environment variable.
|
||||
|
||||
You can usually force an application that uses the OpenAI API to connect to the local API by using the following environment variables:
|
||||
|
||||
```shell
|
||||
OPENAI_API_HOST=http://127.0.0.1:5000
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```shell
|
||||
OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
|
||||
OPENAI_API_BASE=http://127.0.0.1:500/v1
|
||||
```
|
||||
|
||||
With the [official python openai client](https://github.com/openai/openai-python), set the `OPENAI_API_BASE` environment variables:
|
||||
|
||||
@ -128,7 +68,7 @@ OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
|
||||
OPENAI_API_BASE=http://0.0.0.0:5001/v1
|
||||
```
|
||||
|
||||
If needed, replace 0.0.0.0 with the IP/port of your server.
|
||||
If needed, replace 127.0.0.1 with the IP/port of your server.
|
||||
|
||||
If using .env files to save the `OPENAI_API_BASE` and `OPENAI_API_KEY` variables, make sure the .env file is loaded before the openai module is imported:
|
||||
|
||||
@ -157,8 +97,22 @@ const api = new ChatGPTAPI({
|
||||
apiBaseUrl: process.env.OPENAI_API_BASE
|
||||
});
|
||||
```
|
||||
### Embeddings (alpha)
|
||||
|
||||
## API Documentation & Examples
|
||||
Embeddings requires `sentence-transformers` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: `sentence-transformers/all-mpnet-base-v2` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default `text-embedding-ada-002` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future.
|
||||
|
||||
| model name | dimensions | input max tokens | speed | size | Avg. performance |
|
||||
| ---------------------- | ---------- | ---------------- | ----- | ---- | ---------------- |
|
||||
| text-embedding-ada-002 | 1536 | 8192 | - | - | - |
|
||||
| text-davinci-002 | 768 | 2046 | - | - | - |
|
||||
| all-mpnet-base-v2 | 768 | 384 | 2800 | 420M | 63.3 |
|
||||
| all-MiniLM-L6-v2 | 384 | 256 | 14200 | 80M | 58.8 |
|
||||
|
||||
In short, the all-MiniLM-L6-v2 model is 5x faster, 5x smaller ram, 2x smaller storage, and still offers good quality. Stats from (https://www.sbert.net/docs/pretrained_models.html). To change the model from the default you can set the environment variable `OPENEDAI_EMBEDDING_MODEL`, ex. "OPENEDAI_EMBEDDING_MODEL=all-MiniLM-L6-v2".
|
||||
|
||||
Warning: You cannot mix embeddings from different models even if they have the same dimensions. They are not comparable.
|
||||
|
||||
### API Documentation & Examples
|
||||
|
||||
The OpenAI API is well documented, you can view the documentation here: https://platform.openai.com/docs/api-reference
|
||||
|
||||
@ -185,7 +139,7 @@ text = response['choices'][0]['message']['content']
|
||||
print(text)
|
||||
```
|
||||
|
||||
## Compatibility & not so compatibility
|
||||
### Compatibility & not so compatibility
|
||||
|
||||
| API endpoint | tested with | notes |
|
||||
| ------------------------- | ---------------------------------- | --------------------------------------------------------------------------- |
|
||||
@ -195,7 +149,7 @@ print(text)
|
||||
| /v1/moderations | openai.Moderation.create() | Basic initial support via embeddings |
|
||||
| /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options |
|
||||
| /v1/models/{id} | openai.Model.get() | returns whatever you ask for |
|
||||
| /v1/edits | openai.Edit.create() | Deprecated by openai, good with instruction following models |
|
||||
| /v1/edits | openai.Edit.create() | Removed, use /v1/chat/completions instead |
|
||||
| /v1/text_completion | openai.Completion.create() | Legacy endpoint, variable quality based on the model |
|
||||
| /v1/completions | openai api completions.create | Legacy endpoint (v0.25) |
|
||||
| /v1/engines/\*/embeddings | python-openai v0.25 | Legacy endpoint |
|
||||
@ -209,28 +163,8 @@ print(text)
|
||||
| /v1/fine-tunes\* | openai.FineTune.\* | not yet supported |
|
||||
| /v1/search | openai.search, engines.search | not yet supported |
|
||||
|
||||
Because of the differences in OpenAI model context sizes (2k, 4k, 8k, 16k, etc,) you may need to adjust the max_tokens to fit into the context of the model you choose.
|
||||
|
||||
Streaming, temperature, top_p, max_tokens, stop, should all work as expected, but not all parameters are mapped correctly.
|
||||
|
||||
Some hacky mappings:
|
||||
|
||||
| OpenAI | text-generation-webui | note |
|
||||
| ----------------------- | -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| model | - | Ignored, the model is not changed |
|
||||
| frequency_penalty | encoder_repetition_penalty | this seems to operate with a different scale and defaults, I tried to scale it based on range & defaults, but the results are terrible. hardcoded to 1.18 until there is a better way |
|
||||
| presence_penalty | repetition_penalty | same issues as frequency_penalty, hardcoded to 1.0 |
|
||||
| best_of | top_k | default is 1 (top_k is 20 for chat, which doesn't support best_of) |
|
||||
| n | 1 | variations are not supported yet. |
|
||||
| 1 | num_beams | hardcoded to 1 |
|
||||
| 1.0 | typical_p | hardcoded to 1.0 |
|
||||
| logprobs & logit_bias | - | experimental, llama only, transformers-kin only (ExLlama_HF ok), can also use llama tokens if 'model' is not an openai model or will convert from tiktoken for the openai model specified in 'model' |
|
||||
| messages.name | - | not supported yet |
|
||||
| suffix | - | not supported yet |
|
||||
| user | - | not supported yet |
|
||||
| functions/function_call | - | function calls are not supported yet |
|
||||
|
||||
### Applications
|
||||
#### Applications
|
||||
|
||||
Almost everything needs the `OPENAI_API_KEY` and `OPENAI_API_BASE` environment variable set, but there are some exceptions.
|
||||
|
||||
@ -249,15 +183,3 @@ Almost everything needs the `OPENAI_API_KEY` and `OPENAI_API_BASE` environment v
|
||||
| ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context |
|
||||
| ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
||||
| ❌ | guidance | https://github.com/microsoft/guidance | logit_bias and logprobs not yet supported |
|
||||
|
||||
## Future plans
|
||||
|
||||
- better error handling
|
||||
- model changing, esp. something for swapping loras or embedding models
|
||||
- consider switching to FastAPI + starlette for SSE (openai SSE seems non-standard)
|
||||
|
||||
## Bugs? Feedback? Comments? Pull requests?
|
||||
|
||||
To enable debugging and get copious output you can set the `OPENEDAI_DEBUG=1` environment variable.
|
||||
|
||||
Are all appreciated, please @matatonic and I'll try to get back to you as soon as possible.
|
@ -3,9 +3,11 @@ import time
|
||||
import extensions.api.blocking_api as blocking_api
|
||||
import extensions.api.streaming_api as streaming_api
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
|
||||
def setup():
|
||||
logger.warning("The current API is deprecated and will be replaced with the OpenAI compatible API on November xxth. To test the new API, use \"--extensions openai\" instead of \"--api\".")
|
||||
blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id)
|
||||
if shared.args.public_api:
|
||||
time.sleep(5)
|
||||
|
@ -1,18 +1,23 @@
|
||||
import copy
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
import tiktoken
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import yaml
|
||||
from extensions.openai.defaults import clamp, default, get_default_req_params
|
||||
from extensions.openai.errors import InvalidRequestError
|
||||
from extensions.openai.utils import debug_msg, end_line
|
||||
from extensions.openai.utils import debug_msg
|
||||
from modules import shared
|
||||
from modules.chat import (
|
||||
generate_chat_prompt,
|
||||
generate_chat_reply,
|
||||
load_character_memoized
|
||||
)
|
||||
from modules.presets import load_preset_memoized
|
||||
from modules.text_generation import decode, encode, generate_reply
|
||||
from transformers import LogitsProcessor, LogitsProcessorList
|
||||
|
||||
|
||||
# Thanks to @Cypherfox [Cypherfoxy] for the logits code, blame to @matatonic
|
||||
class LogitsBiasProcessor(LogitsProcessor):
|
||||
def __init__(self, logit_bias={}):
|
||||
self.logit_bias = logit_bias
|
||||
@ -28,6 +33,7 @@ class LogitsBiasProcessor(LogitsProcessor):
|
||||
logits[0, self.keys] += self.values
|
||||
debug_msg(" --> ", logits[0, self.keys])
|
||||
debug_msg(" max/min ", float(torch.max(logits[0])), float(torch.min(logits[0])))
|
||||
|
||||
return logits
|
||||
|
||||
def __repr__(self):
|
||||
@ -47,6 +53,7 @@ class LogprobProcessor(LogitsProcessor):
|
||||
top_probs = [float(x) for x in top_values[0]]
|
||||
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
||||
debug_msg(repr(self))
|
||||
|
||||
return logits
|
||||
|
||||
def __repr__(self):
|
||||
@ -66,43 +73,28 @@ def convert_logprobs_to_tiktoken(model, logprobs):
|
||||
return logprobs
|
||||
|
||||
|
||||
def marshal_common_params(body):
|
||||
# Request Parameters
|
||||
# Try to use openai defaults or map them to something with the same intent
|
||||
def process_parameters(body, is_legacy=False):
|
||||
generate_params = body
|
||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||
generate_params['max_new_tokens'] = body.pop(max_tokens_str)
|
||||
if generate_params['truncation_length'] == 0:
|
||||
if shared.args.loader and shared.args.loader.lower().startswith('exllama'):
|
||||
generate_params['truncation_length'] = shared.args.max_seq_len
|
||||
elif shared.args.loader and shared.args.loader in ['llama.cpp', 'llamacpp_HF', 'ctransformers']:
|
||||
generate_params['truncation_length'] = shared.args.n_ctx
|
||||
else:
|
||||
generate_params['truncation_length'] = shared.settings['truncation_length']
|
||||
|
||||
req_params = get_default_req_params()
|
||||
|
||||
# Common request parameters
|
||||
req_params['truncation_length'] = shared.settings['truncation_length']
|
||||
req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token'])
|
||||
req_params['seed'] = shared.settings.get('seed', req_params['seed'])
|
||||
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
|
||||
|
||||
# OpenAI API Parameters
|
||||
# model - ignored for now, TODO: When we can reliably load a model or lora from a name only change this
|
||||
req_params['requested_model'] = body.get('model', shared.model_name)
|
||||
|
||||
req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
||||
req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.01, 1.99) # fixup absolute 0.0/2.0
|
||||
req_params['top_p'] = clamp(default(body, 'top_p', req_params['top_p']), 0.01, 1.0)
|
||||
n = default(body, 'n', 1)
|
||||
if n != 1:
|
||||
raise InvalidRequestError(message="Only n = 1 is supported.", param='n')
|
||||
if body['preset'] is not None:
|
||||
preset = load_preset_memoized(body['preset'])
|
||||
generate_params.update(preset)
|
||||
|
||||
generate_params['custom_stopping_strings'] = []
|
||||
if 'stop' in body: # str or array, max len 4 (ignored)
|
||||
if isinstance(body['stop'], str):
|
||||
req_params['stopping_strings'] = [body['stop']] # non-standard parameter
|
||||
generate_params['custom_stopping_strings'] = [body['stop']]
|
||||
elif isinstance(body['stop'], list):
|
||||
req_params['stopping_strings'] = body['stop']
|
||||
|
||||
# presence_penalty - ignored
|
||||
# frequency_penalty - ignored
|
||||
|
||||
# pass through unofficial params
|
||||
req_params['repetition_penalty'] = default(body, 'repetition_penalty', req_params['repetition_penalty'])
|
||||
req_params['encoder_repetition_penalty'] = default(body, 'encoder_repetition_penalty', req_params['encoder_repetition_penalty'])
|
||||
|
||||
# user - ignored
|
||||
generate_params['custom_stopping_strings'] = body['stop']
|
||||
|
||||
logits_processor = []
|
||||
logit_bias = body.get('logit_bias', None)
|
||||
@ -110,12 +102,13 @@ def marshal_common_params(body):
|
||||
# XXX convert tokens from tiktoken based on requested model
|
||||
# Ex.: 'logit_bias': {'1129': 100, '11442': 100, '16243': 100}
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(req_params['requested_model'])
|
||||
encoder = tiktoken.encoding_for_model(generate_params['model'])
|
||||
new_logit_bias = {}
|
||||
for logit, bias in logit_bias.items():
|
||||
for x in encode(encoder.decode([int(logit)]), add_special_tokens=False)[0]:
|
||||
if int(x) in [0, 1, 2, 29871]: # XXX LLAMA tokens
|
||||
continue
|
||||
|
||||
new_logit_bias[str(int(x))] = bias
|
||||
debug_msg('logit_bias_map', logit_bias, '->', new_logit_bias)
|
||||
logit_bias = new_logit_bias
|
||||
@ -126,238 +119,129 @@ def marshal_common_params(body):
|
||||
|
||||
logprobs = None # coming to chat eventually
|
||||
if 'logprobs' in body:
|
||||
logprobs = default(body, 'logprobs', 0) # maybe cap at topk? don't clamp 0-5.
|
||||
req_params['logprob_proc'] = LogprobProcessor(logprobs)
|
||||
logits_processor.extend([req_params['logprob_proc']])
|
||||
logprobs = body.get('logprobs', 0) # maybe cap at topk? don't clamp 0-5.
|
||||
generate_params['logprob_proc'] = LogprobProcessor(logprobs)
|
||||
logits_processor.extend([generate_params['logprob_proc']])
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
if logits_processor: # requires logits_processor support
|
||||
req_params['logits_processor'] = LogitsProcessorList(logits_processor)
|
||||
generate_params['logits_processor'] = LogitsProcessorList(logits_processor)
|
||||
|
||||
return req_params
|
||||
return generate_params
|
||||
|
||||
|
||||
def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
||||
# functions
|
||||
if body.get('functions', []): # chat only
|
||||
def convert_history(history):
|
||||
'''
|
||||
Chat histories in this program are in the format [message, reply].
|
||||
This function converts OpenAI histories to that format.
|
||||
'''
|
||||
chat_dialogue = []
|
||||
current_message = ""
|
||||
current_reply = ""
|
||||
user_input = ""
|
||||
|
||||
for entry in history:
|
||||
content = entry["content"]
|
||||
role = entry["role"]
|
||||
|
||||
if role == "user":
|
||||
user_input = content
|
||||
if current_message:
|
||||
chat_dialogue.append([current_message, ''])
|
||||
current_message = ""
|
||||
current_message = content
|
||||
elif role == "assistant":
|
||||
current_reply = content
|
||||
if current_message:
|
||||
chat_dialogue.append([current_message, current_reply])
|
||||
current_message = ""
|
||||
current_reply = ""
|
||||
else:
|
||||
chat_dialogue.append(['', current_reply])
|
||||
|
||||
# if current_message:
|
||||
# chat_dialogue.append([current_message, ''])
|
||||
|
||||
return user_input, {'internal': chat_dialogue, 'visible': copy.deepcopy(chat_dialogue)}
|
||||
|
||||
|
||||
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False) -> dict:
|
||||
if body.get('functions', []):
|
||||
raise InvalidRequestError(message="functions is not supported.", param='functions')
|
||||
if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'}
|
||||
|
||||
if body.get('function_call', ''):
|
||||
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
|
||||
|
||||
if 'messages' not in body:
|
||||
raise InvalidRequestError(message="messages is required", param='messages')
|
||||
|
||||
messages = body['messages']
|
||||
|
||||
role_formats = {
|
||||
'user': 'User: {message}\n',
|
||||
'assistant': 'Assistant: {message}\n',
|
||||
'system': '{message}',
|
||||
'context': 'You are a helpful assistant. Answer as concisely as possible.\nUser: I want your assistance.\nAssistant: Sure! What can I do for you?',
|
||||
'prompt': 'Assistant:',
|
||||
}
|
||||
|
||||
if 'stopping_strings' not in req_params:
|
||||
req_params['stopping_strings'] = []
|
||||
|
||||
# Instruct models can be much better
|
||||
if shared.settings['instruction_template']:
|
||||
try:
|
||||
instruct = yaml.safe_load(open(f"instruction-templates/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||
|
||||
template = instruct['turn_template']
|
||||
system_message_template = "{message}"
|
||||
system_message_default = instruct.get('context', '') # can be missing
|
||||
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
|
||||
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct.get('user', ''))
|
||||
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct.get('bot', ''))
|
||||
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
|
||||
|
||||
role_formats = {
|
||||
'user': user_message_template,
|
||||
'assistant': bot_message_template,
|
||||
'system': system_message_template,
|
||||
'context': system_message_default,
|
||||
'prompt': bot_prompt,
|
||||
}
|
||||
|
||||
if 'Alpaca' in shared.settings['instruction_template']:
|
||||
req_params['stopping_strings'].extend(['\n###'])
|
||||
elif instruct['user']: # WizardLM and some others have no user prompt.
|
||||
req_params['stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
|
||||
|
||||
debug_msg(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
||||
|
||||
except Exception as e:
|
||||
req_params['stopping_strings'].extend(['\nUser:', 'User:']) # XXX User: prompt here also
|
||||
|
||||
print(f"Exception: When loading instruction-templates/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
||||
print("Warning: Loaded default instruction-following template for model.")
|
||||
|
||||
else:
|
||||
req_params['stopping_strings'].extend(['\nUser:', 'User:']) # XXX User: prompt here also
|
||||
print("Warning: Loaded default instruction-following template for model.")
|
||||
|
||||
system_msgs = []
|
||||
chat_msgs = []
|
||||
|
||||
# You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
|
||||
context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else ''
|
||||
context_msg = end_line(context_msg)
|
||||
|
||||
# Maybe they sent both? This is not documented in the API, but some clients seem to do this.
|
||||
if 'prompt' in body:
|
||||
context_msg = end_line(role_formats['system'].format(message=body['prompt'])) + context_msg
|
||||
|
||||
for m in messages:
|
||||
if 'role' not in m:
|
||||
raise InvalidRequestError(message="messages: missing role", param='messages')
|
||||
elif m['role'] == 'function':
|
||||
raise InvalidRequestError(message="role: function is not supported.", param='messages')
|
||||
if 'content' not in m:
|
||||
raise InvalidRequestError(message="messages: missing content", param='messages')
|
||||
|
||||
role = m['role']
|
||||
content = m['content']
|
||||
# name = m.get('name', None)
|
||||
# function_call = m.get('function_call', None) # user name or function name with output in content
|
||||
msg = role_formats[role].format(message=content)
|
||||
if role == 'system':
|
||||
system_msgs.extend([msg])
|
||||
elif role == 'function':
|
||||
raise InvalidRequestError(message="role: function is not supported.", param='messages')
|
||||
else:
|
||||
chat_msgs.extend([msg])
|
||||
|
||||
system_msg = '\n'.join(system_msgs)
|
||||
system_msg = end_line(system_msg)
|
||||
|
||||
prompt = system_msg + context_msg + ''.join(chat_msgs) + role_formats['prompt']
|
||||
|
||||
token_count = len(encode(prompt)[0])
|
||||
|
||||
if token_count >= req_params['truncation_length']:
|
||||
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens."
|
||||
raise InvalidRequestError(message=err_msg, param='messages')
|
||||
|
||||
if max_tokens > 0 and token_count + max_tokens > req_params['truncation_length']:
|
||||
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens and max_tokens is {max_tokens}."
|
||||
print(f"Warning: ${err_msg}")
|
||||
# raise InvalidRequestError(message=err_msg, params='max_tokens')
|
||||
|
||||
return prompt, token_count
|
||||
|
||||
|
||||
def chat_completions(body: dict, is_legacy: bool = False) -> dict:
|
||||
# Chat Completions
|
||||
object_type = 'chat.completions'
|
||||
object_type = 'chat.completions' if not stream else 'chat.completions.chunk'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
# common params
|
||||
req_params = marshal_common_params(body)
|
||||
req_params['stream'] = False
|
||||
requested_model = req_params.pop('requested_model')
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
|
||||
# generation parameters
|
||||
generate_params = process_parameters(body, is_legacy=is_legacy)
|
||||
continue_ = body['continue_']
|
||||
|
||||
# chat default max_tokens is 'inf', but also flexible
|
||||
max_tokens = 0
|
||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||
if max_tokens_str in body:
|
||||
max_tokens = default(body, max_tokens_str, req_params['truncation_length'])
|
||||
req_params['max_new_tokens'] = max_tokens
|
||||
else:
|
||||
req_params['max_new_tokens'] = req_params['truncation_length']
|
||||
# Instruction template
|
||||
instruction_template = body['instruction_template'] or shared.settings['instruction_template']
|
||||
name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character_memoized(instruction_template, '', '', instruct=True)
|
||||
name1_instruct = body['name1_instruct'] or name1_instruct
|
||||
name2_instruct = body['name2_instruct'] or name2_instruct
|
||||
context_instruct = body['context_instruct'] or context_instruct
|
||||
turn_template = body['turn_template'] or turn_template
|
||||
|
||||
# format the prompt from messages
|
||||
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']
|
||||
# Chat character
|
||||
character = body['character'] or shared.settings['character']
|
||||
name1 = body['name1'] or shared.settings['name1']
|
||||
name1, name2, _, greeting, context, _ = load_character_memoized(character, name1, '', instruct=False)
|
||||
name2 = body['name2'] or name2
|
||||
context = body['context'] or context
|
||||
greeting = body['greeting'] or greeting
|
||||
|
||||
# set real max, avoid deeper errors
|
||||
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
|
||||
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
||||
# History
|
||||
user_input, history = convert_history(messages)
|
||||
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
generate_params.update({
|
||||
'mode': body['mode'],
|
||||
'name1': name1,
|
||||
'name2': name2,
|
||||
'context': context,
|
||||
'greeting': greeting,
|
||||
'name1_instruct': name1_instruct,
|
||||
'name2_instruct': name2_instruct,
|
||||
'context_instruct': context_instruct,
|
||||
'turn_template': turn_template,
|
||||
'chat-instruct_command': body['chat_instruct_command'],
|
||||
'history': history,
|
||||
'stream': stream
|
||||
})
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
max_tokens = generate_params['max_new_tokens']
|
||||
if max_tokens in [None, 0]:
|
||||
generate_params['max_new_tokens'] = 200
|
||||
generate_params['auto_max_new_tokens'] = True
|
||||
|
||||
answer = ''
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
# strip extra leading space off new generated content
|
||||
if answer and answer[0] == ' ':
|
||||
answer = answer[1:]
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']:
|
||||
stop_reason = "length"
|
||||
|
||||
resp = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name, # TODO: add Lora info?
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": stop_reason,
|
||||
"message": {"role": "assistant", "content": answer}
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
}
|
||||
if logprob_proc: # not official for chat yet
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
# else:
|
||||
# resp[resp_list][0]["logprobs"] = None
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
# generator
|
||||
def stream_chat_completions(body: dict, is_legacy: bool = False):
|
||||
|
||||
# Chat Completions
|
||||
stream_object_type = 'chat.completions.chunk'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
# common params
|
||||
req_params = marshal_common_params(body)
|
||||
req_params['stream'] = True
|
||||
requested_model = req_params.pop('requested_model')
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
|
||||
|
||||
# chat default max_tokens is 'inf', but also flexible
|
||||
max_tokens = 0
|
||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||
if max_tokens_str in body:
|
||||
max_tokens = default(body, max_tokens_str, req_params['truncation_length'])
|
||||
req_params['max_new_tokens'] = max_tokens
|
||||
else:
|
||||
req_params['max_new_tokens'] = req_params['truncation_length']
|
||||
|
||||
# format the prompt from messages
|
||||
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']
|
||||
|
||||
# set real max, avoid deeper errors
|
||||
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
|
||||
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
||||
requested_model = generate_params.pop('model')
|
||||
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||
|
||||
def chat_streaming_chunk(content):
|
||||
# begin streaming
|
||||
chunk = {
|
||||
"id": cmpl_id,
|
||||
"object": stream_object_type,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
resp_list: [{
|
||||
@ -376,262 +260,262 @@ def stream_chat_completions(body: dict, is_legacy: bool = False):
|
||||
# chunk[resp_list][0]["logprobs"] = None
|
||||
return chunk
|
||||
|
||||
yield chat_streaming_chunk('')
|
||||
if stream:
|
||||
yield chat_streaming_chunk('')
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
prompt = generate_chat_prompt(user_input, generate_params)
|
||||
token_count = len(encode(prompt)[0])
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
generator = generate_chat_reply(
|
||||
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
|
||||
|
||||
answer = ''
|
||||
seen_content = ''
|
||||
completion_token_count = 0
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
answer = a['internal'][-1][1]
|
||||
if stream:
|
||||
len_seen = len(seen_content)
|
||||
new_content = answer[len_seen:]
|
||||
|
||||
len_seen = len(seen_content)
|
||||
new_content = answer[len_seen:]
|
||||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||
continue
|
||||
|
||||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||
continue
|
||||
seen_content = answer
|
||||
|
||||
seen_content = answer
|
||||
# strip extra leading space off new generated content
|
||||
if len_seen == 0 and new_content[0] == ' ':
|
||||
new_content = new_content[1:]
|
||||
|
||||
# strip extra leading space off new generated content
|
||||
if len_seen == 0 and new_content[0] == ' ':
|
||||
new_content = new_content[1:]
|
||||
chunk = chat_streaming_chunk(new_content)
|
||||
|
||||
chunk = chat_streaming_chunk(new_content)
|
||||
|
||||
yield chunk
|
||||
|
||||
# to get the correct token_count, strip leading space if present
|
||||
if answer and answer[0] == ' ':
|
||||
answer = answer[1:]
|
||||
yield chunk
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']:
|
||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
|
||||
stop_reason = "length"
|
||||
|
||||
chunk = chat_streaming_chunk('')
|
||||
chunk[resp_list][0]['finish_reason'] = stop_reason
|
||||
chunk['usage'] = {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
if stream:
|
||||
chunk = chat_streaming_chunk('')
|
||||
chunk[resp_list][0]['finish_reason'] = stop_reason
|
||||
chunk['usage'] = {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
|
||||
yield chunk
|
||||
yield chunk
|
||||
else:
|
||||
resp = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": stop_reason,
|
||||
"message": {"role": "assistant", "content": answer}
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
}
|
||||
if logprob_proc: # not official for chat yet
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
# else:
|
||||
# resp[resp_list][0]["logprobs"] = None
|
||||
|
||||
yield resp
|
||||
|
||||
|
||||
def completions(body: dict, is_legacy: bool = False):
|
||||
# Legacy
|
||||
# Text Completions
|
||||
object_type = 'text_completion'
|
||||
def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
||||
object_type = 'text_completion.chunk' if stream else 'text_completion'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
||||
prompt_str = 'context' if is_legacy else 'prompt'
|
||||
|
||||
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
||||
if prompt_str not in body:
|
||||
raise InvalidRequestError("Missing required input", param=prompt_str)
|
||||
|
||||
prompt_arg = body[prompt_str]
|
||||
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)):
|
||||
prompt_arg = [prompt_arg]
|
||||
|
||||
# common params
|
||||
req_params = marshal_common_params(body)
|
||||
req_params['stream'] = False
|
||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
|
||||
req_params['max_new_tokens'] = max_tokens
|
||||
requested_model = req_params.pop('requested_model')
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
# req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
||||
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
||||
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
||||
generate_params = process_parameters(body, is_legacy=is_legacy)
|
||||
max_tokens = generate_params['max_new_tokens']
|
||||
generate_params['stream'] = stream
|
||||
requested_model = generate_params.pop('model')
|
||||
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||
# generate_params['suffix'] = body.get('suffix', generate_params['suffix'])
|
||||
generate_params['echo'] = body.get('echo', generate_params['echo'])
|
||||
|
||||
resp_list_data = []
|
||||
total_completion_token_count = 0
|
||||
total_prompt_token_count = 0
|
||||
if not stream:
|
||||
prompt_arg = body[prompt_str]
|
||||
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)):
|
||||
prompt_arg = [prompt_arg]
|
||||
|
||||
for idx, prompt in enumerate(prompt_arg, start=0):
|
||||
if isinstance(prompt[0], int):
|
||||
# token lists
|
||||
if requested_model == shared.model_name:
|
||||
prompt = decode(prompt)[0]
|
||||
else:
|
||||
resp_list_data = []
|
||||
total_completion_token_count = 0
|
||||
total_prompt_token_count = 0
|
||||
|
||||
for idx, prompt in enumerate(prompt_arg, start=0):
|
||||
if isinstance(prompt[0], int):
|
||||
# token lists
|
||||
if requested_model == shared.model_name:
|
||||
prompt = decode(prompt)[0]
|
||||
else:
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(requested_model)
|
||||
prompt = encoder.decode(prompt)
|
||||
except KeyError:
|
||||
prompt = decode(prompt)[0]
|
||||
|
||||
token_count = len(encode(prompt)[0])
|
||||
total_prompt_token_count += token_count
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||
answer = ''
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
# strip extra leading space off new generated content
|
||||
if answer and answer[0] == ' ':
|
||||
answer = answer[1:]
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
total_completion_token_count += completion_token_count
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
||||
respi = {
|
||||
"index": idx,
|
||||
"finish_reason": stop_reason,
|
||||
"text": answer,
|
||||
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||
}
|
||||
|
||||
resp_list_data.extend([respi])
|
||||
|
||||
resp = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
resp_list: resp_list_data,
|
||||
"usage": {
|
||||
"prompt_tokens": total_prompt_token_count,
|
||||
"completion_tokens": total_completion_token_count,
|
||||
"total_tokens": total_prompt_token_count + total_completion_token_count
|
||||
}
|
||||
}
|
||||
|
||||
yield resp
|
||||
else:
|
||||
prompt = body[prompt_str]
|
||||
if isinstance(prompt, list):
|
||||
if prompt and isinstance(prompt[0], int):
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(requested_model)
|
||||
prompt = encoder.decode(prompt)
|
||||
except KeyError:
|
||||
prompt = decode(prompt)[0]
|
||||
else:
|
||||
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
||||
|
||||
token_count = len(encode(prompt)[0])
|
||||
total_prompt_token_count += token_count
|
||||
|
||||
if token_count + max_tokens > req_params['truncation_length']:
|
||||
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
|
||||
# print(f"Warning: ${err_msg}")
|
||||
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
|
||||
def text_streaming_chunk(content):
|
||||
# begin streaming
|
||||
chunk = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"text": content,
|
||||
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||
}],
|
||||
}
|
||||
|
||||
return chunk
|
||||
|
||||
yield text_streaming_chunk('')
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||
|
||||
answer = ''
|
||||
seen_content = ''
|
||||
completion_token_count = 0
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
# strip extra leading space off new generated content
|
||||
len_seen = len(seen_content)
|
||||
new_content = answer[len_seen:]
|
||||
|
||||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||
continue
|
||||
|
||||
seen_content = answer
|
||||
|
||||
# strip extra leading space off new generated content
|
||||
if len_seen == 0 and new_content[0] == ' ':
|
||||
new_content = new_content[1:]
|
||||
|
||||
chunk = text_streaming_chunk(new_content)
|
||||
|
||||
yield chunk
|
||||
|
||||
# to get the correct count, we strip the leading space if present
|
||||
if answer and answer[0] == ' ':
|
||||
answer = answer[1:]
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
total_completion_token_count += completion_token_count
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
||||
respi = {
|
||||
"index": idx,
|
||||
"finish_reason": stop_reason,
|
||||
"text": answer,
|
||||
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||
chunk = text_streaming_chunk('')
|
||||
chunk[resp_list][0]["finish_reason"] = stop_reason
|
||||
chunk["usage"] = {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
|
||||
resp_list_data.extend([respi])
|
||||
|
||||
resp = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name, # TODO: add Lora info?
|
||||
resp_list: resp_list_data,
|
||||
"usage": {
|
||||
"prompt_tokens": total_prompt_token_count,
|
||||
"completion_tokens": total_completion_token_count,
|
||||
"total_tokens": total_prompt_token_count + total_completion_token_count
|
||||
}
|
||||
}
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
# generator
|
||||
def stream_completions(body: dict, is_legacy: bool = False):
|
||||
# Legacy
|
||||
# Text Completions
|
||||
# object_type = 'text_completion'
|
||||
stream_object_type = 'text_completion.chunk'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
||||
prompt_str = 'context' if is_legacy else 'prompt'
|
||||
if prompt_str not in body:
|
||||
raise InvalidRequestError("Missing required input", param=prompt_str)
|
||||
|
||||
prompt = body[prompt_str]
|
||||
req_params = marshal_common_params(body)
|
||||
requested_model = req_params.pop('requested_model')
|
||||
if isinstance(prompt, list):
|
||||
if prompt and isinstance(prompt[0], int):
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(requested_model)
|
||||
prompt = encoder.decode(prompt)
|
||||
except KeyError:
|
||||
prompt = decode(prompt)[0]
|
||||
else:
|
||||
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
||||
|
||||
# common params
|
||||
req_params['stream'] = True
|
||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
|
||||
req_params['max_new_tokens'] = max_tokens
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
# req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
||||
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
||||
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
||||
|
||||
token_count = len(encode(prompt)[0])
|
||||
|
||||
if token_count + max_tokens > req_params['truncation_length']:
|
||||
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
|
||||
# print(f"Warning: ${err_msg}")
|
||||
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
|
||||
|
||||
def text_streaming_chunk(content):
|
||||
# begin streaming
|
||||
chunk = {
|
||||
"id": cmpl_id,
|
||||
"object": stream_object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"text": content,
|
||||
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||
}],
|
||||
}
|
||||
|
||||
return chunk
|
||||
|
||||
yield text_streaming_chunk('')
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
answer = ''
|
||||
seen_content = ''
|
||||
completion_token_count = 0
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
len_seen = len(seen_content)
|
||||
new_content = answer[len_seen:]
|
||||
|
||||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||
continue
|
||||
|
||||
seen_content = answer
|
||||
|
||||
# strip extra leading space off new generated content
|
||||
if len_seen == 0 and new_content[0] == ' ':
|
||||
new_content = new_content[1:]
|
||||
|
||||
chunk = text_streaming_chunk(new_content)
|
||||
|
||||
yield chunk
|
||||
|
||||
# to get the correct count, we strip the leading space if present
|
||||
if answer and answer[0] == ' ':
|
||||
answer = answer[1:]
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
def chat_completions(body: dict, is_legacy: bool = False) -> dict:
|
||||
generator = chat_completions_common(body, is_legacy, stream=False)
|
||||
return deque(generator, maxlen=1).pop()
|
||||
|
||||
chunk = text_streaming_chunk('')
|
||||
chunk[resp_list][0]["finish_reason"] = stop_reason
|
||||
chunk["usage"] = {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
|
||||
yield chunk
|
||||
def stream_chat_completions(body: dict, is_legacy: bool = False):
|
||||
for resp in chat_completions_common(body, is_legacy, stream=True):
|
||||
yield resp
|
||||
|
||||
|
||||
def completions(body: dict, is_legacy: bool = False) -> dict:
|
||||
generator = completions_common(body, is_legacy, stream=False)
|
||||
return deque(generator, maxlen=1).pop()
|
||||
|
||||
|
||||
def stream_completions(body: dict, is_legacy: bool = False):
|
||||
for resp in completions_common(body, is_legacy, stream=True):
|
||||
yield resp
|
||||
|
@ -1,78 +0,0 @@
|
||||
import copy
|
||||
|
||||
# Slightly different defaults for OpenAI's API
|
||||
# Data type is important, Ex. use 0.0 for a float 0
|
||||
default_req_params = {
|
||||
'max_new_tokens': 16, # 'Inf' for chat
|
||||
'auto_max_new_tokens': False,
|
||||
'max_tokens_second': 0,
|
||||
'temperature': 1.0,
|
||||
'temperature_last': False,
|
||||
'top_p': 1.0,
|
||||
'min_p': 0,
|
||||
'top_k': 1, # choose 20 for chat in absence of another default
|
||||
'repetition_penalty': 1.18,
|
||||
'presence_penalty': 0,
|
||||
'frequency_penalty': 0,
|
||||
'repetition_penalty_range': 0,
|
||||
'encoder_repetition_penalty': 1.0,
|
||||
'suffix': None,
|
||||
'stream': False,
|
||||
'echo': False,
|
||||
'seed': -1,
|
||||
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
|
||||
'truncation_length': 2048, # first use shared.settings value
|
||||
'add_bos_token': True,
|
||||
'do_sample': True,
|
||||
'typical_p': 1.0,
|
||||
'epsilon_cutoff': 0.0, # In units of 1e-4
|
||||
'eta_cutoff': 0.0, # In units of 1e-4
|
||||
'tfs': 1.0,
|
||||
'top_a': 0.0,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0.0,
|
||||
'length_penalty': 1.0,
|
||||
'early_stopping': False,
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5.0,
|
||||
'mirostat_eta': 0.1,
|
||||
'grammar_string': '',
|
||||
'guidance_scale': 1,
|
||||
'negative_prompt': '',
|
||||
'ban_eos_token': False,
|
||||
'custom_token_bans': '',
|
||||
'skip_special_tokens': True,
|
||||
'custom_stopping_strings': '',
|
||||
# 'logits_processor' - conditionally passed
|
||||
# 'stopping_strings' - temporarily used
|
||||
# 'logprobs' - temporarily used
|
||||
# 'requested_model' - temporarily used
|
||||
}
|
||||
|
||||
|
||||
def get_default_req_params():
|
||||
return copy.deepcopy(default_req_params)
|
||||
|
||||
|
||||
def default(dic, key, default):
|
||||
'''
|
||||
little helper to get defaults if arg is present but None and should be the same type as default.
|
||||
'''
|
||||
val = dic.get(key, default)
|
||||
if not isinstance(val, type(default)):
|
||||
# maybe it's just something like 1 instead of 1.0
|
||||
try:
|
||||
v = type(default)(val)
|
||||
if type(val)(v) == val: # if it's the same value passed in, it's ok.
|
||||
return v
|
||||
except:
|
||||
pass
|
||||
|
||||
val = default
|
||||
return val
|
||||
|
||||
|
||||
def clamp(value, minvalue, maxvalue):
|
||||
return max(minvalue, min(value, maxvalue))
|
@ -1,101 +0,0 @@
|
||||
import time
|
||||
|
||||
import yaml
|
||||
from extensions.openai.defaults import get_default_req_params
|
||||
from extensions.openai.errors import InvalidRequestError
|
||||
from extensions.openai.utils import debug_msg
|
||||
from modules import shared
|
||||
from modules.text_generation import encode, generate_reply
|
||||
|
||||
|
||||
def edits(instruction: str, input: str, temperature=1.0, top_p=1.0) -> dict:
|
||||
|
||||
created_time = int(time.time() * 1000)
|
||||
|
||||
# Request parameters
|
||||
req_params = get_default_req_params()
|
||||
stopping_strings = []
|
||||
|
||||
# Alpaca is verbose so a good default prompt
|
||||
default_template = (
|
||||
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||
)
|
||||
|
||||
instruction_template = default_template
|
||||
|
||||
# Use the special instruction/input/response template for anything trained like Alpaca
|
||||
if shared.settings['instruction_template']:
|
||||
if 'Alpaca' in shared.settings['instruction_template']:
|
||||
stopping_strings.extend(['\n###'])
|
||||
else:
|
||||
try:
|
||||
instruct = yaml.safe_load(open(f"instruction-templates/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||
|
||||
template = instruct['turn_template']
|
||||
template = template\
|
||||
.replace('<|user|>', instruct.get('user', ''))\
|
||||
.replace('<|bot|>', instruct.get('bot', ''))\
|
||||
.replace('<|user-message|>', '{instruction}\n{input}')
|
||||
|
||||
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
||||
if instruct['user']:
|
||||
stopping_strings.extend(['\n' + instruct['user'], instruct['user']])
|
||||
|
||||
except Exception as e:
|
||||
instruction_template = default_template
|
||||
print(f"Exception: When loading instruction-templates/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
||||
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
||||
else:
|
||||
stopping_strings.extend(['\n###'])
|
||||
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
||||
|
||||
edit_task = instruction_template.format(instruction=instruction, input=input)
|
||||
|
||||
truncation_length = shared.settings['truncation_length']
|
||||
|
||||
token_count = len(encode(edit_task)[0])
|
||||
max_tokens = truncation_length - token_count
|
||||
|
||||
if max_tokens < 1:
|
||||
err_msg = f"This model maximum context length is {truncation_length} tokens. However, your messages resulted in over {truncation_length - max_tokens} tokens."
|
||||
raise InvalidRequestError(err_msg, param='input')
|
||||
|
||||
req_params['max_new_tokens'] = max_tokens
|
||||
req_params['truncation_length'] = truncation_length
|
||||
req_params['temperature'] = temperature
|
||||
req_params['top_p'] = top_p
|
||||
req_params['seed'] = shared.settings.get('seed', req_params['seed'])
|
||||
req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token'])
|
||||
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
|
||||
|
||||
debug_msg({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
||||
|
||||
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
answer = ''
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
# some reply's have an extra leading space to fit the instruction template, just clip it off from the reply.
|
||||
if edit_task[-1] != '\n' and answer and answer[0] == ' ':
|
||||
answer = answer[1:]
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
|
||||
resp = {
|
||||
"object": "edit",
|
||||
"created": created_time,
|
||||
"choices": [{
|
||||
"text": answer,
|
||||
"index": 0,
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
}
|
||||
|
||||
return resp
|
@ -6,9 +6,13 @@ from extensions.openai.utils import debug_msg, float_list_to_base64
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
embeddings_params_initialized = False
|
||||
# using 'lazy loading' to avoid circular import
|
||||
# so this function will be executed only once
|
||||
|
||||
|
||||
def initialize_embedding_params():
|
||||
'''
|
||||
using 'lazy loading' to avoid circular import
|
||||
so this function will be executed only once
|
||||
'''
|
||||
global embeddings_params_initialized
|
||||
if not embeddings_params_initialized:
|
||||
global st_model, embeddings_model, embeddings_device
|
||||
@ -26,7 +30,7 @@ def load_embedding_model(model: str) -> SentenceTransformer:
|
||||
initialize_embedding_params()
|
||||
global embeddings_device, embeddings_model
|
||||
try:
|
||||
print(f"\Try embedding model: {model} on {embeddings_device}")
|
||||
print(f"Try embedding model: {model} on {embeddings_device}")
|
||||
# see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer
|
||||
embeddings_model = SentenceTransformer(model, device=embeddings_device)
|
||||
# ... embeddings_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM
|
||||
@ -54,7 +58,7 @@ def get_embeddings(input: list) -> np.ndarray:
|
||||
model = get_embeddings_model()
|
||||
debug_msg(f"embedding model : {model}")
|
||||
embedding = model.encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False)
|
||||
debug_msg(f"embedding result : {embedding}") # might be too long even for debug, use at you own will
|
||||
debug_msg(f"embedding result : {embedding}") # might be too long even for debug, use at you own will
|
||||
return embedding
|
||||
|
||||
|
||||
|
@ -50,6 +50,7 @@ def generations(prompt: str, size: str, response_format: str, n: int):
|
||||
'data': []
|
||||
}
|
||||
from extensions.openai.script import params
|
||||
|
||||
# TODO: support SD_WEBUI_AUTH username:password pair.
|
||||
sd_url = f"{os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', ''))}/sdapi/v1/txt2img"
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
SpeechRecognition==3.10.0
|
||||
flask_cloudflared==0.0.12
|
||||
flask_cloudflared==0.0.14
|
||||
sentence-transformers
|
||||
sse-starlette==1.6.5
|
||||
tiktoken
|
||||
|
@ -1,351 +1,255 @@
|
||||
import json
|
||||
import os
|
||||
import ssl
|
||||
import traceback
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from threading import Thread
|
||||
|
||||
import extensions.openai.completions as OAIcompletions
|
||||
import extensions.openai.edits as OAIedits
|
||||
import extensions.openai.embeddings as OAIembeddings
|
||||
import extensions.openai.images as OAIimages
|
||||
import extensions.openai.models as OAImodels
|
||||
import extensions.openai.moderations as OAImoderations
|
||||
from extensions.openai.defaults import clamp, default, get_default_req_params
|
||||
from extensions.openai.errors import (
|
||||
InvalidRequestError,
|
||||
OpenAIError,
|
||||
ServiceUnavailableError
|
||||
)
|
||||
from extensions.openai.tokens import token_count, token_decode, token_encode
|
||||
from extensions.openai.utils import debug_msg
|
||||
from modules import shared
|
||||
|
||||
import cgi
|
||||
import speech_recognition as sr
|
||||
import uvicorn
|
||||
from extensions.openai.errors import ServiceUnavailableError
|
||||
from extensions.openai.tokens import token_count, token_decode, token_encode
|
||||
from extensions.openai.utils import _start_cloudflared
|
||||
from fastapi import Depends, FastAPI, Header, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from pydub import AudioSegment
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
from .typing import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
to_dict
|
||||
)
|
||||
|
||||
params = {
|
||||
# default params
|
||||
'port': 5001,
|
||||
'embedding_device': 'cpu',
|
||||
'embedding_model': 'all-mpnet-base-v2',
|
||||
|
||||
# optional params
|
||||
'sd_webui_url': '',
|
||||
'debug': 0
|
||||
}
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def send_access_control_headers(self):
|
||||
self.send_header("Access-Control-Allow-Origin", "*")
|
||||
self.send_header("Access-Control-Allow-Credentials", "true")
|
||||
self.send_header(
|
||||
"Access-Control-Allow-Methods",
|
||||
"GET,HEAD,OPTIONS,POST,PUT"
|
||||
)
|
||||
self.send_header(
|
||||
"Access-Control-Allow-Headers",
|
||||
"Origin, Accept, X-Requested-With, Content-Type, "
|
||||
"Access-Control-Request-Method, Access-Control-Request-Headers, "
|
||||
"Authorization"
|
||||
)
|
||||
|
||||
def do_OPTIONS(self):
|
||||
self.send_response(200)
|
||||
self.send_access_control_headers()
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write("OK".encode('utf-8'))
|
||||
def verify_api_key(authorization: str = Header(None)) -> None:
|
||||
expected_api_key = shared.args.api_key
|
||||
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
def start_sse(self):
|
||||
self.send_response(200)
|
||||
self.send_access_control_headers()
|
||||
self.send_header('Content-Type', 'text/event-stream')
|
||||
self.send_header('Cache-Control', 'no-cache')
|
||||
# self.send_header('Connection', 'keep-alive')
|
||||
self.end_headers()
|
||||
|
||||
def send_sse(self, chunk: dict):
|
||||
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
||||
debug_msg(response[:-4])
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
app = FastAPI(dependencies=[Depends(verify_api_key)])
|
||||
|
||||
def end_sse(self):
|
||||
response = 'data: [DONE]\r\n\r\n'
|
||||
debug_msg(response[:-4])
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
# Configure CORS settings to allow all origins, methods, and headers
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "HEAD", "OPTIONS", "POST", "PUT"],
|
||||
allow_headers=[
|
||||
"Origin",
|
||||
"Accept",
|
||||
"X-Requested-With",
|
||||
"Content-Type",
|
||||
"Access-Control-Request-Method",
|
||||
"Access-Control-Request-Headers",
|
||||
"Authorization",
|
||||
],
|
||||
)
|
||||
|
||||
def return_json(self, ret: dict, code: int = 200, no_debug=False):
|
||||
self.send_response(code)
|
||||
self.send_access_control_headers()
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
|
||||
response = json.dumps(ret)
|
||||
r_utf8 = response.encode('utf-8')
|
||||
@app.options("/")
|
||||
async def options_route():
|
||||
return JSONResponse(content="OK")
|
||||
|
||||
self.send_header('Content-Length', str(len(r_utf8)))
|
||||
self.end_headers()
|
||||
|
||||
self.wfile.write(r_utf8)
|
||||
if not no_debug:
|
||||
debug_msg(r_utf8)
|
||||
@app.post('/v1/completions', response_model=CompletionResponse)
|
||||
@app.post('/v1/generate', response_model=CompletionResponse)
|
||||
async def openai_completions(request: Request, request_data: CompletionRequest):
|
||||
path = request.url.path
|
||||
is_legacy = "/generate" in path
|
||||
|
||||
def openai_error(self, message, code=500, error_type='APIError', param='', internal_message=''):
|
||||
if request_data.stream:
|
||||
async def generator():
|
||||
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy)
|
||||
for resp in response:
|
||||
yield {"data": json.dumps(resp)}
|
||||
|
||||
error_resp = {
|
||||
'error': {
|
||||
'message': message,
|
||||
'code': code,
|
||||
'type': error_type,
|
||||
'param': param,
|
||||
}
|
||||
}
|
||||
if internal_message:
|
||||
print(error_type, message)
|
||||
print(internal_message)
|
||||
# error_resp['internal_message'] = internal_message
|
||||
return EventSourceResponse(generator()) # SSE streaming
|
||||
|
||||
self.return_json(error_resp, code)
|
||||
else:
|
||||
response = OAIcompletions.completions(to_dict(request_data), is_legacy=is_legacy)
|
||||
return JSONResponse(response)
|
||||
|
||||
def openai_error_handler(func):
|
||||
def wrapper(self):
|
||||
try:
|
||||
func(self)
|
||||
except InvalidRequestError as e:
|
||||
self.openai_error(e.message, e.code, e.__class__.__name__, e.param, internal_message=e.internal_message)
|
||||
except OpenAIError as e:
|
||||
self.openai_error(e.message, e.code, e.__class__.__name__, internal_message=e.internal_message)
|
||||
except Exception as e:
|
||||
self.openai_error(repr(e), 500, 'OpenAIError', internal_message=traceback.format_exc())
|
||||
|
||||
return wrapper
|
||||
@app.post('/v1/chat/completions', response_model=ChatCompletionResponse)
|
||||
async def openai_chat_completions(request: Request, request_data: ChatCompletionRequest):
|
||||
path = request.url.path
|
||||
is_legacy = "/generate" in path
|
||||
|
||||
@openai_error_handler
|
||||
def do_GET(self):
|
||||
debug_msg(self.requestline)
|
||||
debug_msg(self.headers)
|
||||
if request_data.stream:
|
||||
async def generator():
|
||||
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy)
|
||||
for resp in response:
|
||||
yield {"data": json.dumps(resp)}
|
||||
|
||||
if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'):
|
||||
is_legacy = 'engines' in self.path
|
||||
is_list = self.path.split('?')[0].split('#')[0] in ['/v1/engines', '/v1/models']
|
||||
if is_legacy and not is_list:
|
||||
model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):]
|
||||
resp = OAImodels.load_model(model_name)
|
||||
elif is_list:
|
||||
resp = OAImodels.list_models(is_legacy)
|
||||
else:
|
||||
model_name = self.path[len('/v1/models/'):]
|
||||
resp = OAImodels.model_info(model_name)
|
||||
return EventSourceResponse(generator()) # SSE streaming
|
||||
|
||||
self.return_json(resp)
|
||||
else:
|
||||
response = OAIcompletions.chat_completions(to_dict(request_data), is_legacy=is_legacy)
|
||||
return JSONResponse(response)
|
||||
|
||||
elif '/billing/usage' in self.path:
|
||||
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
||||
self.return_json({"total_usage": 0}, no_debug=True)
|
||||
|
||||
else:
|
||||
self.send_error(404)
|
||||
@app.get("/v1/models")
|
||||
@app.get("/v1/engines")
|
||||
async def handle_models(request: Request):
|
||||
path = request.url.path
|
||||
is_legacy = 'engines' in path
|
||||
is_list = request.url.path.split('?')[0].split('#')[0] in ['/v1/engines', '/v1/models']
|
||||
|
||||
@openai_error_handler
|
||||
def do_POST(self):
|
||||
if is_legacy and not is_list:
|
||||
model_name = path[path.find('/v1/engines/') + len('/v1/engines/'):]
|
||||
resp = OAImodels.load_model(model_name)
|
||||
elif is_list:
|
||||
resp = OAImodels.list_models(is_legacy)
|
||||
else:
|
||||
model_name = path[len('/v1/models/'):]
|
||||
resp = OAImodels.model_info(model_name)
|
||||
|
||||
if '/v1/audio/transcriptions' in self.path:
|
||||
r = sr.Recognizer()
|
||||
return JSONResponse(content=resp)
|
||||
|
||||
# Parse the form data
|
||||
form = cgi.FieldStorage(
|
||||
fp=self.rfile,
|
||||
headers=self.headers,
|
||||
environ={'REQUEST_METHOD': 'POST', 'CONTENT_TYPE': self.headers['Content-Type']}
|
||||
)
|
||||
|
||||
audio_file = form['file'].file
|
||||
audio_data = AudioSegment.from_file(audio_file)
|
||||
|
||||
# Convert AudioSegment to raw data
|
||||
raw_data = audio_data.raw_data
|
||||
|
||||
# Create AudioData object
|
||||
audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
|
||||
whipser_language = form.getvalue('language', None)
|
||||
whipser_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny
|
||||
|
||||
transcription = {"text": ""}
|
||||
|
||||
try:
|
||||
transcription["text"] = r.recognize_whisper(audio_data, language=whipser_language, model=whipser_model)
|
||||
except sr.UnknownValueError:
|
||||
print("Whisper could not understand audio")
|
||||
transcription["text"] = "Whisper could not understand audio UnknownValueError"
|
||||
except sr.RequestError as e:
|
||||
print("Could not request results from Whisper", e)
|
||||
transcription["text"] = "Whisper could not understand audio RequestError"
|
||||
|
||||
self.return_json(transcription, no_debug=True)
|
||||
return
|
||||
|
||||
debug_msg(self.requestline)
|
||||
debug_msg(self.headers)
|
||||
@app.get('/v1/billing/usage')
|
||||
def handle_billing_usage():
|
||||
'''
|
||||
Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
||||
'''
|
||||
return JSONResponse(content={"total_usage": 0})
|
||||
|
||||
content_length = self.headers.get('Content-Length')
|
||||
transfer_encoding = self.headers.get('Transfer-Encoding')
|
||||
|
||||
if content_length:
|
||||
body = json.loads(self.rfile.read(int(content_length)).decode('utf-8'))
|
||||
elif transfer_encoding == 'chunked':
|
||||
chunks = []
|
||||
while True:
|
||||
chunk_size = int(self.rfile.readline(), 16) # Read the chunk size
|
||||
if chunk_size == 0:
|
||||
break # End of chunks
|
||||
chunks.append(self.rfile.read(chunk_size))
|
||||
self.rfile.readline() # Consume the trailing newline after each chunk
|
||||
body = json.loads(b''.join(chunks).decode('utf-8'))
|
||||
else:
|
||||
self.send_response(400, "Bad Request: Either Content-Length or Transfer-Encoding header expected.")
|
||||
self.end_headers()
|
||||
return
|
||||
@app.post('/v1/audio/transcriptions')
|
||||
async def handle_audio_transcription(request: Request):
|
||||
r = sr.Recognizer()
|
||||
|
||||
debug_msg(body)
|
||||
form = await request.form()
|
||||
audio_file = await form["file"].read()
|
||||
audio_data = AudioSegment.from_file(audio_file)
|
||||
|
||||
if '/completions' in self.path or '/generate' in self.path:
|
||||
# Convert AudioSegment to raw data
|
||||
raw_data = audio_data.raw_data
|
||||
|
||||
if not shared.model:
|
||||
raise ServiceUnavailableError("No model loaded.")
|
||||
# Create AudioData object
|
||||
audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
|
||||
whipser_language = form.getvalue('language', None)
|
||||
whipser_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny
|
||||
|
||||
is_legacy = '/generate' in self.path
|
||||
is_streaming = body.get('stream', False)
|
||||
transcription = {"text": ""}
|
||||
|
||||
if is_streaming:
|
||||
self.start_sse()
|
||||
try:
|
||||
transcription["text"] = r.recognize_whisper(audio_data, language=whipser_language, model=whipser_model)
|
||||
except sr.UnknownValueError:
|
||||
print("Whisper could not understand audio")
|
||||
transcription["text"] = "Whisper could not understand audio UnknownValueError"
|
||||
except sr.RequestError as e:
|
||||
print("Could not request results from Whisper", e)
|
||||
transcription["text"] = "Whisper could not understand audio RequestError"
|
||||
|
||||
response = []
|
||||
if 'chat' in self.path:
|
||||
response = OAIcompletions.stream_chat_completions(body, is_legacy=is_legacy)
|
||||
else:
|
||||
response = OAIcompletions.stream_completions(body, is_legacy=is_legacy)
|
||||
return JSONResponse(content=transcription)
|
||||
|
||||
for resp in response:
|
||||
self.send_sse(resp)
|
||||
|
||||
self.end_sse()
|
||||
@app.post('/v1/images/generations')
|
||||
async def handle_image_generation(request: Request):
|
||||
|
||||
else:
|
||||
response = ''
|
||||
if 'chat' in self.path:
|
||||
response = OAIcompletions.chat_completions(body, is_legacy=is_legacy)
|
||||
else:
|
||||
response = OAIcompletions.completions(body, is_legacy=is_legacy)
|
||||
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
|
||||
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
|
||||
|
||||
self.return_json(response)
|
||||
body = await request.json()
|
||||
prompt = body['prompt']
|
||||
size = body.get('size', '1024x1024')
|
||||
response_format = body.get('response_format', 'url') # or b64_json
|
||||
n = body.get('n', 1) # ignore the batch limits of max 10
|
||||
|
||||
elif '/edits' in self.path:
|
||||
# deprecated
|
||||
response = await OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
|
||||
return JSONResponse(response)
|
||||
|
||||
if not shared.model:
|
||||
raise ServiceUnavailableError("No model loaded.")
|
||||
|
||||
req_params = get_default_req_params()
|
||||
@app.post("/v1/embeddings")
|
||||
async def handle_embeddings(request: Request):
|
||||
body = await request.json()
|
||||
encoding_format = body.get("encoding_format", "")
|
||||
|
||||
instruction = body['instruction']
|
||||
input = body.get('input', '')
|
||||
temperature = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
||||
top_p = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
|
||||
input = body.get('input', body.get('text', ''))
|
||||
if not input:
|
||||
raise HTTPException(status_code=400, detail="Missing required argument input")
|
||||
|
||||
response = OAIedits.edits(instruction, input, temperature, top_p)
|
||||
if type(input) is str:
|
||||
input = [input]
|
||||
|
||||
self.return_json(response)
|
||||
response = OAIembeddings.embeddings(input, encoding_format)
|
||||
return JSONResponse(response)
|
||||
|
||||
elif '/images/generations' in self.path:
|
||||
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
|
||||
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
|
||||
|
||||
prompt = body['prompt']
|
||||
size = default(body, 'size', '1024x1024')
|
||||
response_format = default(body, 'response_format', 'url') # or b64_json
|
||||
n = default(body, 'n', 1) # ignore the batch limits of max 10
|
||||
@app.post("/v1/moderations")
|
||||
async def handle_moderations(request: Request):
|
||||
body = await request.json()
|
||||
input = body["input"]
|
||||
if not input:
|
||||
raise HTTPException(status_code=400, detail="Missing required argument input")
|
||||
|
||||
response = OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
|
||||
response = OAImoderations.moderations(input)
|
||||
return JSONResponse(response)
|
||||
|
||||
self.return_json(response, no_debug=True)
|
||||
|
||||
elif '/embeddings' in self.path:
|
||||
encoding_format = body.get('encoding_format', '')
|
||||
@app.post("/api/v1/token-count")
|
||||
async def handle_token_count(request: Request):
|
||||
body = await request.json()
|
||||
response = token_count(body['prompt'])
|
||||
return JSONResponse(response)
|
||||
|
||||
input = body.get('input', body.get('text', ''))
|
||||
if not input:
|
||||
raise InvalidRequestError("Missing required argument input", params='input')
|
||||
|
||||
if type(input) is str:
|
||||
input = [input]
|
||||
@app.post("/api/v1/token/encode")
|
||||
async def handle_token_encode(request: Request):
|
||||
body = await request.json()
|
||||
encoding_format = body.get("encoding_format", "")
|
||||
response = token_encode(body["input"], encoding_format)
|
||||
return JSONResponse(response)
|
||||
|
||||
response = OAIembeddings.embeddings(input, encoding_format)
|
||||
|
||||
self.return_json(response, no_debug=True)
|
||||
|
||||
elif '/moderations' in self.path:
|
||||
input = body['input']
|
||||
if not input:
|
||||
raise InvalidRequestError("Missing required argument input", params='input')
|
||||
|
||||
response = OAImoderations.moderations(input)
|
||||
|
||||
self.return_json(response, no_debug=True)
|
||||
|
||||
elif self.path == '/api/v1/token-count':
|
||||
# NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side.
|
||||
response = token_count(body['prompt'])
|
||||
|
||||
self.return_json(response, no_debug=True)
|
||||
|
||||
elif self.path == '/api/v1/token/encode':
|
||||
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
|
||||
encoding_format = body.get('encoding_format', '')
|
||||
|
||||
response = token_encode(body['input'], encoding_format)
|
||||
|
||||
self.return_json(response, no_debug=True)
|
||||
|
||||
elif self.path == '/api/v1/token/decode':
|
||||
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
|
||||
encoding_format = body.get('encoding_format', '')
|
||||
|
||||
response = token_decode(body['input'], encoding_format)
|
||||
|
||||
self.return_json(response, no_debug=True)
|
||||
|
||||
else:
|
||||
self.send_error(404)
|
||||
@app.post("/api/v1/token/decode")
|
||||
async def handle_token_decode(request: Request):
|
||||
body = await request.json()
|
||||
encoding_format = body.get("encoding_format", "")
|
||||
response = token_decode(body["input"], encoding_format)
|
||||
return JSONResponse(response, no_debug=True)
|
||||
|
||||
|
||||
def run_server():
|
||||
port = int(os.environ.get('OPENEDAI_PORT', params.get('port', 5001)))
|
||||
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', port)
|
||||
server = ThreadingHTTPServer(server_addr, Handler)
|
||||
|
||||
ssl_certfile=os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile)
|
||||
ssl_keyfile=os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile)
|
||||
ssl_verify=True if (ssl_keyfile and ssl_certfile) else False
|
||||
if ssl_verify:
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
context.load_cert_chain(ssl_certfile, ssl_keyfile)
|
||||
server.socket = context.wrap_socket(server.socket, server_side=True)
|
||||
|
||||
if shared.args.share:
|
||||
try:
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
public_url = _run_cloudflared(port, port + 1)
|
||||
print(f'OpenAI compatible API ready at: OPENAI_API_BASE={public_url}/v1')
|
||||
except ImportError:
|
||||
print('You should install flask_cloudflared manually')
|
||||
server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
||||
port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port))
|
||||
|
||||
ssl_certfile = os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile)
|
||||
ssl_keyfile = os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile)
|
||||
|
||||
if shared.args.public_api:
|
||||
def on_start(public_url: str):
|
||||
logger.info(f'OpenAI compatible API URL:\n\n{public_url}/v1\n')
|
||||
|
||||
_start_cloudflared(port, shared.args.public_api_id, max_attempts=3, on_start=on_start)
|
||||
else:
|
||||
if ssl_verify:
|
||||
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=https://{server_addr[0]}:{server_addr[1]}/v1')
|
||||
if ssl_keyfile and ssl_certfile:
|
||||
logger.info(f'OpenAI compatible API URL:\n\nhttps://{server_addr}:{port}/v1\n')
|
||||
else:
|
||||
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')
|
||||
|
||||
server.serve_forever()
|
||||
logger.info(f'OpenAI compatible API URL:\n\nhttp://{server_addr}:{port}/v1\n')
|
||||
|
||||
if shared.args.api_key:
|
||||
logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n')
|
||||
|
||||
uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile)
|
||||
|
||||
|
||||
def setup():
|
||||
|
125
extensions/openai/typing.py
Normal file
125
extensions/openai/typing.py
Normal file
@ -0,0 +1,125 @@
|
||||
import json
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GenerationOptions(BaseModel):
|
||||
preset: str | None = None
|
||||
temperature: float = 1
|
||||
temperature_last: bool = False
|
||||
top_p: float = 1
|
||||
min_p: float = 0
|
||||
top_k: int = 0
|
||||
repetition_penalty: float = 1
|
||||
presence_penalty: float = 0
|
||||
frequency_penalty: float = 0
|
||||
repetition_penalty_range: int = 0
|
||||
typical_p: float = 1
|
||||
tfs: float = 1
|
||||
top_a: float = 0
|
||||
epsilon_cutoff: float = 0
|
||||
eta_cutoff: float = 0
|
||||
guidance_scale: float = 1
|
||||
negative_prompt: str = ''
|
||||
penalty_alpha: float = 0
|
||||
mirostat_mode: int = 0
|
||||
mirostat_tau: float = 5
|
||||
mirostat_eta: float = 0.1
|
||||
do_sample: bool = True
|
||||
seed: int = -1
|
||||
encoder_repetition_penalty: float = 1
|
||||
no_repeat_ngram_size: int = 0
|
||||
min_length: int = 0
|
||||
num_beams: int = 1
|
||||
length_penalty: float = 1
|
||||
early_stopping: bool = False
|
||||
truncation_length: int = 0
|
||||
max_tokens_second: int = 0
|
||||
custom_token_bans: str = ""
|
||||
auto_max_new_tokens: bool = False
|
||||
ban_eos_token: bool = False
|
||||
add_bos_token: bool = True
|
||||
skip_special_tokens: bool = True
|
||||
grammar_string: str = ""
|
||||
|
||||
|
||||
class CompletionRequest(GenerationOptions):
|
||||
model: str | None = None
|
||||
prompt: str | List[str]
|
||||
best_of: int | None = 1
|
||||
echo: bool | None = False
|
||||
frequency_penalty: float | None = 0
|
||||
logit_bias: dict | None = None
|
||||
logprobs: int | None = None
|
||||
max_tokens: int | None = 16
|
||||
n: int | None = 1
|
||||
presence_penalty: int | None = 0
|
||||
stop: str | List[str] | None = None
|
||||
stream: bool | None = False
|
||||
suffix: str | None = None
|
||||
temperature: float | None = 1
|
||||
top_p: float | None = 1
|
||||
user: str | None = None
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
id: str
|
||||
choices: List[dict]
|
||||
created: int = int(time.time())
|
||||
model: str
|
||||
object: str = "text_completion"
|
||||
usage: dict
|
||||
|
||||
|
||||
class ChatCompletionRequest(GenerationOptions):
|
||||
messages: List[dict]
|
||||
model: str | None = None
|
||||
frequency_penalty: float | None = 0
|
||||
function_call: str | dict | None = None
|
||||
functions: List[dict] | None = None
|
||||
logit_bias: dict | None = None
|
||||
max_tokens: int | None = None
|
||||
n: int | None = 1
|
||||
presence_penalty: int | None = 0
|
||||
stop: str | List[str] | None = None
|
||||
stream: bool | None = False
|
||||
temperature: float | None = 1
|
||||
top_p: float | None = 1
|
||||
user: str | None = None
|
||||
|
||||
mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.")
|
||||
|
||||
instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/instruction-templates. If not set, the correct template will be guessed using the regex expressions in models/config.yaml.")
|
||||
name1_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
|
||||
name2_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
|
||||
context_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
|
||||
turn_template: str | None = Field(default=None, description="Overwrites the value set by instruction_template.")
|
||||
|
||||
character: str | None = Field(default=None, description="A character defined under text-generation-webui/characters. If not set, the default \"Assistant\" character will be used.")
|
||||
name1: str | None = Field(default=None, description="Overwrites the value set by character.")
|
||||
name2: str | None = Field(default=None, description="Overwrites the value set by character.")
|
||||
context: str | None = Field(default=None, description="Overwrites the value set by character.")
|
||||
greeting: str | None = Field(default=None, description="Overwrites the value set by character.")
|
||||
|
||||
chat_instruct_command: str | None = None
|
||||
|
||||
continue_: bool = Field(default=False, description="Makes the last bot message in the history be continued instead of starting a new message.")
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str
|
||||
choices: List[dict]
|
||||
created: int = int(time.time())
|
||||
model: str
|
||||
object: str = "chat.completion"
|
||||
usage: dict
|
||||
|
||||
|
||||
def to_json(obj):
|
||||
return json.dumps(obj.__dict__, indent=4)
|
||||
|
||||
|
||||
def to_dict(obj):
|
||||
return obj.__dict__
|
@ -1,8 +1,12 @@
|
||||
import base64
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def float_list_to_base64(float_array: np.ndarray) -> str:
|
||||
# Convert the list to a float32 array that the OpenAPI client expects
|
||||
# float_array = np.array(float_list, dtype="float32")
|
||||
@ -18,13 +22,33 @@ def float_list_to_base64(float_array: np.ndarray) -> str:
|
||||
return ascii_string
|
||||
|
||||
|
||||
def end_line(s):
|
||||
if s and s[-1] != '\n':
|
||||
s = s + '\n'
|
||||
return s
|
||||
|
||||
|
||||
def debug_msg(*args, **kwargs):
|
||||
from extensions.openai.script import params
|
||||
if os.environ.get("OPENEDAI_DEBUG", params.get('debug', 0)):
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
|
||||
try:
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
except ImportError:
|
||||
print('You should install flask_cloudflared manually')
|
||||
raise Exception(
|
||||
'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.')
|
||||
|
||||
for _ in range(max_attempts):
|
||||
try:
|
||||
if tunnel_id is not None:
|
||||
public_url = _run_cloudflared(port, port + 1, tunnel_id=tunnel_id)
|
||||
else:
|
||||
public_url = _run_cloudflared(port, port + 1)
|
||||
|
||||
if on_start:
|
||||
on_start(public_url)
|
||||
|
||||
return
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
time.sleep(3)
|
||||
|
||||
raise Exception('Could not start cloudflared.')
|
||||
|
@ -81,7 +81,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||
# Find the maximum prompt size
|
||||
max_length = get_max_prompt_length(state)
|
||||
all_substrings = {
|
||||
'chat': get_turn_substrings(state, instruct=False),
|
||||
'chat': get_turn_substrings(state, instruct=False) if state['mode'] in ['chat', 'chat-instruct'] else None,
|
||||
'instruct': get_turn_substrings(state, instruct=True)
|
||||
}
|
||||
|
||||
@ -237,7 +237,10 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
||||
for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True)):
|
||||
|
||||
# Extract the reply
|
||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
||||
visible_reply = reply
|
||||
if state['mode'] in ['chat', 'chat-instruct']:
|
||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
||||
|
||||
visible_reply = html.escape(visible_reply)
|
||||
|
||||
if shared.stop_everything:
|
||||
|
@ -71,11 +71,12 @@ def load_model(model_name, loader=None):
|
||||
'AutoAWQ': AutoAWQ_loader,
|
||||
}
|
||||
|
||||
metadata = get_model_metadata(model_name)
|
||||
if loader is None:
|
||||
if shared.args.loader is not None:
|
||||
loader = shared.args.loader
|
||||
else:
|
||||
loader = get_model_metadata(model_name)['loader']
|
||||
loader = metadata['loader']
|
||||
if loader is None:
|
||||
logger.error('The path to the model does not exist. Exiting.')
|
||||
return None, None
|
||||
@ -95,6 +96,7 @@ def load_model(model_name, loader=None):
|
||||
if any((shared.args.xformers, shared.args.sdp_attention)):
|
||||
llama_attn_hijack.hijack_llama_attention()
|
||||
|
||||
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
|
||||
logger.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
||||
return model, tokenizer
|
||||
|
||||
|
@ -6,33 +6,32 @@ import yaml
|
||||
|
||||
def default_preset():
|
||||
return {
|
||||
'do_sample': True,
|
||||
'temperature': 1,
|
||||
'temperature_last': False,
|
||||
'top_p': 1,
|
||||
'min_p': 0,
|
||||
'top_k': 0,
|
||||
'typical_p': 1,
|
||||
'epsilon_cutoff': 0,
|
||||
'eta_cutoff': 0,
|
||||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
'repetition_penalty': 1,
|
||||
'presence_penalty': 0,
|
||||
'frequency_penalty': 0,
|
||||
'repetition_penalty_range': 0,
|
||||
'typical_p': 1,
|
||||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
'epsilon_cutoff': 0,
|
||||
'eta_cutoff': 0,
|
||||
'guidance_scale': 1,
|
||||
'penalty_alpha': 0,
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5,
|
||||
'mirostat_eta': 0.1,
|
||||
'do_sample': True,
|
||||
'encoder_repetition_penalty': 1,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'min_length': 0,
|
||||
'guidance_scale': 1,
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5.0,
|
||||
'mirostat_eta': 0.1,
|
||||
'penalty_alpha': 0,
|
||||
'num_beams': 1,
|
||||
'length_penalty': 1,
|
||||
'early_stopping': False,
|
||||
'custom_token_bans': '',
|
||||
}
|
||||
|
||||
|
||||
|
@ -39,21 +39,21 @@ settings = {
|
||||
'max_new_tokens': 200,
|
||||
'max_new_tokens_min': 1,
|
||||
'max_new_tokens_max': 4096,
|
||||
'seed': -1,
|
||||
'negative_prompt': '',
|
||||
'seed': -1,
|
||||
'truncation_length': 2048,
|
||||
'truncation_length_min': 0,
|
||||
'truncation_length_max': 32768,
|
||||
'custom_stopping_strings': '',
|
||||
'auto_max_new_tokens': False,
|
||||
'max_tokens_second': 0,
|
||||
'ban_eos_token': False,
|
||||
'custom_stopping_strings': '',
|
||||
'custom_token_bans': '',
|
||||
'auto_max_new_tokens': False,
|
||||
'ban_eos_token': False,
|
||||
'add_bos_token': True,
|
||||
'skip_special_tokens': True,
|
||||
'stream': True,
|
||||
'name1': 'You',
|
||||
'character': 'Assistant',
|
||||
'name1': 'You',
|
||||
'instruction_template': 'Alpaca',
|
||||
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
|
||||
'autoload_model': False,
|
||||
@ -167,8 +167,8 @@ parser.add_argument('--ssl-certfile', type=str, help='The path to the SSL certif
|
||||
parser.add_argument('--api', action='store_true', help='Enable the API extension.')
|
||||
parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
|
||||
parser.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None)
|
||||
parser.add_argument('--api-blocking-port', type=int, default=5000, help='The listening port for the blocking API.')
|
||||
parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.')
|
||||
parser.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.')
|
||||
parser.add_argument('--api-key', type=str, default='', help='API authentication key.')
|
||||
|
||||
# Multimodal
|
||||
parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.')
|
||||
@ -178,6 +178,8 @@ parser.add_argument('--notebook', action='store_true', help='DEPRECATED')
|
||||
parser.add_argument('--chat', action='store_true', help='DEPRECATED')
|
||||
parser.add_argument('--no-stream', action='store_true', help='DEPRECATED')
|
||||
parser.add_argument('--mul_mat_q', action='store_true', help='DEPRECATED')
|
||||
parser.add_argument('--api-blocking-port', type=int, default=5000, help='DEPRECATED')
|
||||
parser.add_argument('--api-streaming-port', type=int, default=5005, help='DEPRECATED')
|
||||
|
||||
args = parser.parse_args()
|
||||
args_defaults = parser.parse_args([])
|
||||
@ -233,10 +235,13 @@ def fix_loader_name(name):
|
||||
return 'AutoAWQ'
|
||||
|
||||
|
||||
def add_extension(name):
|
||||
def add_extension(name, last=False):
|
||||
if args.extensions is None:
|
||||
args.extensions = [name]
|
||||
elif 'api' not in args.extensions:
|
||||
elif last:
|
||||
args.extensions = [x for x in args.extensions if x != name]
|
||||
args.extensions.append(name)
|
||||
elif name not in args.extensions:
|
||||
args.extensions.append(name)
|
||||
|
||||
|
||||
@ -246,14 +251,15 @@ def is_chat():
|
||||
|
||||
args.loader = fix_loader_name(args.loader)
|
||||
|
||||
# Activate the API extension
|
||||
if args.api or args.public_api:
|
||||
add_extension('api')
|
||||
|
||||
# Activate the multimodal extension
|
||||
if args.multimodal_pipeline is not None:
|
||||
add_extension('multimodal')
|
||||
|
||||
# Activate the API extension
|
||||
if args.api:
|
||||
# add_extension('openai', last=True)
|
||||
add_extension('api', last=True)
|
||||
|
||||
# Load model-specific settings
|
||||
with Path(f'{args.model_dir}/config.yaml') as p:
|
||||
if p.exists():
|
||||
|
@ -56,7 +56,10 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
||||
|
||||
# Find the stopping strings
|
||||
all_stop_strings = []
|
||||
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
|
||||
for st in (stopping_strings, state['custom_stopping_strings']):
|
||||
if type(st) is str:
|
||||
st = ast.literal_eval(f"[{st}]")
|
||||
|
||||
if type(st) is list and len(st) > 0:
|
||||
all_stop_strings += st
|
||||
|
||||
|
@ -215,9 +215,6 @@ def load_model_wrapper(selected_model, loader, autoload=False):
|
||||
if 'instruction_template' in settings:
|
||||
output += '\n\nIt seems to be an instruction-following model with template "{}". In the chat tab, instruct or chat-instruct modes should be used.'.format(settings['instruction_template'])
|
||||
|
||||
# Applying the changes to the global shared settings (in-memory)
|
||||
shared.settings.update({k: v for k, v in settings.items() if k in shared.settings})
|
||||
|
||||
yield output
|
||||
else:
|
||||
yield f"Failed to load `{selected_model}`."
|
||||
|
Loading…
Reference in New Issue
Block a user