mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 09:19:23 +01:00
extensions/openai: docs update, model loader, minor fixes (#2557)
This commit is contained in:
parent
2220b78e7a
commit
1e97aaac95
@ -84,20 +84,25 @@ const api = new ChatGPTAPI({
|
|||||||
|
|
||||||
| API endpoint | tested with | notes |
|
| API endpoint | tested with | notes |
|
||||||
| --- | --- | --- |
|
| --- | --- | --- |
|
||||||
| /v1/models | openai.Model.list() | returns the currently loaded model_name and some mock compatibility options |
|
| /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options |
|
||||||
| /v1/models/{id} | openai.Model.get() | returns whatever you ask for, model does nothing yet anyways |
|
| /v1/models/{id} | openai.Model.get() | returns whatever you ask for, model does nothing yet anyways |
|
||||||
| /v1/text_completion | openai.Completion.create() | the most tested, only supports single string input so far |
|
| /v1/text_completion | openai.Completion.create() | the most tested, only supports single string input so far, variable quality based on the model |
|
||||||
| /v1/chat/completions | openai.ChatCompletion.create() | depending on the model, this may add leading linefeeds |
|
| /v1/chat/completions | openai.ChatCompletion.create() | Quality depends a lot on the model |
|
||||||
| /v1/edits | openai.Edit.create() | Assumes an instruction following model, but may work with others |
|
| /v1/edits | openai.Edit.create() | Works the best of all, perfect for instruction following models |
|
||||||
| /v1/images/generations | openai.Image.create() | Bare bones, no model configuration, response_format='b64_json' only. |
|
| /v1/images/generations | openai.Image.create() | Bare bones, no model configuration, response_format='b64_json' only. |
|
||||||
| /v1/embeddings | openai.Embedding.create() | Using Sentence Transformer, dimensions are different and may never be directly comparable to openai embeddings. |
|
| /v1/embeddings | openai.Embedding.create() | Using Sentence Transformer, dimensions are different and may never be directly comparable to openai embeddings. |
|
||||||
| /v1/moderations | openai.Moderation.create() | does nothing. successfully. |
|
| /v1/moderations | openai.Moderation.create() | does nothing. successfully. |
|
||||||
| /v1/engines/\*/... completions, embeddings, generate | python-openai v0.25 and earlier | Legacy engines endpoints |
|
| /v1/completions | openai api completions.create | Legacy endpoint (v0.25) |
|
||||||
| /v1/images/edits | openai.Image.create_edit() | not supported |
|
| /v1/engines/*/embeddings | python-openai v0.25 | Legacy endpoint |
|
||||||
| /v1/images/variations | openai.Image.create_variation() | not supported |
|
| /v1/engines/*/generate | openai engines.generate | Legacy endpoint |
|
||||||
| /v1/audio/\* | openai.Audio.\* | not supported |
|
| /v1/engines | openai engines.list | Legacy Lists models |
|
||||||
| /v1/files\* | openai.Files.\* | not supported |
|
| /v1/engines/{model_name} | openai engines.get -i {model_name} | You can use this legacy endpoint to load models via the api |
|
||||||
| /v1/fine-tunes\* | openai.FineTune.\* | not supported |
|
| /v1/images/edits | openai.Image.create_edit() | not yet supported |
|
||||||
|
| /v1/images/variations | openai.Image.create_variation() | not yet supported |
|
||||||
|
| /v1/audio/\* | openai.Audio.\* | not yet supported |
|
||||||
|
| /v1/files\* | openai.Files.\* | not yet supported |
|
||||||
|
| /v1/fine-tunes\* | openai.FineTune.\* | not yet supported |
|
||||||
|
| /v1/search | openai.search, engines.search | not yet supported |
|
||||||
|
|
||||||
The model name setting is ignored in completions, but you may need to adjust the maximum token length to fit the model (ie. set to <2048 tokens instead of 4096, 8k, etc). To mitigate some of this, the max_tokens value is halved until it is less than truncation_length for the model (typically 2k).
|
The model name setting is ignored in completions, but you may need to adjust the maximum token length to fit the model (ie. set to <2048 tokens instead of 4096, 8k, etc). To mitigate some of this, the max_tokens value is halved until it is less than truncation_length for the model (typically 2k).
|
||||||
|
|
||||||
@ -110,12 +115,15 @@ Some hacky mappings:
|
|||||||
| frequency_penalty | encoder_repetition_penalty | this seems to operate with a different scale and defaults, I tried to scale it based on range & defaults, but the results are terrible. hardcoded to 1.18 until there is a better way |
|
| frequency_penalty | encoder_repetition_penalty | this seems to operate with a different scale and defaults, I tried to scale it based on range & defaults, but the results are terrible. hardcoded to 1.18 until there is a better way |
|
||||||
| presence_penalty | repetition_penalty | same issues as frequency_penalty, hardcoded to 1.0 |
|
| presence_penalty | repetition_penalty | same issues as frequency_penalty, hardcoded to 1.0 |
|
||||||
| best_of | top_k | |
|
| best_of | top_k | |
|
||||||
| stop | custom_stopping_strings | this is also stuffed with ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ] for good measure. |
|
| stop | custom_stopping_strings | this is also stuffed with ['\n###', "\n{user prompt}", "{user prompt}" ] for good measure. |
|
||||||
| n | 1 | hardcoded, it may be worth implementing this but I'm not sure how yet |
|
| n | 1 | hardcoded, it may be worth implementing this but I'm not sure how yet |
|
||||||
| 1.0 | typical_p | hardcoded |
|
| 1.0 | typical_p | hardcoded |
|
||||||
| 1 | num_beams | hardcoded |
|
| 1 | num_beams | hardcoded |
|
||||||
| max_tokens | max_new_tokens | max_tokens is scaled down by powers of 2 until it's smaller than truncation length. |
|
| max_tokens | max_new_tokens | For Text Completions max_tokens is set smaller than the truncation_length minus the prompt length. This can cause no input to be generated if the prompt is too large. For ChatCompletions, the older chat messages may be dropped to fit the max_new_tokens requested |
|
||||||
| logprobs | - | ignored |
|
| logprobs | - | ignored |
|
||||||
|
| logit_bias | - | ignored |
|
||||||
|
| messages.name | - | ignored |
|
||||||
|
| user | - | ignored |
|
||||||
|
|
||||||
defaults are mostly from openai, so are different. I use the openai defaults where I can and try to scale them to the webui defaults with the same intent.
|
defaults are mostly from openai, so are different. I use the openai defaults where I can and try to scale them to the webui defaults with the same intent.
|
||||||
|
|
||||||
@ -129,13 +137,14 @@ Everything needs OPENAI_API_KEY=dummy set.
|
|||||||
|
|
||||||
| Compatibility | Application/Library | url | notes / setting |
|
| Compatibility | Application/Library | url | notes / setting |
|
||||||
| --- | --- | --- | --- |
|
| --- | --- | --- | --- |
|
||||||
| ✅❌ | openai-python | https://github.com/openai/openai-python | only the endpoints from above are working. OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
| ✅❌ | openai-python (v0.25+) | https://github.com/openai/openai-python | only the endpoints from above are working. OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
||||||
| ✅❌ | openai-node | https://github.com/openai/openai-node | only the endpoints from above are working. environment variables don't work by default, but can be configured (see above) |
|
| ✅❌ | openai-node | https://github.com/openai/openai-node | only the endpoints from above are working. environment variables don't work by default, but can be configured (see above) |
|
||||||
| ✅❌ | chatgpt-api | https://github.com/transitive-bullshit/chatgpt-api | only the endpoints from above are working. environment variables don't work by default, but can be configured (see above) |
|
| ✅❌ | chatgpt-api | https://github.com/transitive-bullshit/chatgpt-api | only the endpoints from above are working. environment variables don't work by default, but can be configured (see above) |
|
||||||
| ✅ | anse | https://github.com/anse-app/anse | API Key & URL configurable in UI |
|
| ✅ | anse | https://github.com/anse-app/anse | API Key & URL configurable in UI |
|
||||||
| ✅ | shell_gpt | https://github.com/TheR1D/shell_gpt | OPENAI_API_HOST=http://127.0.0.1:5001 |
|
| ✅ | shell_gpt | https://github.com/TheR1D/shell_gpt | OPENAI_API_HOST=http://127.0.0.1:5001 |
|
||||||
| ✅ | gpt-shell | https://github.com/jla/gpt-shell | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
| ✅ | gpt-shell | https://github.com/jla/gpt-shell | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
||||||
| ✅ | gpt-discord-bot | https://github.com/openai/gpt-discord-bot | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
| ✅ | gpt-discord-bot | https://github.com/openai/gpt-discord-bot | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
||||||
|
| ✅ | OpenAI for Notepad++| https://github.com/Krazal/nppopenai | api_url=http://127.0.0.1:5001 in the config file |
|
||||||
| ✅❌ | langchain | https://github.com/hwchase17/langchain | OPENAI_API_BASE=http://127.0.0.1:5001/v1 even with a good 30B-4bit model the result is poor so far. It assumes zero shot python/json coding. Some model tailored prompt formatting improves results greatly. |
|
| ✅❌ | langchain | https://github.com/hwchase17/langchain | OPENAI_API_BASE=http://127.0.0.1:5001/v1 even with a good 30B-4bit model the result is poor so far. It assumes zero shot python/json coding. Some model tailored prompt formatting improves results greatly. |
|
||||||
| ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context |
|
| ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context |
|
||||||
| ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
| ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
||||||
@ -148,4 +157,6 @@ Everything needs OPENAI_API_KEY=dummy set.
|
|||||||
|
|
||||||
## Bugs? Feedback? Comments? Pull requests?
|
## Bugs? Feedback? Comments? Pull requests?
|
||||||
|
|
||||||
|
To enable debugging and get copious output you can set the OPENEDAI_DEBUG=1 environment variable.
|
||||||
|
|
||||||
Are all appreciated, please @matatonic and I'll try to get back to you as soon as possible.
|
Are all appreciated, please @matatonic and I'll try to get back to you as soon as possible.
|
||||||
|
@ -4,11 +4,13 @@ import os
|
|||||||
import time
|
import time
|
||||||
import requests
|
import requests
|
||||||
import yaml
|
import yaml
|
||||||
|
import numpy as np
|
||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from modules.utils import get_available_models
|
from modules.utils import get_available_models
|
||||||
|
from modules.models import load_model, unload_model
|
||||||
import numpy as np
|
from modules.models_settings import (get_model_settings_from_yamls,
|
||||||
|
update_model_parameters)
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.text_generation import encode, generate_reply
|
from modules.text_generation import encode, generate_reply
|
||||||
@ -37,8 +39,8 @@ default_req_params = {
|
|||||||
'add_bos_token': True,
|
'add_bos_token': True,
|
||||||
'do_sample': True,
|
'do_sample': True,
|
||||||
'typical_p': 1.0,
|
'typical_p': 1.0,
|
||||||
'epsilon_cutoff': 0, # In units of 1e-4
|
'epsilon_cutoff': 0.0, # In units of 1e-4
|
||||||
'eta_cutoff': 0, # In units of 1e-4
|
'eta_cutoff': 0.0, # In units of 1e-4
|
||||||
'tfs': 1.0,
|
'tfs': 1.0,
|
||||||
'top_a': 0.0,
|
'top_a': 0.0,
|
||||||
'min_length': 0,
|
'min_length': 0,
|
||||||
@ -142,41 +144,83 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
self.wfile.write("OK".encode('utf-8'))
|
self.wfile.write("OK".encode('utf-8'))
|
||||||
|
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
if self.path.startswith('/v1/models'):
|
if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'):
|
||||||
self.send_response(200)
|
current_model_list = [ shared.model_name ] # The real chat/completions model, maybe "None"
|
||||||
self.send_access_control_headers()
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
# TODO: Lora's?
|
|
||||||
# This API should list capabilities, limits and pricing...
|
|
||||||
current_model_list = [ shared.model_name ] # The real chat/completions model
|
|
||||||
embeddings_model_list = [ st_model ] if embedding_model else [] # The real sentence transformer embeddings model
|
embeddings_model_list = [ st_model ] if embedding_model else [] # The real sentence transformer embeddings model
|
||||||
pseudo_model_list = [ # these are expected by so much, so include some here as a dummy
|
pseudo_model_list = [ # these are expected by so much, so include some here as a dummy
|
||||||
'gpt-3.5-turbo', # /v1/chat/completions
|
'gpt-3.5-turbo', # /v1/chat/completions
|
||||||
'text-curie-001', # /v1/completions, 2k context
|
'text-curie-001', # /v1/completions, 2k context
|
||||||
'text-davinci-002' # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768
|
'text-davinci-002' # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768
|
||||||
]
|
]
|
||||||
available_model_list = get_available_models()
|
|
||||||
all_model_list = current_model_list + embeddings_model_list + pseudo_model_list + available_model_list
|
|
||||||
|
|
||||||
models = [{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in all_model_list ]
|
is_legacy = 'engines' in self.path
|
||||||
|
is_list = self.path in ['/v1/engines', '/v1/models']
|
||||||
|
|
||||||
response = ''
|
resp = ''
|
||||||
if self.path == '/v1/models':
|
|
||||||
response = json.dumps({
|
if is_legacy and not is_list: # load model
|
||||||
|
model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):]
|
||||||
|
|
||||||
|
resp = {
|
||||||
|
"id": model_name,
|
||||||
|
"object": "engine",
|
||||||
|
"owner": "self",
|
||||||
|
"ready": True,
|
||||||
|
}
|
||||||
|
if model_name not in pseudo_model_list + embeddings_model_list + current_model_list: # Real model only
|
||||||
|
# No args. Maybe it works anyways!
|
||||||
|
# TODO: hack some heuristics into args for better results
|
||||||
|
|
||||||
|
shared.model_name = model_name
|
||||||
|
unload_model()
|
||||||
|
|
||||||
|
model_settings = get_model_settings_from_yamls(shared.model_name)
|
||||||
|
shared.settings.update(model_settings)
|
||||||
|
update_model_parameters(model_settings, initial=True)
|
||||||
|
|
||||||
|
if shared.settings['mode'] != 'instruct':
|
||||||
|
shared.settings['instruction_template'] = None
|
||||||
|
|
||||||
|
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||||
|
|
||||||
|
if not shared.model: # load failed.
|
||||||
|
shared.model_name = "None"
|
||||||
|
resp['id'] = "None"
|
||||||
|
resp['ready'] = False
|
||||||
|
|
||||||
|
elif is_list:
|
||||||
|
# TODO: Lora's?
|
||||||
|
available_model_list = get_available_models()
|
||||||
|
all_model_list = current_model_list + embeddings_model_list + pseudo_model_list + available_model_list
|
||||||
|
|
||||||
|
models = {}
|
||||||
|
|
||||||
|
if is_legacy:
|
||||||
|
models = [{ "id": id, "object": "engine", "owner": "user", "ready": True } for id in all_model_list ]
|
||||||
|
if not shared.model:
|
||||||
|
models[0]['ready'] = False
|
||||||
|
else:
|
||||||
|
models = [{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in all_model_list ]
|
||||||
|
|
||||||
|
resp = {
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": models,
|
"data": models,
|
||||||
})
|
}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
the_model_name = self.path[len('/v1/models/'):]
|
the_model_name = self.path[len('/v1/models/'):]
|
||||||
response = json.dumps({
|
resp = {
|
||||||
"id": the_model_name,
|
"id": the_model_name,
|
||||||
"object": "model",
|
"object": "model",
|
||||||
"owned_by": "user",
|
"owned_by": "user",
|
||||||
"permission": []
|
"permission": []
|
||||||
})
|
}
|
||||||
|
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_access_control_headers()
|
||||||
|
self.send_header('Content-Type', 'application/json')
|
||||||
|
self.end_headers()
|
||||||
|
response = json.dumps(resp)
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
elif '/billing/usage' in self.path:
|
elif '/billing/usage' in self.path:
|
||||||
@ -283,35 +327,41 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Instruct models can be much better
|
# Instruct models can be much better
|
||||||
try:
|
if shared.settings['instruction_template']:
|
||||||
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
try:
|
||||||
|
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||||
|
|
||||||
template = instruct['turn_template']
|
template = instruct['turn_template']
|
||||||
system_message_template = "{message}"
|
system_message_template = "{message}"
|
||||||
system_message_default = instruct['context']
|
system_message_default = instruct['context']
|
||||||
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
|
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
|
||||||
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user'])
|
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user'])
|
||||||
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot'])
|
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot'])
|
||||||
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
|
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
|
||||||
|
|
||||||
role_formats = {
|
role_formats = {
|
||||||
'user': user_message_template,
|
'user': user_message_template,
|
||||||
'assistant': bot_message_template,
|
'assistant': bot_message_template,
|
||||||
'system': system_message_template,
|
'system': system_message_template,
|
||||||
'context': system_message_default,
|
'context': system_message_default,
|
||||||
'prompt': bot_prompt,
|
'prompt': bot_prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
if instruct['user']: # WizardLM and some others have no user prompt.
|
if instruct['user']: # WizardLM and some others have no user prompt.
|
||||||
req_params['custom_stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
|
req_params['custom_stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
print(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
print(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
||||||
except:
|
|
||||||
|
except Exception as e:
|
||||||
|
req_params['custom_stopping_strings'].extend(['\nuser:'])
|
||||||
|
|
||||||
|
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
||||||
|
print("Warning: Loaded default instruction-following template for model.")
|
||||||
|
|
||||||
|
else:
|
||||||
req_params['custom_stopping_strings'].extend(['\nuser:'])
|
req_params['custom_stopping_strings'].extend(['\nuser:'])
|
||||||
|
print("Warning: Loaded default instruction-following template for model.")
|
||||||
if debug:
|
|
||||||
print("Loaded default role format.")
|
|
||||||
|
|
||||||
system_msgs = []
|
system_msgs = []
|
||||||
chat_msgs = []
|
chat_msgs = []
|
||||||
@ -370,7 +420,8 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
prompt = body['prompt'] # XXX this can be different types
|
prompt = body['prompt'] # XXX this can be different types
|
||||||
|
|
||||||
if isinstance(prompt, list):
|
if isinstance(prompt, list):
|
||||||
prompt = ''.join(prompt) # XXX this is wrong... need to split out to multiple calls?
|
self.openai_error("API Batched generation not yet supported.")
|
||||||
|
return
|
||||||
|
|
||||||
token_count = len(encode(prompt)[0])
|
token_count = len(encode(prompt)[0])
|
||||||
if token_count >= req_params['truncation_length']:
|
if token_count >= req_params['truncation_length']:
|
||||||
@ -412,7 +463,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
# generate reply #######################################
|
# generate reply #######################################
|
||||||
if debug:
|
if debug:
|
||||||
print({'prompt': prompt, 'req_params': req_params})
|
print({'prompt': prompt, 'req_params': req_params})
|
||||||
generator = generate_reply(prompt, req_params, is_chat=False)
|
generator = generate_reply(prompt, req_params, stopping_strings=req_params['custom_stopping_strings'], is_chat=False)
|
||||||
|
|
||||||
answer = ''
|
answer = ''
|
||||||
seen_content = ''
|
seen_content = ''
|
||||||
@ -569,6 +620,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
# Request parameters
|
# Request parameters
|
||||||
req_params = default_req_params.copy()
|
req_params = default_req_params.copy()
|
||||||
|
req_params['custom_stopping_strings'] = default_req_params['custom_stopping_strings'].copy()
|
||||||
|
|
||||||
# Alpaca is verbose so a good default prompt
|
# Alpaca is verbose so a good default prompt
|
||||||
default_template = (
|
default_template = (
|
||||||
@ -578,10 +630,9 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
instruction_template = default_template
|
instruction_template = default_template
|
||||||
req_params['custom_stopping_strings'] = [ '\n###' ]
|
|
||||||
|
|
||||||
# Use the special instruction/input/response template for anything trained like Alpaca
|
# Use the special instruction/input/response template for anything trained like Alpaca
|
||||||
if not (shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']):
|
if shared.settings['instruction_template'] and not (shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']):
|
||||||
try:
|
try:
|
||||||
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||||
|
|
||||||
@ -593,9 +644,16 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
||||||
if instruct['user']:
|
if instruct['user']:
|
||||||
req_params['custom_stopping_strings'] = [ '\n' + instruct['user'], instruct['user'] ]
|
req_params['custom_stopping_strings'].extend(['\n' + instruct['user'], instruct['user'] ])
|
||||||
except:
|
|
||||||
pass
|
except Exception as e:
|
||||||
|
instruction_template = default_template
|
||||||
|
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
||||||
|
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
||||||
|
|
||||||
|
|
||||||
edit_task = instruction_template.format(instruction=instruction, input=input)
|
edit_task = instruction_template.format(instruction=instruction, input=input)
|
||||||
|
|
||||||
@ -613,12 +671,28 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
if debug:
|
if debug:
|
||||||
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
||||||
|
|
||||||
generator = generate_reply(edit_task, req_params, is_chat=False)
|
generator = generate_reply(edit_task, req_params, stopping_strings=req_params['custom_stopping_strings'], is_chat=False)
|
||||||
|
|
||||||
|
longest_stop_len = max([len(x) for x in req_params['custom_stopping_strings']] + [0])
|
||||||
answer = ''
|
answer = ''
|
||||||
|
seen_content = ''
|
||||||
for a in generator:
|
for a in generator:
|
||||||
answer = a
|
answer = a
|
||||||
|
|
||||||
|
stop_string_found = False
|
||||||
|
len_seen = len(seen_content)
|
||||||
|
search_start = max(len_seen - longest_stop_len, 0)
|
||||||
|
|
||||||
|
for string in req_params['custom_stopping_strings']:
|
||||||
|
idx = answer.find(string, search_start)
|
||||||
|
if idx != -1:
|
||||||
|
answer = answer[:idx] # clip it.
|
||||||
|
stop_string_found = True
|
||||||
|
|
||||||
|
if stop_string_found:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
# some reply's have an extra leading space to fit the instruction template, just clip it off from the reply.
|
# some reply's have an extra leading space to fit the instruction template, just clip it off from the reply.
|
||||||
if edit_task[-1] != '\n' and answer and answer[0] == ' ':
|
if edit_task[-1] != '\n' and answer and answer[0] == ' ':
|
||||||
answer = answer[1:]
|
answer = answer[1:]
|
||||||
|
Loading…
Reference in New Issue
Block a user