Merge pull request #2551 from oobabooga/dev

This commit is contained in:
oobabooga 2023-06-06 14:40:52 -03:00 committed by GitHub
commit 3cc5ce3c42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 340 additions and 259 deletions

1
.gitignore vendored
View File

@ -17,6 +17,7 @@ torch-dumps
venv/ venv/
.venv/ .venv/
.vscode .vscode
.idea/
*.bak *.bak
*.ipynb *.ipynb
*.log *.log

View File

@ -11,6 +11,7 @@
flex-direction: column-reverse; flex-direction: column-reverse;
word-break: break-word; word-break: break-word;
overflow-wrap: anywhere; overflow-wrap: anywhere;
padding-top: 1px;
} }
.message { .message {

View File

@ -9,6 +9,7 @@
flex-direction: column-reverse; flex-direction: column-reverse;
word-break: break-word; word-break: break-word;
overflow-wrap: anywhere; overflow-wrap: anywhere;
padding-top: 1px;
} }
.message { .message {

View File

@ -9,6 +9,7 @@
flex-direction: column-reverse; flex-direction: column-reverse;
word-break: break-word; word-break: break-word;
overflow-wrap: anywhere; overflow-wrap: anywhere;
padding-top: 1px;
} }
.message { .message {

View File

@ -9,6 +9,7 @@
flex-direction: column-reverse; flex-direction: column-reverse;
word-break: break-word; word-break: break-word;
overflow-wrap: anywhere; overflow-wrap: anywhere;
padding-top: 1px;
} }
.message { .message {

View File

@ -9,6 +9,7 @@
flex-direction: column-reverse; flex-direction: column-reverse;
word-break: break-word; word-break: break-word;
overflow-wrap: anywhere; overflow-wrap: anywhere;
padding-top: 1px;
} }
.message { .message {
@ -74,6 +75,12 @@
.dark .chat .assistant-message { .dark .chat .assistant-message {
background-color: #374151; background-color: #374151;
border: 1px solid #4b5563;
}
.dark .chat .user-message {
background-color: #111827;
border: 1px solid #4b5563;
} }
code { code {

View File

@ -160,11 +160,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
min_rows = 3 min_rows = 3
# Finding the maximum prompt size # Finding the maximum prompt size
chat_prompt_size = state['chat_prompt_size'] max_length = min(get_max_prompt_length(state), state['chat_prompt_size'])
if shared.soft_prompt:
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(state), chat_prompt_size)
# Building the turn templates # Building the turn templates
if 'turn_template' not in state or state['turn_template'] == '': if 'turn_template' not in state or state['turn_template'] == '':

View File

@ -103,7 +103,7 @@ class ModelDownloader:
classifications = [] classifications = []
has_pytorch = False has_pytorch = False
has_pt = False has_pt = False
has_ggml = False # has_ggml = False
has_safetensors = False has_safetensors = False
is_lora = False is_lora = False
while True: while True:
@ -148,7 +148,7 @@ class ModelDownloader:
has_pt = True has_pt = True
classifications.append('pt') classifications.append('pt')
elif is_ggml: elif is_ggml:
has_ggml = True # has_ggml = True
classifications.append('ggml') classifications.append('ggml')
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50' cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'

View File

@ -0,0 +1,69 @@
# Adding an ingress URL through the ngrok Agent SDK for Python
[ngrok](https://ngrok.com) is a globally distributed reverse proxy commonly used for quickly getting a public URL to a
service running inside a private network, such as on your local laptop. The ngrok agent is usually
deployed inside a private network and is used to communicate with the ngrok cloud service.
By default the authtoken in the NGROK_AUTHTOKEN environment variable will be used. Alternatively one may be specified in
the `settings.json` file, see the Examples below. Retrieve your authtoken on the [Auth Token page of your ngrok dashboard](https://dashboard.ngrok.com/get-started/your-authtoken), signing up is free.
# Documentation
For a list of all available options, see [the configuration documentation](https://ngrok.com/docs/ngrok-agent/config/) or [the connect example](https://github.com/ngrok/ngrok-py/blob/main/examples/ngrok-connect-full.py).
The ngrok Python SDK is [on github here](https://github.com/ngrok/ngrok-py). A quickstart guide and a full API reference are included in the [ngrok-py Python API documentation](https://ngrok.github.io/ngrok-py/).
# Running
To enable ngrok install the requirements and then add `--extension ngrok` to the command line options, for instance:
```bash
pip install -r extensions/ngrok/requirements.txt
python server.py --extension ngrok
```
In the output you should then see something like this:
```bash
INFO:Loading the extension "ngrok"...
INFO:Session created
INFO:Created tunnel "9d9d0944dc75ff9d3aae653e5eb29fe9" with url "https://d83706cf7be7.ngrok.app"
INFO:Tunnel "9d9d0944dc75ff9d3aae653e5eb29fe9" TCP forwarding to "localhost:7860"
INFO:Ingress established at https://d83706cf7be7.ngrok.app
```
You can now access the webui via the url shown, in this case `https://d83706cf7be7.ngrok.app`. It is recommended to add some authentication to the ingress, see below.
# Example Settings
In `settings.json` add a `ngrok` key with a dictionary of options, for instance:
To enable basic authentication:
```json
{
"ngrok": {
"basic_auth": "user:password"
}
}
```
To enable OAUTH authentication:
```json
{
"ngrok": {
"oauth_provider": "google",
"oauth_allow_domains": "asdf.com",
"oauth_allow_emails": "asdf@asdf.com"
}
}
```
To add an authtoken instead of using the NGROK_AUTHTOKEN environment variable:
```json
{
"ngrok": {
"authtoken": "<token>",
"authtoken_from_env":false
}
}
```

View File

@ -0,0 +1 @@
ngrok==0.*

View File

@ -0,0 +1,36 @@
# Adds ngrok ingress, to use add `--extension ngrok` to the command line options
#
# Parameters can be customized in settings.json of webui, e.g.:
# {"ngrok": {"basic_auth":"user:password"} }
# or
# {"ngrok": {"oauth_provider":"google", "oauth_allow_emails":["asdf@asdf.com"]} }
#
# See this example for full list of options: https://github.com/ngrok/ngrok-py/blob/main/examples/ngrok-connect-full.py
# or the README.md in this directory.
import logging
from modules import shared
# Pick up host/port command line arguments
host = shared.args.listen_host if shared.args.listen_host and shared.args.listen else '127.0.0.1'
port = shared.args.listen_port if shared.args.listen_port else '7860'
# Default options
options = {
'addr': f"{host}:{port}",
'authtoken_from_env': True,
'session_metadata': 'text-generation-webui',
}
def ui():
settings = shared.settings.get("ngrok")
if settings:
options.update(settings)
try:
import ngrok
tunnel = ngrok.connect(**options)
logging.info(f"Ingress established at: {tunnel.url()}")
except ModuleNotFoundError:
logging.error("===> ngrok library not found, please run `pip install -r extensions/ngrok/requirements.txt`")

View File

@ -20,6 +20,12 @@ Example:
SD_WEBUI_URL=http://127.0.0.1:7861 SD_WEBUI_URL=http://127.0.0.1:7861
``` ```
Make sure you enable it in server launch parameters. Just make sure they include:
```
--extensions openai
```
### Embeddings (alpha) ### Embeddings (alpha)
Embeddings requires ```sentence-transformers``` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: ```sentence-transformers/all-mpnet-base-v2``` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default ```text-embedding-ada-002``` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future. Embeddings requires ```sentence-transformers``` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: ```sentence-transformers/all-mpnet-base-v2``` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default ```text-embedding-ada-002``` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future.
@ -42,7 +48,7 @@ Almost everything you use it with will require you to set a dummy OpenAI API key
With the [official python openai client](https://github.com/openai/openai-python), you can set the OPENAI_API_BASE environment variable before you import the openai module, like so: With the [official python openai client](https://github.com/openai/openai-python), you can set the OPENAI_API_BASE environment variable before you import the openai module, like so:
``` ```
OPENAI_API_KEY=dummy OPENAI_API_KEY=sk-dummy
OPENAI_API_BASE=http://127.0.0.1:5001/v1 OPENAI_API_BASE=http://127.0.0.1:5001/v1
``` ```

View File

@ -20,6 +20,7 @@ params = {
debug = True if 'OPENEDAI_DEBUG' in os.environ else False debug = True if 'OPENEDAI_DEBUG' in os.environ else False
# Slightly different defaults for OpenAI's API # Slightly different defaults for OpenAI's API
# Data type is important, Ex. use 0.0 for a float 0
default_req_params = { default_req_params = {
'max_new_tokens': 200, 'max_new_tokens': 200,
'temperature': 1.0, 'temperature': 1.0,
@ -44,14 +45,14 @@ default_req_params = {
'no_repeat_ngram_size': 0, 'no_repeat_ngram_size': 0,
'num_beams': 1, 'num_beams': 1,
'penalty_alpha': 0.0, 'penalty_alpha': 0.0,
'length_penalty': 1, 'length_penalty': 1.0,
'early_stopping': False, 'early_stopping': False,
'mirostat_mode': 0, 'mirostat_mode': 0,
'mirostat_tau': 5, 'mirostat_tau': 5.0,
'mirostat_eta': 0.1, 'mirostat_eta': 0.1,
'ban_eos_token': False, 'ban_eos_token': False,
'skip_special_tokens': True, 'skip_special_tokens': True,
'custom_stopping_strings': [], 'custom_stopping_strings': ['\n###'],
} }
# Optional, install the module and download the model to enable # Optional, install the module and download the model to enable
@ -64,8 +65,6 @@ except ImportError:
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2" st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
embedding_model = None embedding_model = None
standard_stopping_strings = ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ]
# little helper to get defaults if arg is present but None and should be the same type as default. # little helper to get defaults if arg is present but None and should be the same type as default.
def default(dic, key, default): def default(dic, key, default):
val = dic.get(key, default) val = dic.get(key, default)
@ -86,31 +85,6 @@ def clamp(value, minvalue, maxvalue):
return max(minvalue, min(value, maxvalue)) return max(minvalue, min(value, maxvalue))
def deduce_template():
# Alpaca is verbose so a good default prompt
default_template = (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
)
# Use the special instruction/input/response template for anything trained like Alpaca
if shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']:
return default_template
try:
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
template = instruct['turn_template']
template = template\
.replace('<|user|>', instruct.get('user', ''))\
.replace('<|bot|>', instruct.get('bot', ''))\
.replace('<|user-message|>', '{instruction}\n{input}')
return instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
except:
return default_template
def float_list_to_base64(float_list): def float_list_to_base64(float_list):
# Convert the list to a float32 array that the OpenAPI client expects # Convert the list to a float32 array that the OpenAPI client expects
float_array = np.array(float_list, dtype="float32") float_array = np.array(float_list, dtype="float32")
@ -141,6 +115,25 @@ class Handler(BaseHTTPRequestHandler):
"Authorization" "Authorization"
) )
def openai_error(self, message, code = 500, error_type = 'APIError', param = '', internal_message = ''):
self.send_response(code)
self.send_access_control_headers()
self.send_header('Content-Type', 'application/json')
self.end_headers()
error_resp = {
'error': {
'message': message,
'code': code,
'type': error_type,
'param': param,
}
}
if internal_message:
error_resp['internal_message'] = internal_message
response = json.dumps(error_resp)
self.wfile.write(response.encode('utf-8'))
def do_OPTIONS(self): def do_OPTIONS(self):
self.send_response(200) self.send_response(200)
self.send_access_control_headers() self.send_access_control_headers()
@ -150,42 +143,24 @@ class Handler(BaseHTTPRequestHandler):
def do_GET(self): def do_GET(self):
if self.path.startswith('/v1/models'): if self.path.startswith('/v1/models'):
self.send_response(200) self.send_response(200)
self.send_access_control_headers() self.send_access_control_headers()
self.send_header('Content-Type', 'application/json') self.send_header('Content-Type', 'application/json')
self.end_headers() self.end_headers()
# TODO: list all models and allow model changes via API? Lora's? # TODO: Lora's?
# This API should list capabilities, limits and pricing... # This API should list capabilities, limits and pricing...
models = [{ current_model_list = [ shared.model_name ] # The real chat/completions model
"id": shared.model_name, # The real chat/completions model embeddings_model_list = [ st_model ] if embedding_model else [] # The real sentence transformer embeddings model
"object": "model", pseudo_model_list = [ # these are expected by so much, so include some here as a dummy
"owned_by": "user", 'gpt-3.5-turbo', # /v1/chat/completions
"permission": [] 'text-curie-001', # /v1/completions, 2k context
}, { 'text-davinci-002' # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768
"id": st_model, # The real sentence transformer embeddings model ]
"object": "model", available_model_list = get_available_models()
"owned_by": "user", all_model_list = current_model_list + embeddings_model_list + pseudo_model_list + available_model_list
"permission": []
}, { # these are expected by so much, so include some here as a dummy
"id": "gpt-3.5-turbo", # /v1/chat/completions
"object": "model",
"owned_by": "user",
"permission": []
}, {
"id": "text-curie-001", # /v1/completions, 2k context
"object": "model",
"owned_by": "user",
"permission": []
}, {
"id": "text-davinci-002", # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768
"object": "model",
"owned_by": "user",
"permission": []
}]
models.extend([{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in get_available_models() ]) models = [{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in all_model_list ]
response = '' response = ''
if self.path == '/v1/models': if self.path == '/v1/models':
@ -203,6 +178,7 @@ class Handler(BaseHTTPRequestHandler):
}) })
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
elif '/billing/usage' in self.path: elif '/billing/usage' in self.path:
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31 # Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
self.send_response(200) self.send_response(200)
@ -214,6 +190,7 @@ class Handler(BaseHTTPRequestHandler):
"total_usage": 0, "total_usage": 0,
}) })
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
else: else:
self.send_error(404) self.send_error(404)
@ -227,6 +204,11 @@ class Handler(BaseHTTPRequestHandler):
print(body) print(body)
if '/completions' in self.path or '/generate' in self.path: if '/completions' in self.path or '/generate' in self.path:
if not shared.model:
self.openai_error("No model loaded.")
return
is_legacy = '/generate' in self.path is_legacy = '/generate' in self.path
is_chat = 'chat' in self.path is_chat = 'chat' in self.path
resp_list = 'data' if is_legacy else 'choices' resp_list = 'data' if is_legacy else 'choices'
@ -238,13 +220,16 @@ class Handler(BaseHTTPRequestHandler):
cmpl_id = "chatcmpl-%d" % (created_time) if is_chat else "conv-%d" % (created_time) cmpl_id = "chatcmpl-%d" % (created_time) if is_chat else "conv-%d" % (created_time)
# Request Parameters
# Try to use openai defaults or map them to something with the same intent # Try to use openai defaults or map them to something with the same intent
stopping_strings = default(shared.settings, 'custom_stopping_strings', []) req_params = default_req_params.copy()
req_params['custom_stopping_strings'] = default_req_params['custom_stopping_strings'].copy()
if 'stop' in body: if 'stop' in body:
if isinstance(body['stop'], str): if isinstance(body['stop'], str):
stopping_strings = [body['stop']] req_params['custom_stopping_strings'].extend([body['stop']])
elif isinstance(body['stop'], list): elif isinstance(body['stop'], list):
stopping_strings = body['stop'] req_params['custom_stopping_strings'].extend(body['stop'])
truncation_length = default(shared.settings, 'truncation_length', 2048) truncation_length = default(shared.settings, 'truncation_length', 2048)
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length) truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length)
@ -255,8 +240,6 @@ class Handler(BaseHTTPRequestHandler):
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens)) max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
# if the user assumes OpenAI, the max_tokens is way too large - try to ignore it unless it's small enough # if the user assumes OpenAI, the max_tokens is way too large - try to ignore it unless it's small enough
req_params = default_req_params.copy()
req_params['max_new_tokens'] = max_tokens req_params['max_new_tokens'] = max_tokens
req_params['truncation_length'] = truncation_length req_params['truncation_length'] = truncation_length
req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0 req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
@ -319,9 +302,14 @@ class Handler(BaseHTTPRequestHandler):
'prompt': bot_prompt, 'prompt': bot_prompt,
} }
if instruct['user']: # WizardLM and some others have no user prompt.
req_params['custom_stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
if debug: if debug:
print(f"Loaded instruction role format: {shared.settings['instruction_template']}") print(f"Loaded instruction role format: {shared.settings['instruction_template']}")
except: except:
req_params['custom_stopping_strings'].extend(['\nuser:'])
if debug: if debug:
print("Loaded default role format.") print("Loaded default role format.")
@ -397,11 +385,6 @@ class Handler(BaseHTTPRequestHandler):
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}") print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}")
# pass with some expected stop strings.
# some strange cases of "##| Instruction: " sneaking through.
stopping_strings += standard_stopping_strings
req_params['custom_stopping_strings'] = stopping_strings
if req_params['stream']: if req_params['stream']:
shared.args.chat = True shared.args.chat = True
# begin streaming # begin streaming
@ -423,19 +406,17 @@ class Handler(BaseHTTPRequestHandler):
chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''} chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''}
chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''} chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''}
data_chunk = 'data: ' + json.dumps(chunk) + '\r\n\r\n' response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
chunk_size = hex(len(data_chunk))[2:] + '\r\n'
response = chunk_size + data_chunk
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
# generate reply ####################################### # generate reply #######################################
if debug: if debug:
print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings}) print({'prompt': prompt, 'req_params': req_params})
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) generator = generate_reply(prompt, req_params, is_chat=False)
answer = '' answer = ''
seen_content = '' seen_content = ''
longest_stop_len = max([len(x) for x in stopping_strings]) longest_stop_len = max([len(x) for x in req_params['custom_stopping_strings']] + [0])
for a in generator: for a in generator:
answer = a answer = a
@ -444,7 +425,7 @@ class Handler(BaseHTTPRequestHandler):
len_seen = len(seen_content) len_seen = len(seen_content)
search_start = max(len_seen - longest_stop_len, 0) search_start = max(len_seen - longest_stop_len, 0)
for string in stopping_strings: for string in req_params['custom_stopping_strings']:
idx = answer.find(string, search_start) idx = answer.find(string, search_start)
if idx != -1: if idx != -1:
answer = answer[:idx] # clip it. answer = answer[:idx] # clip it.
@ -457,7 +438,7 @@ class Handler(BaseHTTPRequestHandler):
# is completed, buffer and generate more, don't send it # is completed, buffer and generate more, don't send it
buffer_and_continue = False buffer_and_continue = False
for string in stopping_strings: for string in req_params['custom_stopping_strings']:
for j in range(len(string) - 1, 0, -1): for j in range(len(string) - 1, 0, -1):
if answer[-j:] == string[:j]: if answer[-j:] == string[:j]:
buffer_and_continue = True buffer_and_continue = True
@ -498,9 +479,7 @@ class Handler(BaseHTTPRequestHandler):
# So yeah... do both methods? delta and messages. # So yeah... do both methods? delta and messages.
chunk[resp_list][0]['message'] = {'content': new_content} chunk[resp_list][0]['message'] = {'content': new_content}
chunk[resp_list][0]['delta'] = {'content': new_content} chunk[resp_list][0]['delta'] = {'content': new_content}
data_chunk = 'data: ' + json.dumps(chunk) + '\r\n\r\n' response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
chunk_size = hex(len(data_chunk))[2:] + '\r\n'
response = chunk_size + data_chunk
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
completion_token_count += len(encode(new_content)[0]) completion_token_count += len(encode(new_content)[0])
@ -527,10 +506,7 @@ class Handler(BaseHTTPRequestHandler):
chunk[resp_list][0]['message'] = {'content': ''} chunk[resp_list][0]['message'] = {'content': ''}
chunk[resp_list][0]['delta'] = {'content': ''} chunk[resp_list][0]['delta'] = {'content': ''}
data_chunk = 'data: ' + json.dumps(chunk) + '\r\n\r\n' response = 'data: ' + json.dumps(chunk) + '\r\n\r\ndata: [DONE]\r\n\r\n'
chunk_size = hex(len(data_chunk))[2:] + '\r\n'
done = 'data: [DONE]\r\n\r\n'
response = chunk_size + data_chunk + done
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
# Finished if streaming. # Finished if streaming.
if debug: if debug:
@ -574,7 +550,12 @@ class Handler(BaseHTTPRequestHandler):
response = json.dumps(resp) response = json.dumps(resp)
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
elif '/edits' in self.path: elif '/edits' in self.path:
if not shared.model:
self.openai_error("No model loaded.")
return
self.send_response(200) self.send_response(200)
self.send_access_control_headers() self.send_access_control_headers()
self.send_header('Content-Type', 'application/json') self.send_header('Content-Type', 'application/json')
@ -586,15 +567,42 @@ class Handler(BaseHTTPRequestHandler):
instruction = body['instruction'] instruction = body['instruction']
input = body.get('input', '') input = body.get('input', '')
instruction_template = deduce_template() # Request parameters
req_params = default_req_params.copy()
# Alpaca is verbose so a good default prompt
default_template = (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
)
instruction_template = default_template
req_params['custom_stopping_strings'] = [ '\n###' ]
# Use the special instruction/input/response template for anything trained like Alpaca
if not (shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']):
try:
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
template = instruct['turn_template']
template = template\
.replace('<|user|>', instruct.get('user', ''))\
.replace('<|bot|>', instruct.get('bot', ''))\
.replace('<|user-message|>', '{instruction}\n{input}')
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
if instruct['user']:
req_params['custom_stopping_strings'] = [ '\n' + instruct['user'], instruct['user'] ]
except:
pass
edit_task = instruction_template.format(instruction=instruction, input=input) edit_task = instruction_template.format(instruction=instruction, input=input)
truncation_length = default(shared.settings, 'truncation_length', 2048) truncation_length = default(shared.settings, 'truncation_length', 2048)
token_count = len(encode(edit_task)[0]) token_count = len(encode(edit_task)[0])
max_tokens = truncation_length - token_count max_tokens = truncation_length - token_count
req_params = default_req_params.copy()
req_params['max_new_tokens'] = max_tokens req_params['max_new_tokens'] = max_tokens
req_params['truncation_length'] = truncation_length req_params['truncation_length'] = truncation_length
req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0 req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
@ -605,7 +613,7 @@ class Handler(BaseHTTPRequestHandler):
if debug: if debug:
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count}) print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
generator = generate_reply(edit_task, req_params, stopping_strings=standard_stopping_strings, is_chat=False) generator = generate_reply(edit_task, req_params, is_chat=False)
answer = '' answer = ''
for a in generator: for a in generator:
@ -636,6 +644,7 @@ class Handler(BaseHTTPRequestHandler):
response = json.dumps(resp) response = json.dumps(resp)
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ: elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ:
# Stable Diffusion callout wrapper for txt2img # Stable Diffusion callout wrapper for txt2img
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E # Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
@ -682,6 +691,7 @@ class Handler(BaseHTTPRequestHandler):
response = json.dumps(resp) response = json.dumps(resp)
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
elif '/embeddings' in self.path and embedding_model is not None: elif '/embeddings' in self.path and embedding_model is not None:
self.send_response(200) self.send_response(200)
self.send_access_control_headers() self.send_access_control_headers()
@ -715,6 +725,7 @@ class Handler(BaseHTTPRequestHandler):
if debug: if debug:
print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}") print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
elif '/moderations' in self.path: elif '/moderations' in self.path:
# for now do nothing, just don't error. # for now do nothing, just don't error.
self.send_response(200) self.send_response(200)
@ -763,6 +774,7 @@ class Handler(BaseHTTPRequestHandler):
}] }]
}) })
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
else: else:
print(self.path, self.headers) print(self.path, self.headers)
self.send_error(404) self.send_error(404)

View File

@ -5,6 +5,14 @@ from peft import PeftModel
import modules.shared as shared import modules.shared as shared
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.models import reload_model
try:
from auto_gptq import get_gptq_peft_model
from auto_gptq.utils.peft_utils import GPTQLoraConfig
has_auto_gptq_peft = True
except:
has_auto_gptq_peft = False
def add_lora_to_model(lora_names): def add_lora_to_model(lora_names):
@ -13,6 +21,35 @@ def add_lora_to_model(lora_names):
removed_set = prior_set - set(lora_names) removed_set = prior_set - set(lora_names)
shared.lora_names = list(lora_names) shared.lora_names = list(lora_names)
is_autogptq = 'GPTQForCausalLM' in shared.model.__class__.__name__
# AutoGPTQ case. It doesn't use the peft functions.
# Copied from https://github.com/Ph0rk0z/text-generation-webui-testing
if is_autogptq:
if not has_auto_gptq_peft:
logger.error("This version of AutoGPTQ does not support LoRA. You need to install from source or wait for a new release.")
return
if len(prior_set) > 0:
reload_model()
if len(shared.lora_names) == 0:
return
else:
if len(shared.lora_names) > 1:
logger.warning('AutoGPTQ can only work with 1 LoRA at the moment. Only the first one in the list will be loaded')
peft_config = GPTQLoraConfig(
inference_mode=True,
)
lora_path = Path(f"{shared.args.lora_dir}/{shared.lora_names[0]}")
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join([lora_names[0]])))
shared.model = get_gptq_peft_model(shared.model, peft_config, lora_path)
return
# Transformers case
else:
# If no LoRA needs to be added or removed, exit # If no LoRA needs to be added or removed, exit
if len(added_set) == 0 and len(removed_set) == 0: if len(added_set) == 0 and len(removed_set) == 0:
return return

View File

@ -55,11 +55,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
is_instruct = state['mode'] == 'instruct' is_instruct = state['mode'] == 'instruct'
# Find the maximum prompt size # Find the maximum prompt size
chat_prompt_size = state['chat_prompt_size'] max_length = min(get_max_prompt_length(state), state['chat_prompt_size'])
if shared.soft_prompt:
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(state), chat_prompt_size)
all_substrings = { all_substrings = {
'chat': get_turn_substrings(state, instruct=False), 'chat': get_turn_substrings(state, instruct=False),
'instruct': get_turn_substrings(state, instruct=True) 'instruct': get_turn_substrings(state, instruct=True)
@ -441,6 +437,9 @@ def save_history(mode, timestamp=False):
fname = f"Instruct_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json" fname = f"Instruct_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
else: else:
if shared.character == 'None':
return
if timestamp: if timestamp:
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json" fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
else: else:
@ -564,7 +563,7 @@ def load_character(character, name1, name2, instruct=False):
if not instruct: if not instruct:
shared.history['internal'] = [] shared.history['internal'] = []
shared.history['visible'] = [] shared.history['visible'] = []
if Path(f'logs/{shared.character}_persistent.json').exists(): if shared.character != 'None' and Path(f'logs/{shared.character}_persistent.json').exists():
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2) load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
else: else:
# Insert greeting if it exists # Insert greeting if it exists

View File

@ -1,5 +1,4 @@
import datetime import datetime
import traceback
from pathlib import Path from pathlib import Path
import pandas as pd import pandas as pd

View File

@ -25,7 +25,6 @@ class LlamaCppModel:
@classmethod @classmethod
def from_pretrained(self, path): def from_pretrained(self, path):
result = self() result = self()
cache_capacity = 0 cache_capacity = 0
if shared.args.cache_capacity is not None: if shared.args.cache_capacity is not None:
if 'GiB' in shared.args.cache_capacity: if 'GiB' in shared.args.cache_capacity:
@ -36,7 +35,6 @@ class LlamaCppModel:
cache_capacity = int(shared.args.cache_capacity) cache_capacity = int(shared.args.cache_capacity)
logger.info("Cache capacity is " + str(cache_capacity) + " bytes") logger.info("Cache capacity is " + str(cache_capacity) + " bytes")
params = { params = {
'model_path': str(path), 'model_path': str(path),
'n_ctx': shared.args.n_ctx, 'n_ctx': shared.args.n_ctx,
@ -47,6 +45,7 @@ class LlamaCppModel:
'use_mlock': shared.args.mlock, 'use_mlock': shared.args.mlock,
'n_gpu_layers': shared.args.n_gpu_layers 'n_gpu_layers': shared.args.n_gpu_layers
} }
self.model = Llama(**params) self.model = Llama(**params)
if cache_capacity > 0: if cache_capacity > 0:
self.model.set_cache(LlamaCache(capacity_bytes=cache_capacity)) self.model.set_cache(LlamaCache(capacity_bytes=cache_capacity))
@ -57,6 +56,7 @@ class LlamaCppModel:
def encode(self, string): def encode(self, string):
if type(string) is str: if type(string) is str:
string = string.encode() string = string.encode()
return self.model.tokenize(string) return self.model.tokenize(string)
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, mirostat_mode=0, mirostat_tau=5, mirostat_eta=0.1, callback=None): def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, mirostat_mode=0, mirostat_tau=5, mirostat_eta=0.1, callback=None):
@ -73,12 +73,14 @@ class LlamaCppModel:
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
stream=True stream=True
) )
output = "" output = ""
for completion_chunk in completion_chunks: for completion_chunk in completion_chunks:
text = completion_chunk['choices'][0]['text'] text = completion_chunk['choices'][0]['text']
output += text output += text
if callback: if callback:
callback(text) callback(text)
return output return output
def generate_with_streaming(self, **kwargs): def generate_with_streaming(self, **kwargs):

View File

@ -1,12 +1,9 @@
import gc import gc
import json
import os import os
import re import re
import time import time
import zipfile
from pathlib import Path from pathlib import Path
import numpy as np
import torch import torch
import transformers import transformers
from accelerate import infer_auto_device_map, init_empty_weights from accelerate import infer_auto_device_map, init_empty_weights
@ -338,32 +335,3 @@ def unload_model():
def reload_model(): def reload_model():
unload_model() unload_model()
shared.model, shared.tokenizer = load_model(shared.model_name) shared.model, shared.tokenizer = load_model(shared.model_name)
def load_soft_prompt(name):
if name == 'None':
shared.soft_prompt = False
shared.soft_prompt_tensor = None
else:
with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf:
zf.extract('tensor.npy')
zf.extract('meta.json')
j = json.loads(open('meta.json', 'r').read())
logger.info(f"\nLoading the softprompt \"{name}\".")
for field in j:
if field != 'name':
if type(j[field]) is list:
logger.info(f"{field}: {', '.join(j[field])}")
else:
logger.info(f"{field}: {j[field]}")
tensor = np.load('tensor.npy')
Path('tensor.npy').unlink()
Path('meta.json').unlink()
tensor = torch.Tensor(tensor).to(device=shared.model.device, dtype=shared.model.dtype)
tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
shared.soft_prompt = True
shared.soft_prompt_tensor = tensor
return name

View File

@ -12,8 +12,6 @@ tokenizer = None
model_name = "None" model_name = "None"
model_type = None model_type = None
lora_names = [] lora_names = []
soft_prompt_tensor = None
soft_prompt = False
# Chat variables # Chat variables
history = {'internal': [], 'visible': []} history = {'internal': [], 'visible': []}
@ -61,7 +59,7 @@ settings = {
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>', 'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
'chat_prompt_size': 2048, 'chat_prompt_size': 2048,
'chat_prompt_size_min': 0, 'chat_prompt_size_min': 0,
'chat_prompt_size_max': 2048, 'chat_prompt_size_max': 8192,
'chat_generation_attempts': 1, 'chat_generation_attempts': 1,
'chat_generation_attempts_min': 1, 'chat_generation_attempts_min': 1,
'chat_generation_attempts_max': 10, 'chat_generation_attempts_max': 10,

View File

@ -1,7 +1,6 @@
import ast import ast
import random import random
import re import re
import threading
import time import time
import traceback import traceback
@ -28,11 +27,7 @@ def generate_reply(*args, **kwargs):
def get_max_prompt_length(state): def get_max_prompt_length(state):
max_length = state['truncation_length'] - state['max_new_tokens'] return state['truncation_length'] - state['max_new_tokens']
if shared.soft_prompt:
max_length -= shared.soft_prompt_tensor.shape[1]
return max_length
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None): def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
@ -81,14 +76,6 @@ def decode(output_ids, skip_special_tokens=True):
return shared.tokenizer.decode(output_ids, skip_special_tokens) return shared.tokenizer.decode(output_ids, skip_special_tokens)
def generate_softprompt_input_tensors(input_ids):
inputs_embeds = shared.model.transformer.wte(input_ids)
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
return inputs_embeds, filler_input_ids
# Removes empty replies from gpt4chan outputs # Removes empty replies from gpt4chan outputs
def fix_gpt4chan(s): def fix_gpt4chan(s):
for i in range(10): for i in range(10):
@ -233,13 +220,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
eos_token_ids.append(int(encode(eos_token)[0][-1])) eos_token_ids.append(int(encode(eos_token)[0][-1]))
# Add the encoded tokens to generate_params # Add the encoded tokens to generate_params
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds)
original_input_ids = input_ids
generate_params.update({'inputs_embeds': inputs_embeds})
generate_params.update({'inputs': filler_input_ids})
else:
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None) question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
original_input_ids = input_ids original_input_ids = input_ids
generate_params.update({'inputs': input_ids}) generate_params.update({'inputs': input_ids})
@ -270,9 +250,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
if cuda: if cuda:
output = output.cuda() output = output.cuda()
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat) yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
# Stream the reply 1 token at a time. # Stream the reply 1 token at a time.
@ -290,9 +267,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
with generate_with_streaming(**generate_params) as generator: with generate_with_streaming(**generate_params) as generator:
for output in generator: for output in generator:
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat) yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
if output[-1] in eos_token_ids: if output[-1] in eos_token_ids:
break break

View File

@ -60,10 +60,6 @@ def get_available_extensions():
return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys) return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys)
def get_available_softprompts():
return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=natural_keys)
def get_available_loras(): def get_available_loras():
return sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=natural_keys) return sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=natural_keys)

View File

@ -3,7 +3,7 @@ datasets
einops einops
flexgen==0.1.7 flexgen==0.1.7
gradio_client==0.2.5 gradio_client==0.2.5
gradio==3.31.0 gradio==3.33.1
markdown markdown
numpy numpy
pandas pandas
@ -19,7 +19,7 @@ git+https://github.com/huggingface/transformers@e45e756d22206ca8fa9fb057c8c3d8fa
git+https://github.com/huggingface/accelerate@0226f750257b3bf2cadc4f189f9eef0c764a0467 git+https://github.com/huggingface/accelerate@0226f750257b3bf2cadc4f189f9eef0c764a0467
bitsandbytes==0.39.0; platform_system != "Windows" bitsandbytes==0.39.0; platform_system != "Windows"
https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.39.0-py3-none-any.whl; platform_system == "Windows" https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.39.0-py3-none-any.whl; platform_system == "Windows"
llama-cpp-python==0.1.56; platform_system != "Windows" llama-cpp-python==0.1.57; platform_system != "Windows"
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.56/llama_cpp_python-0.1.56-cp310-cp310-win_amd64.whl; platform_system == "Windows" https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.57/llama_cpp_python-0.1.57-cp310-cp310-win_amd64.whl; platform_system == "Windows"
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.0/auto_gptq-0.2.0+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows" https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.0/auto_gptq-0.2.0+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.0/auto_gptq-0.2.0+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.0/auto_gptq-0.2.0+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux"

View File

@ -26,7 +26,6 @@ import matplotlib
matplotlib.use('Agg') # This fixes LaTeX rendering on some systems matplotlib.use('Agg') # This fixes LaTeX rendering on some systems
import importlib import importlib
import io
import json import json
import math import math
import os import os
@ -34,7 +33,6 @@ import re
import sys import sys
import time import time
import traceback import traceback
import zipfile
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
@ -50,7 +48,7 @@ from modules import chat, shared, training, ui, utils
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.html_generator import chat_html_wrapper from modules.html_generator import chat_html_wrapper
from modules.LoRA import add_lora_to_model from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt, unload_model from modules.models import load_model, unload_model
from modules.text_generation import (generate_reply_wrapper, from modules.text_generation import (generate_reply_wrapper,
get_encoded_length, stop_everything_event) get_encoded_length, stop_everything_event)
@ -119,19 +117,6 @@ def load_preset_values(preset_menu, state, return_dict=False):
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']] return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
def upload_soft_prompt(file):
with zipfile.ZipFile(io.BytesIO(file)) as zf:
zf.extract('meta.json')
j = json.loads(open('meta.json', 'r').read())
name = j['name']
Path('meta.json').unlink()
with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
f.write(file)
return name
def open_save_prompt(): def open_save_prompt():
fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}" fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
return gr.update(value=fname, visible=True), gr.update(visible=False), gr.update(visible=True) return gr.update(value=fname, visible=True), gr.update(visible=False), gr.update(visible=True)
@ -392,13 +377,12 @@ def create_model_menus():
with gr.Box(): with gr.Box():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
gr.Markdown('AutoGPTQ') gr.Markdown('GPTQ')
shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton) shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton)
shared.gradio['desc_act'] = gr.Checkbox(label="desc_act", value=shared.args.desc_act, info='\'desc_act\', \'wbits\', and \'groupsize\' are used for old models without a quantize_config.json.') shared.gradio['desc_act'] = gr.Checkbox(label="desc_act", value=shared.args.desc_act, info='\'desc_act\', \'wbits\', and \'groupsize\' are used for old models without a quantize_config.json.')
shared.gradio['gptq_for_llama'] = gr.Checkbox(label="gptq-for-llama", value=shared.args.gptq_for_llama, info='Use GPTQ-for-LLaMa loader instead of AutoGPTQ. pre_layer should be used for CPU offloading instead of gpu-memory.')
with gr.Column(): with gr.Column():
gr.Markdown('GPTQ-for-LLaMa')
shared.gradio['gptq_for_llama'] = gr.Checkbox(label="gptq-for-llama", value=shared.args.gptq_for_llama, info='Use GPTQ-for-LLaMa to load the GPTQ model instead of AutoGPTQ. pre_layer should be used for CPU offloading instead of gpu-memory.')
with gr.Row(): with gr.Row():
shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=shared.args.wbits if shared.args.wbits > 0 else "None") shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=shared.args.wbits if shared.args.wbits > 0 else "None")
shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=shared.args.groupsize if shared.args.groupsize > 0 else "None") shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=shared.args.groupsize if shared.args.groupsize > 0 else "None")
@ -457,10 +441,24 @@ def create_model_menus():
shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), shared.gradio['autoload_model'], load) shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), shared.gradio['autoload_model'], load)
def create_chat_settings_menus():
if not shared.is_chat():
return
with gr.Box():
gr.Markdown("Chat parameters")
with gr.Row():
with gr.Column():
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
shared.gradio['chat_prompt_size'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='chat_prompt_size', info='Set limit on prompt size by removing old messages (while retaining context and user input)', value=shared.settings['chat_prompt_size'])
with gr.Column():
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)', info='New generations will be called until either this number is reached or no new content is generated between two iterations.')
shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character')
def create_settings_menus(default_preset): def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True) generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
@ -493,13 +491,14 @@ def create_settings_menus(default_preset):
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample') shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
with gr.Column(): with gr.Column():
create_chat_settings_menus()
with gr.Box(): with gr.Box():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
gr.Markdown('Contrastive search') gr.Markdown('Contrastive search')
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha', info='Contrastive Search is enabled by setting this to greater than zero and unchecking "do_sample". It should be used with a low value of top_k, for instance, top_k = 4.') shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha', info='Contrastive Search is enabled by setting this to greater than zero and unchecking "do_sample". It should be used with a low value of top_k, for instance, top_k = 4.')
gr.Markdown('Beam search (uses a lot of VRAM)') gr.Markdown('Beam search')
shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams') shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams')
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
@ -510,16 +509,6 @@ def create_settings_menus(default_preset):
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau') shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta') shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
gr.Markdown('Other')
with gr.Accordion('Soft prompt', open=False):
with gr.Row():
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=utils.get_available_softprompts(), value='None', label='Soft prompt')
ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda: None, lambda: {'choices': utils.get_available_softprompts()}, 'refresh-button')
gr.Markdown('Upload a soft prompt (.zip format):')
with gr.Row():
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
with gr.Box(): with gr.Box():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -535,8 +524,6 @@ def create_settings_menus(default_preset):
gr.Markdown('[Click here for more information.](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Generation-parameters.md)') gr.Markdown('[Click here for more information.](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Generation-parameters.md)')
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]) shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']])
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
def set_interface_arguments(interface_mode, extensions, bool_active): def set_interface_arguments(interface_mode, extensions, bool_active):
@ -696,17 +683,6 @@ def create_interface():
shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image']) shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
with gr.Tab("Parameters", elem_id="parameters"): with gr.Tab("Parameters", elem_id="parameters"):
with gr.Box():
gr.Markdown("Chat parameters")
with gr.Row():
with gr.Column():
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
shared.gradio['chat_prompt_size'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='chat_prompt_size', info='Set limit on prompt size by removing old messages (while retaining context and user input)', value=shared.settings['chat_prompt_size'])
with gr.Column():
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)', info='New generations will be called until either this number is reached or no new content is generated between two iterations.')
shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character')
create_settings_menus(default_preset) create_settings_menus(default_preset)
# Create notebook mode interface # Create notebook mode interface

View File

@ -32,7 +32,7 @@ chat-instruct_command: 'Continue the chat dialogue below. Write a single reply f
<|prompt|>' <|prompt|>'
chat_prompt_size: 2048 chat_prompt_size: 2048
chat_prompt_size_min: 0 chat_prompt_size_min: 0
chat_prompt_size_max: 2048 chat_prompt_size_max: 8192
chat_generation_attempts: 1 chat_generation_attempts: 1
chat_generation_attempts_min: 1 chat_generation_attempts_min: 1
chat_generation_attempts_max: 10 chat_generation_attempts_max: 10