mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
This commit is contained in:
parent
ebfcfa41f2
commit
b45baeea41
@ -13,17 +13,56 @@ pip3 install -r requirements.txt
|
|||||||
|
|
||||||
It listens on tcp port 5001 by default. You can use the OPENEDAI_PORT environment variable to change this.
|
It listens on tcp port 5001 by default. You can use the OPENEDAI_PORT environment variable to change this.
|
||||||
|
|
||||||
To enable the bare bones image generation (txt2img) set: SD_WEBUI_URL to point to your Stable Diffusion API ([Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui)).
|
Make sure you enable it in server launch parameters, it should include:
|
||||||
|
|
||||||
Example:
|
```
|
||||||
|
--extensions openai
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also use the ``--listen`` argument to make the server available on the networ, and/or the ```--share``` argument to enable a public Cloudflare endpoint.
|
||||||
|
|
||||||
|
To enable the basic image generation support (txt2img) set the environment variable SD_WEBUI_URL to point to your Stable Diffusion API ([Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui)).
|
||||||
|
|
||||||
|
For example:
|
||||||
```
|
```
|
||||||
SD_WEBUI_URL=http://127.0.0.1:7861
|
SD_WEBUI_URL=http://127.0.0.1:7861
|
||||||
```
|
```
|
||||||
|
|
||||||
Make sure you enable it in server launch parameters. Just make sure they include:
|
### Models
|
||||||
|
|
||||||
|
This has been successfully tested with Alpaca, Koala, Vicuna, WizardLM and their variants, (ex. gpt4-x-alpaca, GPT4all-snoozy, stable-vicuna, wizard-vicuna, etc.) and many others. Models that have been trained for **Instruction Following** work best. If you test with other models please let me know how it goes. Less than satisfying results (so far) from: RWKV-4-Raven, llama, mpt-7b-instruct/chat.
|
||||||
|
|
||||||
|
For best results across all API endpoints, a model like [vicuna-13b-v1.3-GPTQ](https://huggingface.co/TheBloke/vicuna-13b-v1.3-GPTQ), [stable-vicuna-13B-GPTQ](https://huggingface.co/TheBloke/stable-vicuna-13B-GPTQ) or [airoboros-13B-gpt4-1.3-GPTQ](https://huggingface.co/TheBloke/airoboros-13B-gpt4-1.3-GPTQ) is a good start.
|
||||||
|
|
||||||
|
For good results with the [Completions](https://platform.openai.com/docs/api-reference/completions) API endpoint, in addition to the above models, you can also try using a base model like [falcon-7b](https://huggingface.co/tiiuae/falcon-7b) or Llama.
|
||||||
|
|
||||||
|
For good results with the [ChatCompletions](https://platform.openai.com/docs/api-reference/chat) or [Edits](https://platform.openai.com/docs/api-reference/edits) API endpoints you can use almost any model trained for instruction following - within the limits of the model. Be sure that the proper instruction template is detected and loaded or the results will not be good.
|
||||||
|
|
||||||
|
For the proper instruction format to be detected you need to have a matching model entry in your ```models/config.yaml``` file. Be sure to keep this file up to date.
|
||||||
|
A matching instruction template file in the characters/instruction-following/ folder will loaded and applied to format messages correctly for the model - this is critical for good results.
|
||||||
|
|
||||||
|
For example, the Wizard-Vicuna family of models are trained with the Vicuna 1.1 format. In the models/config.yaml file there is this matching entry:
|
||||||
|
|
||||||
```
|
```
|
||||||
--extensions openai
|
.*wizard.*vicuna:
|
||||||
|
mode: 'instruct'
|
||||||
|
instruction_template: 'Vicuna-v1.1'
|
||||||
|
```
|
||||||
|
|
||||||
|
This refers to ```characters/instruction-following/Vicuna-v1.1.yaml```, which looks like this:
|
||||||
|
|
||||||
|
```
|
||||||
|
user: "USER:"
|
||||||
|
bot: "ASSISTANT:"
|
||||||
|
turn_template: "<|user|> <|user-message|>\n<|bot|> <|bot-message|></s>\n"
|
||||||
|
context: "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n"
|
||||||
|
```
|
||||||
|
|
||||||
|
For most common models this is already setup, but if you are using a new or uncommon model you may need add a matching entry to the models/config.yaml and possibly create your own instruction-following template and for best results.
|
||||||
|
|
||||||
|
If you see this in your logs, it probably means that the correct format could not be loaded:
|
||||||
|
```
|
||||||
|
Warning: Loaded default instruction-following template for model.
|
||||||
```
|
```
|
||||||
|
|
||||||
### Embeddings (alpha)
|
### Embeddings (alpha)
|
||||||
@ -43,12 +82,13 @@ Warning: You cannot mix embeddings from different models even if they have the s
|
|||||||
|
|
||||||
### Client Application Setup
|
### Client Application Setup
|
||||||
|
|
||||||
|
|
||||||
Almost everything you use it with will require you to set a dummy OpenAI API key environment variable.
|
Almost everything you use it with will require you to set a dummy OpenAI API key environment variable.
|
||||||
|
|
||||||
With the [official python openai client](https://github.com/openai/openai-python), you can set the OPENAI_API_BASE environment variable before you import the openai module, like so:
|
With the [official python openai client](https://github.com/openai/openai-python), you can set the OPENAI_API_BASE environment variable before you import the openai module, like so:
|
||||||
|
|
||||||
```
|
```
|
||||||
OPENAI_API_KEY=sk-dummy
|
OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
|
||||||
OPENAI_API_BASE=http://127.0.0.1:5001/v1
|
OPENAI_API_BASE=http://127.0.0.1:5001/v1
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -80,6 +120,29 @@ const api = new ChatGPTAPI({
|
|||||||
})
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## API Documentation & Examples
|
||||||
|
|
||||||
|
The OpenAI API is well documented, you can view the documentation here: https://platform.openai.com/docs/api-reference
|
||||||
|
|
||||||
|
Examples of how to use the Completions API in Python can be found here: https://platform.openai.com/examples
|
||||||
|
Not all of them will work with all models unfortunately, See the notes on Models for how to get the best results.
|
||||||
|
|
||||||
|
Here is a simple python example of how you can use the Edit endpoint as a translator.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
response = openai.Edit.create(
|
||||||
|
model="x",
|
||||||
|
instruction="Translate this into French",
|
||||||
|
input="Our mission is to ensure that artificial general intelligence benefits all of humanity.",
|
||||||
|
)
|
||||||
|
print(response['choices'][0]['text'])
|
||||||
|
# Sample Output:
|
||||||
|
# Notre mission est de garantir que l'intelligence artificielle généralisée profite à tous les membres de l'humanité.
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Compatibility & not so compatibility
|
## Compatibility & not so compatibility
|
||||||
|
|
||||||
| API endpoint | tested with | notes |
|
| API endpoint | tested with | notes |
|
||||||
@ -114,26 +177,32 @@ Some hacky mappings:
|
|||||||
| --- | --- | --- |
|
| --- | --- | --- |
|
||||||
| frequency_penalty | encoder_repetition_penalty | this seems to operate with a different scale and defaults, I tried to scale it based on range & defaults, but the results are terrible. hardcoded to 1.18 until there is a better way |
|
| frequency_penalty | encoder_repetition_penalty | this seems to operate with a different scale and defaults, I tried to scale it based on range & defaults, but the results are terrible. hardcoded to 1.18 until there is a better way |
|
||||||
| presence_penalty | repetition_penalty | same issues as frequency_penalty, hardcoded to 1.0 |
|
| presence_penalty | repetition_penalty | same issues as frequency_penalty, hardcoded to 1.0 |
|
||||||
| best_of | top_k | |
|
| best_of | top_k | default is 1 |
|
||||||
| stop | custom_stopping_strings | this is also stuffed with ['\n###', "\n{user prompt}", "{user prompt}" ] for good measure. |
|
| stop | custom_stopping_strings | this is also stuffed with ['\n###', "\n{user prompt}", "{user prompt}" ] for good measure. |
|
||||||
| n | 1 | hardcoded, it may be worth implementing this but I'm not sure how yet |
|
| n | 1 | variations are not supported yet. |
|
||||||
| 1.0 | typical_p | hardcoded |
|
| 1 | num_beams | hardcoded to 1 |
|
||||||
| 1 | num_beams | hardcoded |
|
| 1.0 | typical_p | hardcoded to 1.0 |
|
||||||
| max_tokens | max_new_tokens | For Text Completions max_tokens is set smaller than the truncation_length minus the prompt length. This can cause no input to be generated if the prompt is too large. For ChatCompletions, the older chat messages may be dropped to fit the max_new_tokens requested |
|
| max_tokens | max_new_tokens | For Text Completions max_tokens is set smaller than the truncation_length minus the prompt length. This can cause no input to be generated if the prompt is too large. For ChatCompletions, the older chat messages may be dropped to fit the max_new_tokens requested |
|
||||||
| logprobs | - | ignored |
|
| logprobs | - | not supported yet |
|
||||||
| logit_bias | - | ignored |
|
| logit_bias | - | not supported yet |
|
||||||
| messages.name | - | ignored |
|
| messages.name | - | not supported yet |
|
||||||
| user | - | ignored |
|
| user | - | not supported yet |
|
||||||
|
| functions/function_call | - | function calls are not supported yet |
|
||||||
|
|
||||||
defaults are mostly from openai, so are different. I use the openai defaults where I can and try to scale them to the webui defaults with the same intent.
|
defaults are mostly from openai, so are different. I use the openai defaults where I can and try to scale them to the webui defaults with the same intent.
|
||||||
|
|
||||||
### Models
|
|
||||||
|
|
||||||
This has been successfully tested with Koala, Alpaca, gpt4-x-alpaca, GPT4all-snoozy, wizard-vicuna, stable-vicuna and Vicuna 1.1 - ie. Instruction Following models. If you test with other models please let me know how it goes. Less than satisfying results (so far): RWKV-4-Raven, llama, mpt-7b-instruct/chat
|
|
||||||
|
|
||||||
### Applications
|
### Applications
|
||||||
|
|
||||||
Everything needs OPENAI_API_KEY=dummy set.
|
Almost everything needs the OPENAI_API_KEY environment variable set, for example:
|
||||||
|
```
|
||||||
|
OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111
|
||||||
|
```
|
||||||
|
Some apps are picky about key format, but 'dummy' or 'sk-dummy' also work in most cases.
|
||||||
|
Most application will work if you also set:
|
||||||
|
```
|
||||||
|
OPENAI_API_BASE=http://127.0.0.1:5001/v1
|
||||||
|
```
|
||||||
|
but there are some exceptions.
|
||||||
|
|
||||||
| Compatibility | Application/Library | url | notes / setting |
|
| Compatibility | Application/Library | url | notes / setting |
|
||||||
| --- | --- | --- | --- |
|
| --- | --- | --- | --- |
|
||||||
@ -144,7 +213,8 @@ Everything needs OPENAI_API_KEY=dummy set.
|
|||||||
| ✅ | shell_gpt | https://github.com/TheR1D/shell_gpt | OPENAI_API_HOST=http://127.0.0.1:5001 |
|
| ✅ | shell_gpt | https://github.com/TheR1D/shell_gpt | OPENAI_API_HOST=http://127.0.0.1:5001 |
|
||||||
| ✅ | gpt-shell | https://github.com/jla/gpt-shell | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
| ✅ | gpt-shell | https://github.com/jla/gpt-shell | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
||||||
| ✅ | gpt-discord-bot | https://github.com/openai/gpt-discord-bot | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
| ✅ | gpt-discord-bot | https://github.com/openai/gpt-discord-bot | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
||||||
| ✅ | OpenAI for Notepad++| https://github.com/Krazal/nppopenai | api_url=http://127.0.0.1:5001 in the config file |
|
| ✅ | OpenAI for Notepad++ | https://github.com/Krazal/nppopenai | api_url=http://127.0.0.1:5001 in the config file, or environment variables |
|
||||||
|
| ✅ | vscode-openai | https://marketplace.visualstudio.com/items?itemName=AndrewButson.vscode-openai | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
||||||
| ✅❌ | langchain | https://github.com/hwchase17/langchain | OPENAI_API_BASE=http://127.0.0.1:5001/v1 even with a good 30B-4bit model the result is poor so far. It assumes zero shot python/json coding. Some model tailored prompt formatting improves results greatly. |
|
| ✅❌ | langchain | https://github.com/hwchase17/langchain | OPENAI_API_BASE=http://127.0.0.1:5001/v1 even with a good 30B-4bit model the result is poor so far. It assumes zero shot python/json coding. Some model tailored prompt formatting improves results greatly. |
|
||||||
| ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context |
|
| ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context |
|
||||||
| ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
| ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
||||||
|
@ -54,7 +54,7 @@ default_req_params = {
|
|||||||
'mirostat_eta': 0.1,
|
'mirostat_eta': 0.1,
|
||||||
'ban_eos_token': False,
|
'ban_eos_token': False,
|
||||||
'skip_special_tokens': True,
|
'skip_special_tokens': True,
|
||||||
'custom_stopping_strings': ['\n###'],
|
'custom_stopping_strings': '',
|
||||||
}
|
}
|
||||||
|
|
||||||
# Optional, install the module and download the model to enable
|
# Optional, install the module and download the model to enable
|
||||||
@ -254,7 +254,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
return
|
return
|
||||||
|
|
||||||
is_legacy = '/generate' in self.path
|
is_legacy = '/generate' in self.path
|
||||||
is_chat = 'chat' in self.path
|
is_chat_request = 'chat' in self.path
|
||||||
resp_list = 'data' if is_legacy else 'choices'
|
resp_list = 'data' if is_legacy else 'choices'
|
||||||
|
|
||||||
# XXX model is ignored for now
|
# XXX model is ignored for now
|
||||||
@ -262,23 +262,23 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
model = shared.model_name
|
model = shared.model_name
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
|
|
||||||
cmpl_id = "chatcmpl-%d" % (created_time) if is_chat else "conv-%d" % (created_time)
|
cmpl_id = "chatcmpl-%d" % (created_time) if is_chat_request else "conv-%d" % (created_time)
|
||||||
|
|
||||||
# Request Parameters
|
# Request Parameters
|
||||||
# Try to use openai defaults or map them to something with the same intent
|
# Try to use openai defaults or map them to something with the same intent
|
||||||
req_params = default_req_params.copy()
|
req_params = default_req_params.copy()
|
||||||
req_params['custom_stopping_strings'] = default_req_params['custom_stopping_strings'].copy()
|
stopping_strings = []
|
||||||
|
|
||||||
if 'stop' in body:
|
if 'stop' in body:
|
||||||
if isinstance(body['stop'], str):
|
if isinstance(body['stop'], str):
|
||||||
req_params['custom_stopping_strings'].extend([body['stop']])
|
stopping_strings.extend([body['stop']])
|
||||||
elif isinstance(body['stop'], list):
|
elif isinstance(body['stop'], list):
|
||||||
req_params['custom_stopping_strings'].extend(body['stop'])
|
stopping_strings.extend(body['stop'])
|
||||||
|
|
||||||
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
||||||
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length)
|
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length)
|
||||||
|
|
||||||
default_max_tokens = truncation_length if is_chat else 16 # completions default, chat default is 'inf' so we need to cap it.
|
default_max_tokens = truncation_length if is_chat_request else 16 # completions default, chat default is 'inf' so we need to cap it.
|
||||||
|
|
||||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||||
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
|
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
|
||||||
@ -295,9 +295,11 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
req_params['seed'] = shared.settings.get('seed', default_req_params['seed'])
|
req_params['seed'] = shared.settings.get('seed', default_req_params['seed'])
|
||||||
req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token'])
|
req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token'])
|
||||||
|
|
||||||
|
is_streaming = req_params['stream']
|
||||||
|
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_access_control_headers()
|
self.send_access_control_headers()
|
||||||
if req_params['stream']:
|
if is_streaming:
|
||||||
self.send_header('Content-Type', 'text/event-stream')
|
self.send_header('Content-Type', 'text/event-stream')
|
||||||
self.send_header('Cache-Control', 'no-cache')
|
self.send_header('Cache-Control', 'no-cache')
|
||||||
# self.send_header('Connection', 'keep-alive')
|
# self.send_header('Connection', 'keep-alive')
|
||||||
@ -311,7 +313,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
stream_object_type = ''
|
stream_object_type = ''
|
||||||
object_type = ''
|
object_type = ''
|
||||||
|
|
||||||
if is_chat:
|
if is_chat_request:
|
||||||
# Chat Completions
|
# Chat Completions
|
||||||
stream_object_type = 'chat.completions.chunk'
|
stream_object_type = 'chat.completions.chunk'
|
||||||
object_type = 'chat.completions'
|
object_type = 'chat.completions'
|
||||||
@ -347,20 +349,22 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
'prompt': bot_prompt,
|
'prompt': bot_prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
if instruct['user']: # WizardLM and some others have no user prompt.
|
if 'Alpaca' in shared.settings['instruction_template']:
|
||||||
req_params['custom_stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
|
stopping_strings.extend(['\n###'])
|
||||||
|
elif instruct['user']: # WizardLM and some others have no user prompt.
|
||||||
|
stopping_strings.extend(['\n' + instruct['user'], instruct['user']])
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
print(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
print(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
req_params['custom_stopping_strings'].extend(['\nuser:'])
|
stopping_strings.extend(['\nuser:'])
|
||||||
|
|
||||||
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
||||||
print("Warning: Loaded default instruction-following template for model.")
|
print("Warning: Loaded default instruction-following template for model.")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
req_params['custom_stopping_strings'].extend(['\nuser:'])
|
stopping_strings.extend(['\nuser:'])
|
||||||
print("Warning: Loaded default instruction-following template for model.")
|
print("Warning: Loaded default instruction-following template for model.")
|
||||||
|
|
||||||
system_msgs = []
|
system_msgs = []
|
||||||
@ -391,7 +395,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
system_msg = system_msg + '\n'
|
system_msg = system_msg + '\n'
|
||||||
|
|
||||||
system_token_count = len(encode(system_msg)[0])
|
system_token_count = len(encode(system_msg)[0])
|
||||||
remaining_tokens = req_params['truncation_length'] - system_token_count
|
remaining_tokens = truncation_length - system_token_count
|
||||||
chat_msg = ''
|
chat_msg = ''
|
||||||
|
|
||||||
while chat_msgs:
|
while chat_msgs:
|
||||||
@ -424,20 +428,19 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
return
|
return
|
||||||
|
|
||||||
token_count = len(encode(prompt)[0])
|
token_count = len(encode(prompt)[0])
|
||||||
if token_count >= req_params['truncation_length']:
|
if token_count >= truncation_length:
|
||||||
new_len = int(len(prompt) * shared.settings['truncation_length'] / token_count)
|
new_len = int(len(prompt) * shared.settings['truncation_length'] / token_count)
|
||||||
prompt = prompt[-new_len:]
|
prompt = prompt[-new_len:]
|
||||||
new_token_count = len(encode(prompt)[0])
|
new_token_count = len(encode(prompt)[0])
|
||||||
print(f"Warning: truncating prompt to {new_len} characters, was {token_count} tokens. Now: {new_token_count} tokens.")
|
print(f"Warning: truncating prompt to {new_len} characters, was {token_count} tokens. Now: {new_token_count} tokens.")
|
||||||
token_count = new_token_count
|
token_count = new_token_count
|
||||||
|
|
||||||
if req_params['truncation_length'] - token_count < req_params['max_new_tokens']:
|
if truncation_length - token_count < req_params['max_new_tokens']:
|
||||||
print(f"Warning: Ignoring max_new_tokens ({req_params['max_new_tokens']}), too large for the remaining context. Remaining tokens: {req_params['truncation_length'] - token_count}")
|
print(f"Warning: Ignoring max_new_tokens ({req_params['max_new_tokens']}), too large for the remaining context. Remaining tokens: {truncation_length - token_count}")
|
||||||
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
req_params['max_new_tokens'] = truncation_length - token_count
|
||||||
print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}")
|
print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}")
|
||||||
|
|
||||||
if req_params['stream']:
|
if is_streaming:
|
||||||
shared.args.chat = True
|
|
||||||
# begin streaming
|
# begin streaming
|
||||||
chunk = {
|
chunk = {
|
||||||
"id": cmpl_id,
|
"id": cmpl_id,
|
||||||
@ -463,11 +466,11 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
# generate reply #######################################
|
# generate reply #######################################
|
||||||
if debug:
|
if debug:
|
||||||
print({'prompt': prompt, 'req_params': req_params})
|
print({'prompt': prompt, 'req_params': req_params})
|
||||||
generator = generate_reply(prompt, req_params, stopping_strings=req_params['custom_stopping_strings'], is_chat=False)
|
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||||
|
|
||||||
answer = ''
|
answer = ''
|
||||||
seen_content = ''
|
seen_content = ''
|
||||||
longest_stop_len = max([len(x) for x in req_params['custom_stopping_strings']] + [0])
|
longest_stop_len = max([len(x) for x in stopping_strings] + [0])
|
||||||
|
|
||||||
for a in generator:
|
for a in generator:
|
||||||
answer = a
|
answer = a
|
||||||
@ -476,7 +479,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
len_seen = len(seen_content)
|
len_seen = len(seen_content)
|
||||||
search_start = max(len_seen - longest_stop_len, 0)
|
search_start = max(len_seen - longest_stop_len, 0)
|
||||||
|
|
||||||
for string in req_params['custom_stopping_strings']:
|
for string in stopping_strings:
|
||||||
idx = answer.find(string, search_start)
|
idx = answer.find(string, search_start)
|
||||||
if idx != -1:
|
if idx != -1:
|
||||||
answer = answer[:idx] # clip it.
|
answer = answer[:idx] # clip it.
|
||||||
@ -489,7 +492,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
# is completed, buffer and generate more, don't send it
|
# is completed, buffer and generate more, don't send it
|
||||||
buffer_and_continue = False
|
buffer_and_continue = False
|
||||||
|
|
||||||
for string in req_params['custom_stopping_strings']:
|
for string in stopping_strings:
|
||||||
for j in range(len(string) - 1, 0, -1):
|
for j in range(len(string) - 1, 0, -1):
|
||||||
if answer[-j:] == string[:j]:
|
if answer[-j:] == string[:j]:
|
||||||
buffer_and_continue = True
|
buffer_and_continue = True
|
||||||
@ -501,7 +504,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
if buffer_and_continue:
|
if buffer_and_continue:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if req_params['stream']:
|
if is_streaming:
|
||||||
# Streaming
|
# Streaming
|
||||||
new_content = answer[len_seen:]
|
new_content = answer[len_seen:]
|
||||||
|
|
||||||
@ -534,7 +537,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
completion_token_count += len(encode(new_content)[0])
|
completion_token_count += len(encode(new_content)[0])
|
||||||
|
|
||||||
if req_params['stream']:
|
if is_streaming:
|
||||||
chunk = {
|
chunk = {
|
||||||
"id": cmpl_id,
|
"id": cmpl_id,
|
||||||
"object": stream_object_type,
|
"object": stream_object_type,
|
||||||
@ -575,7 +578,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
completion_token_count = len(encode(answer)[0])
|
completion_token_count = len(encode(answer)[0])
|
||||||
stop_reason = "stop"
|
stop_reason = "stop"
|
||||||
if token_count + completion_token_count >= req_params['truncation_length']:
|
if token_count + completion_token_count >= truncation_length:
|
||||||
stop_reason = "length"
|
stop_reason = "length"
|
||||||
|
|
||||||
resp = {
|
resp = {
|
||||||
@ -594,7 +597,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_chat:
|
if is_chat_request:
|
||||||
resp[resp_list][0]["message"] = {"role": "assistant", "content": answer}
|
resp[resp_list][0]["message"] = {"role": "assistant", "content": answer}
|
||||||
else:
|
else:
|
||||||
resp[resp_list][0]["text"] = answer
|
resp[resp_list][0]["text"] = answer
|
||||||
@ -620,7 +623,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
# Request parameters
|
# Request parameters
|
||||||
req_params = default_req_params.copy()
|
req_params = default_req_params.copy()
|
||||||
req_params['custom_stopping_strings'] = default_req_params['custom_stopping_strings'].copy()
|
stopping_strings = []
|
||||||
|
|
||||||
# Alpaca is verbose so a good default prompt
|
# Alpaca is verbose so a good default prompt
|
||||||
default_template = (
|
default_template = (
|
||||||
@ -632,26 +635,29 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
instruction_template = default_template
|
instruction_template = default_template
|
||||||
|
|
||||||
# Use the special instruction/input/response template for anything trained like Alpaca
|
# Use the special instruction/input/response template for anything trained like Alpaca
|
||||||
if shared.settings['instruction_template'] and not (shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']):
|
if shared.settings['instruction_template']:
|
||||||
try:
|
if 'Alpaca' in shared.settings['instruction_template']:
|
||||||
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
stopping_strings.extend(['\n###'])
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||||
|
|
||||||
template = instruct['turn_template']
|
template = instruct['turn_template']
|
||||||
template = template\
|
template = template\
|
||||||
.replace('<|user|>', instruct.get('user', ''))\
|
.replace('<|user|>', instruct.get('user', ''))\
|
||||||
.replace('<|bot|>', instruct.get('bot', ''))\
|
.replace('<|bot|>', instruct.get('bot', ''))\
|
||||||
.replace('<|user-message|>', '{instruction}\n{input}')
|
.replace('<|user-message|>', '{instruction}\n{input}')
|
||||||
|
|
||||||
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
||||||
if instruct['user']:
|
if instruct['user']:
|
||||||
req_params['custom_stopping_strings'].extend(['\n' + instruct['user'], instruct['user'] ])
|
stopping_strings.extend(['\n' + instruct['user'], instruct['user'] ])
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
instruction_template = default_template
|
|
||||||
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
|
||||||
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
instruction_template = default_template
|
||||||
|
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
||||||
|
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
||||||
else:
|
else:
|
||||||
|
stopping_strings.extend(['\n###'])
|
||||||
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
||||||
|
|
||||||
|
|
||||||
@ -671,9 +677,9 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
if debug:
|
if debug:
|
||||||
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
||||||
|
|
||||||
generator = generate_reply(edit_task, req_params, stopping_strings=req_params['custom_stopping_strings'], is_chat=False)
|
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||||
|
|
||||||
longest_stop_len = max([len(x) for x in req_params['custom_stopping_strings']] + [0])
|
longest_stop_len = max([len(x) for x in stopping_strings] + [0])
|
||||||
answer = ''
|
answer = ''
|
||||||
seen_content = ''
|
seen_content = ''
|
||||||
for a in generator:
|
for a in generator:
|
||||||
@ -683,7 +689,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
len_seen = len(seen_content)
|
len_seen = len(seen_content)
|
||||||
search_start = max(len_seen - longest_stop_len, 0)
|
search_start = max(len_seen - longest_stop_len, 0)
|
||||||
|
|
||||||
for string in req_params['custom_stopping_strings']:
|
for string in stopping_strings:
|
||||||
idx = answer.find(string, search_start)
|
idx = answer.find(string, search_start)
|
||||||
if idx != -1:
|
if idx != -1:
|
||||||
answer = answer[:idx] # clip it.
|
answer = answer[:idx] # clip it.
|
||||||
|
Loading…
Reference in New Issue
Block a user