mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-30 14:10:14 +01:00
Merge remote-tracking branch 'refs/remotes/origin/dev' into dev
This commit is contained in:
commit
177ab7912a
@ -20,6 +20,12 @@ Example:
|
|||||||
SD_WEBUI_URL=http://127.0.0.1:7861
|
SD_WEBUI_URL=http://127.0.0.1:7861
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Make sure you enable it in server launch parameters. Just make sure they include:
|
||||||
|
|
||||||
|
```
|
||||||
|
--extensions openai
|
||||||
|
```
|
||||||
|
|
||||||
### Embeddings (alpha)
|
### Embeddings (alpha)
|
||||||
|
|
||||||
Embeddings requires ```sentence-transformers``` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: ```sentence-transformers/all-mpnet-base-v2``` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default ```text-embedding-ada-002``` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future.
|
Embeddings requires ```sentence-transformers``` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: ```sentence-transformers/all-mpnet-base-v2``` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default ```text-embedding-ada-002``` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future.
|
||||||
@ -42,7 +48,7 @@ Almost everything you use it with will require you to set a dummy OpenAI API key
|
|||||||
With the [official python openai client](https://github.com/openai/openai-python), you can set the OPENAI_API_BASE environment variable before you import the openai module, like so:
|
With the [official python openai client](https://github.com/openai/openai-python), you can set the OPENAI_API_BASE environment variable before you import the openai module, like so:
|
||||||
|
|
||||||
```
|
```
|
||||||
OPENAI_API_KEY=dummy
|
OPENAI_API_KEY=sk-dummy
|
||||||
OPENAI_API_BASE=http://127.0.0.1:5001/v1
|
OPENAI_API_BASE=http://127.0.0.1:5001/v1
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ params = {
|
|||||||
debug = True if 'OPENEDAI_DEBUG' in os.environ else False
|
debug = True if 'OPENEDAI_DEBUG' in os.environ else False
|
||||||
|
|
||||||
# Slightly different defaults for OpenAI's API
|
# Slightly different defaults for OpenAI's API
|
||||||
|
# Data type is important, Ex. use 0.0 for a float 0
|
||||||
default_req_params = {
|
default_req_params = {
|
||||||
'max_new_tokens': 200,
|
'max_new_tokens': 200,
|
||||||
'temperature': 1.0,
|
'temperature': 1.0,
|
||||||
@ -44,14 +45,14 @@ default_req_params = {
|
|||||||
'no_repeat_ngram_size': 0,
|
'no_repeat_ngram_size': 0,
|
||||||
'num_beams': 1,
|
'num_beams': 1,
|
||||||
'penalty_alpha': 0.0,
|
'penalty_alpha': 0.0,
|
||||||
'length_penalty': 1,
|
'length_penalty': 1.0,
|
||||||
'early_stopping': False,
|
'early_stopping': False,
|
||||||
'mirostat_mode': 0,
|
'mirostat_mode': 0,
|
||||||
'mirostat_tau': 5,
|
'mirostat_tau': 5.0,
|
||||||
'mirostat_eta': 0.1,
|
'mirostat_eta': 0.1,
|
||||||
'ban_eos_token': False,
|
'ban_eos_token': False,
|
||||||
'skip_special_tokens': True,
|
'skip_special_tokens': True,
|
||||||
'custom_stopping_strings': [],
|
'custom_stopping_strings': ['\n###'],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Optional, install the module and download the model to enable
|
# Optional, install the module and download the model to enable
|
||||||
@ -64,8 +65,6 @@ except ImportError:
|
|||||||
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
||||||
embedding_model = None
|
embedding_model = None
|
||||||
|
|
||||||
standard_stopping_strings = ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ]
|
|
||||||
|
|
||||||
# little helper to get defaults if arg is present but None and should be the same type as default.
|
# little helper to get defaults if arg is present but None and should be the same type as default.
|
||||||
def default(dic, key, default):
|
def default(dic, key, default):
|
||||||
val = dic.get(key, default)
|
val = dic.get(key, default)
|
||||||
@ -86,31 +85,6 @@ def clamp(value, minvalue, maxvalue):
|
|||||||
return max(minvalue, min(value, maxvalue))
|
return max(minvalue, min(value, maxvalue))
|
||||||
|
|
||||||
|
|
||||||
def deduce_template():
|
|
||||||
# Alpaca is verbose so a good default prompt
|
|
||||||
default_template = (
|
|
||||||
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
|
||||||
"Write a response that appropriately completes the request.\n\n"
|
|
||||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use the special instruction/input/response template for anything trained like Alpaca
|
|
||||||
if shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']:
|
|
||||||
return default_template
|
|
||||||
|
|
||||||
try:
|
|
||||||
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
|
||||||
|
|
||||||
template = instruct['turn_template']
|
|
||||||
template = template\
|
|
||||||
.replace('<|user|>', instruct.get('user', ''))\
|
|
||||||
.replace('<|bot|>', instruct.get('bot', ''))\
|
|
||||||
.replace('<|user-message|>', '{instruction}\n{input}')
|
|
||||||
return instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
|
||||||
except:
|
|
||||||
return default_template
|
|
||||||
|
|
||||||
|
|
||||||
def float_list_to_base64(float_list):
|
def float_list_to_base64(float_list):
|
||||||
# Convert the list to a float32 array that the OpenAPI client expects
|
# Convert the list to a float32 array that the OpenAPI client expects
|
||||||
float_array = np.array(float_list, dtype="float32")
|
float_array = np.array(float_list, dtype="float32")
|
||||||
@ -141,6 +115,25 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
"Authorization"
|
"Authorization"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def openai_error(self, message, code = 500, error_type = 'APIError', param = '', internal_message = ''):
|
||||||
|
self.send_response(code)
|
||||||
|
self.send_access_control_headers()
|
||||||
|
self.send_header('Content-Type', 'application/json')
|
||||||
|
self.end_headers()
|
||||||
|
error_resp = {
|
||||||
|
'error': {
|
||||||
|
'message': message,
|
||||||
|
'code': code,
|
||||||
|
'type': error_type,
|
||||||
|
'param': param,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if internal_message:
|
||||||
|
error_resp['internal_message'] = internal_message
|
||||||
|
|
||||||
|
response = json.dumps(error_resp)
|
||||||
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
def do_OPTIONS(self):
|
def do_OPTIONS(self):
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_access_control_headers()
|
self.send_access_control_headers()
|
||||||
@ -150,42 +143,24 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
if self.path.startswith('/v1/models'):
|
if self.path.startswith('/v1/models'):
|
||||||
|
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_access_control_headers()
|
self.send_access_control_headers()
|
||||||
self.send_header('Content-Type', 'application/json')
|
self.send_header('Content-Type', 'application/json')
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
||||||
# TODO: list all models and allow model changes via API? Lora's?
|
# TODO: Lora's?
|
||||||
# This API should list capabilities, limits and pricing...
|
# This API should list capabilities, limits and pricing...
|
||||||
models = [{
|
current_model_list = [ shared.model_name ] # The real chat/completions model
|
||||||
"id": shared.model_name, # The real chat/completions model
|
embeddings_model_list = [ st_model ] if embedding_model else [] # The real sentence transformer embeddings model
|
||||||
"object": "model",
|
pseudo_model_list = [ # these are expected by so much, so include some here as a dummy
|
||||||
"owned_by": "user",
|
'gpt-3.5-turbo', # /v1/chat/completions
|
||||||
"permission": []
|
'text-curie-001', # /v1/completions, 2k context
|
||||||
}, {
|
'text-davinci-002' # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768
|
||||||
"id": st_model, # The real sentence transformer embeddings model
|
]
|
||||||
"object": "model",
|
available_model_list = get_available_models()
|
||||||
"owned_by": "user",
|
all_model_list = current_model_list + embeddings_model_list + pseudo_model_list + available_model_list
|
||||||
"permission": []
|
|
||||||
}, { # these are expected by so much, so include some here as a dummy
|
|
||||||
"id": "gpt-3.5-turbo", # /v1/chat/completions
|
|
||||||
"object": "model",
|
|
||||||
"owned_by": "user",
|
|
||||||
"permission": []
|
|
||||||
}, {
|
|
||||||
"id": "text-curie-001", # /v1/completions, 2k context
|
|
||||||
"object": "model",
|
|
||||||
"owned_by": "user",
|
|
||||||
"permission": []
|
|
||||||
}, {
|
|
||||||
"id": "text-davinci-002", # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768
|
|
||||||
"object": "model",
|
|
||||||
"owned_by": "user",
|
|
||||||
"permission": []
|
|
||||||
}]
|
|
||||||
|
|
||||||
models.extend([{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in get_available_models() ])
|
models = [{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in all_model_list ]
|
||||||
|
|
||||||
response = ''
|
response = ''
|
||||||
if self.path == '/v1/models':
|
if self.path == '/v1/models':
|
||||||
@ -203,6 +178,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
})
|
})
|
||||||
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
elif '/billing/usage' in self.path:
|
elif '/billing/usage' in self.path:
|
||||||
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
@ -214,6 +190,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
"total_usage": 0,
|
"total_usage": 0,
|
||||||
})
|
})
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.send_error(404)
|
self.send_error(404)
|
||||||
|
|
||||||
@ -227,6 +204,11 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
print(body)
|
print(body)
|
||||||
|
|
||||||
if '/completions' in self.path or '/generate' in self.path:
|
if '/completions' in self.path or '/generate' in self.path:
|
||||||
|
|
||||||
|
if not shared.model:
|
||||||
|
self.openai_error("No model loaded.")
|
||||||
|
return
|
||||||
|
|
||||||
is_legacy = '/generate' in self.path
|
is_legacy = '/generate' in self.path
|
||||||
is_chat = 'chat' in self.path
|
is_chat = 'chat' in self.path
|
||||||
resp_list = 'data' if is_legacy else 'choices'
|
resp_list = 'data' if is_legacy else 'choices'
|
||||||
@ -238,13 +220,16 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
cmpl_id = "chatcmpl-%d" % (created_time) if is_chat else "conv-%d" % (created_time)
|
cmpl_id = "chatcmpl-%d" % (created_time) if is_chat else "conv-%d" % (created_time)
|
||||||
|
|
||||||
|
# Request Parameters
|
||||||
# Try to use openai defaults or map them to something with the same intent
|
# Try to use openai defaults or map them to something with the same intent
|
||||||
stopping_strings = default(shared.settings, 'custom_stopping_strings', [])
|
req_params = default_req_params.copy()
|
||||||
|
req_params['custom_stopping_strings'] = default_req_params['custom_stopping_strings'].copy()
|
||||||
|
|
||||||
if 'stop' in body:
|
if 'stop' in body:
|
||||||
if isinstance(body['stop'], str):
|
if isinstance(body['stop'], str):
|
||||||
stopping_strings = [body['stop']]
|
req_params['custom_stopping_strings'].extend([body['stop']])
|
||||||
elif isinstance(body['stop'], list):
|
elif isinstance(body['stop'], list):
|
||||||
stopping_strings = body['stop']
|
req_params['custom_stopping_strings'].extend(body['stop'])
|
||||||
|
|
||||||
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
||||||
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length)
|
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length)
|
||||||
@ -255,8 +240,6 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
|
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
|
||||||
# if the user assumes OpenAI, the max_tokens is way too large - try to ignore it unless it's small enough
|
# if the user assumes OpenAI, the max_tokens is way too large - try to ignore it unless it's small enough
|
||||||
|
|
||||||
req_params = default_req_params.copy()
|
|
||||||
|
|
||||||
req_params['max_new_tokens'] = max_tokens
|
req_params['max_new_tokens'] = max_tokens
|
||||||
req_params['truncation_length'] = truncation_length
|
req_params['truncation_length'] = truncation_length
|
||||||
req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
||||||
@ -319,9 +302,14 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
'prompt': bot_prompt,
|
'prompt': bot_prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if instruct['user']: # WizardLM and some others have no user prompt.
|
||||||
|
req_params['custom_stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
print(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
print(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
||||||
except:
|
except:
|
||||||
|
req_params['custom_stopping_strings'].extend(['\nuser:'])
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
print("Loaded default role format.")
|
print("Loaded default role format.")
|
||||||
|
|
||||||
@ -397,11 +385,6 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
||||||
print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}")
|
print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}")
|
||||||
|
|
||||||
# pass with some expected stop strings.
|
|
||||||
# some strange cases of "##| Instruction: " sneaking through.
|
|
||||||
stopping_strings += standard_stopping_strings
|
|
||||||
req_params['custom_stopping_strings'] = stopping_strings
|
|
||||||
|
|
||||||
if req_params['stream']:
|
if req_params['stream']:
|
||||||
shared.args.chat = True
|
shared.args.chat = True
|
||||||
# begin streaming
|
# begin streaming
|
||||||
@ -423,19 +406,17 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''}
|
chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''}
|
||||||
chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''}
|
chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''}
|
||||||
|
|
||||||
data_chunk = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
||||||
chunk_size = hex(len(data_chunk))[2:] + '\r\n'
|
|
||||||
response = chunk_size + data_chunk
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
# generate reply #######################################
|
# generate reply #######################################
|
||||||
if debug:
|
if debug:
|
||||||
print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings})
|
print({'prompt': prompt, 'req_params': req_params})
|
||||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
generator = generate_reply(prompt, req_params, is_chat=False)
|
||||||
|
|
||||||
answer = ''
|
answer = ''
|
||||||
seen_content = ''
|
seen_content = ''
|
||||||
longest_stop_len = max([len(x) for x in stopping_strings])
|
longest_stop_len = max([len(x) for x in req_params['custom_stopping_strings']] + [0])
|
||||||
|
|
||||||
for a in generator:
|
for a in generator:
|
||||||
answer = a
|
answer = a
|
||||||
@ -444,7 +425,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
len_seen = len(seen_content)
|
len_seen = len(seen_content)
|
||||||
search_start = max(len_seen - longest_stop_len, 0)
|
search_start = max(len_seen - longest_stop_len, 0)
|
||||||
|
|
||||||
for string in stopping_strings:
|
for string in req_params['custom_stopping_strings']:
|
||||||
idx = answer.find(string, search_start)
|
idx = answer.find(string, search_start)
|
||||||
if idx != -1:
|
if idx != -1:
|
||||||
answer = answer[:idx] # clip it.
|
answer = answer[:idx] # clip it.
|
||||||
@ -457,7 +438,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
# is completed, buffer and generate more, don't send it
|
# is completed, buffer and generate more, don't send it
|
||||||
buffer_and_continue = False
|
buffer_and_continue = False
|
||||||
|
|
||||||
for string in stopping_strings:
|
for string in req_params['custom_stopping_strings']:
|
||||||
for j in range(len(string) - 1, 0, -1):
|
for j in range(len(string) - 1, 0, -1):
|
||||||
if answer[-j:] == string[:j]:
|
if answer[-j:] == string[:j]:
|
||||||
buffer_and_continue = True
|
buffer_and_continue = True
|
||||||
@ -498,9 +479,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
# So yeah... do both methods? delta and messages.
|
# So yeah... do both methods? delta and messages.
|
||||||
chunk[resp_list][0]['message'] = {'content': new_content}
|
chunk[resp_list][0]['message'] = {'content': new_content}
|
||||||
chunk[resp_list][0]['delta'] = {'content': new_content}
|
chunk[resp_list][0]['delta'] = {'content': new_content}
|
||||||
data_chunk = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
||||||
chunk_size = hex(len(data_chunk))[2:] + '\r\n'
|
|
||||||
response = chunk_size + data_chunk
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
completion_token_count += len(encode(new_content)[0])
|
completion_token_count += len(encode(new_content)[0])
|
||||||
|
|
||||||
@ -527,10 +506,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
chunk[resp_list][0]['message'] = {'content': ''}
|
chunk[resp_list][0]['message'] = {'content': ''}
|
||||||
chunk[resp_list][0]['delta'] = {'content': ''}
|
chunk[resp_list][0]['delta'] = {'content': ''}
|
||||||
|
|
||||||
data_chunk = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
response = 'data: ' + json.dumps(chunk) + '\r\n\r\ndata: [DONE]\r\n\r\n'
|
||||||
chunk_size = hex(len(data_chunk))[2:] + '\r\n'
|
|
||||||
done = 'data: [DONE]\r\n\r\n'
|
|
||||||
response = chunk_size + data_chunk + done
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
# Finished if streaming.
|
# Finished if streaming.
|
||||||
if debug:
|
if debug:
|
||||||
@ -574,7 +550,12 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
response = json.dumps(resp)
|
response = json.dumps(resp)
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
elif '/edits' in self.path:
|
elif '/edits' in self.path:
|
||||||
|
if not shared.model:
|
||||||
|
self.openai_error("No model loaded.")
|
||||||
|
return
|
||||||
|
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_access_control_headers()
|
self.send_access_control_headers()
|
||||||
self.send_header('Content-Type', 'application/json')
|
self.send_header('Content-Type', 'application/json')
|
||||||
@ -586,15 +567,42 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
instruction = body['instruction']
|
instruction = body['instruction']
|
||||||
input = body.get('input', '')
|
input = body.get('input', '')
|
||||||
|
|
||||||
instruction_template = deduce_template()
|
# Request parameters
|
||||||
|
req_params = default_req_params.copy()
|
||||||
|
|
||||||
|
# Alpaca is verbose so a good default prompt
|
||||||
|
default_template = (
|
||||||
|
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||||
|
"Write a response that appropriately completes the request.\n\n"
|
||||||
|
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
instruction_template = default_template
|
||||||
|
req_params['custom_stopping_strings'] = [ '\n###' ]
|
||||||
|
|
||||||
|
# Use the special instruction/input/response template for anything trained like Alpaca
|
||||||
|
if not (shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']):
|
||||||
|
try:
|
||||||
|
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||||
|
|
||||||
|
template = instruct['turn_template']
|
||||||
|
template = template\
|
||||||
|
.replace('<|user|>', instruct.get('user', ''))\
|
||||||
|
.replace('<|bot|>', instruct.get('bot', ''))\
|
||||||
|
.replace('<|user-message|>', '{instruction}\n{input}')
|
||||||
|
|
||||||
|
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
||||||
|
if instruct['user']:
|
||||||
|
req_params['custom_stopping_strings'] = [ '\n' + instruct['user'], instruct['user'] ]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
edit_task = instruction_template.format(instruction=instruction, input=input)
|
edit_task = instruction_template.format(instruction=instruction, input=input)
|
||||||
|
|
||||||
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
||||||
token_count = len(encode(edit_task)[0])
|
token_count = len(encode(edit_task)[0])
|
||||||
max_tokens = truncation_length - token_count
|
max_tokens = truncation_length - token_count
|
||||||
|
|
||||||
req_params = default_req_params.copy()
|
|
||||||
|
|
||||||
req_params['max_new_tokens'] = max_tokens
|
req_params['max_new_tokens'] = max_tokens
|
||||||
req_params['truncation_length'] = truncation_length
|
req_params['truncation_length'] = truncation_length
|
||||||
req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
||||||
@ -605,7 +613,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
if debug:
|
if debug:
|
||||||
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
||||||
|
|
||||||
generator = generate_reply(edit_task, req_params, stopping_strings=standard_stopping_strings, is_chat=False)
|
generator = generate_reply(edit_task, req_params, is_chat=False)
|
||||||
|
|
||||||
answer = ''
|
answer = ''
|
||||||
for a in generator:
|
for a in generator:
|
||||||
@ -636,6 +644,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
response = json.dumps(resp)
|
response = json.dumps(resp)
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ:
|
elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ:
|
||||||
# Stable Diffusion callout wrapper for txt2img
|
# Stable Diffusion callout wrapper for txt2img
|
||||||
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
|
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
|
||||||
@ -682,6 +691,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
response = json.dumps(resp)
|
response = json.dumps(resp)
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
elif '/embeddings' in self.path and embedding_model is not None:
|
elif '/embeddings' in self.path and embedding_model is not None:
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_access_control_headers()
|
self.send_access_control_headers()
|
||||||
@ -715,6 +725,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
if debug:
|
if debug:
|
||||||
print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
elif '/moderations' in self.path:
|
elif '/moderations' in self.path:
|
||||||
# for now do nothing, just don't error.
|
# for now do nothing, just don't error.
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
@ -763,6 +774,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
}]
|
}]
|
||||||
})
|
})
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(self.path, self.headers)
|
print(self.path, self.headers)
|
||||||
self.send_error(404)
|
self.send_error(404)
|
||||||
|
@ -19,7 +19,7 @@ git+https://github.com/huggingface/transformers@e45e756d22206ca8fa9fb057c8c3d8fa
|
|||||||
git+https://github.com/huggingface/accelerate@0226f750257b3bf2cadc4f189f9eef0c764a0467
|
git+https://github.com/huggingface/accelerate@0226f750257b3bf2cadc4f189f9eef0c764a0467
|
||||||
bitsandbytes==0.39.0; platform_system != "Windows"
|
bitsandbytes==0.39.0; platform_system != "Windows"
|
||||||
https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.39.0-py3-none-any.whl; platform_system == "Windows"
|
https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.39.0-py3-none-any.whl; platform_system == "Windows"
|
||||||
llama-cpp-python==0.1.56; platform_system != "Windows"
|
llama-cpp-python==0.1.57; platform_system != "Windows"
|
||||||
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.56/llama_cpp_python-0.1.56-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.56/llama_cpp_python-0.1.56-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
||||||
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.0/auto_gptq-0.2.0+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.0/auto_gptq-0.2.0+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
||||||
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.0/auto_gptq-0.2.0+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux"
|
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.0/auto_gptq-0.2.0+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux"
|
||||||
|
Loading…
Reference in New Issue
Block a user