mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-11 21:10:40 +01:00
Merge pull request #2551 from oobabooga/dev
This commit is contained in:
commit
3cc5ce3c42
1
.gitignore
vendored
1
.gitignore
vendored
@ -17,6 +17,7 @@ torch-dumps
|
|||||||
venv/
|
venv/
|
||||||
.venv/
|
.venv/
|
||||||
.vscode
|
.vscode
|
||||||
|
.idea/
|
||||||
*.bak
|
*.bak
|
||||||
*.ipynb
|
*.ipynb
|
||||||
*.log
|
*.log
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
flex-direction: column-reverse;
|
flex-direction: column-reverse;
|
||||||
word-break: break-word;
|
word-break: break-word;
|
||||||
overflow-wrap: anywhere;
|
overflow-wrap: anywhere;
|
||||||
|
padding-top: 1px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message {
|
.message {
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
flex-direction: column-reverse;
|
flex-direction: column-reverse;
|
||||||
word-break: break-word;
|
word-break: break-word;
|
||||||
overflow-wrap: anywhere;
|
overflow-wrap: anywhere;
|
||||||
|
padding-top: 1px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message {
|
.message {
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
flex-direction: column-reverse;
|
flex-direction: column-reverse;
|
||||||
word-break: break-word;
|
word-break: break-word;
|
||||||
overflow-wrap: anywhere;
|
overflow-wrap: anywhere;
|
||||||
|
padding-top: 1px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message {
|
.message {
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
flex-direction: column-reverse;
|
flex-direction: column-reverse;
|
||||||
word-break: break-word;
|
word-break: break-word;
|
||||||
overflow-wrap: anywhere;
|
overflow-wrap: anywhere;
|
||||||
|
padding-top: 1px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message {
|
.message {
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
flex-direction: column-reverse;
|
flex-direction: column-reverse;
|
||||||
word-break: break-word;
|
word-break: break-word;
|
||||||
overflow-wrap: anywhere;
|
overflow-wrap: anywhere;
|
||||||
|
padding-top: 1px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message {
|
.message {
|
||||||
@ -74,6 +75,12 @@
|
|||||||
|
|
||||||
.dark .chat .assistant-message {
|
.dark .chat .assistant-message {
|
||||||
background-color: #374151;
|
background-color: #374151;
|
||||||
|
border: 1px solid #4b5563;
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark .chat .user-message {
|
||||||
|
background-color: #111827;
|
||||||
|
border: 1px solid #4b5563;
|
||||||
}
|
}
|
||||||
|
|
||||||
code {
|
code {
|
||||||
|
@ -160,11 +160,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
min_rows = 3
|
min_rows = 3
|
||||||
|
|
||||||
# Finding the maximum prompt size
|
# Finding the maximum prompt size
|
||||||
chat_prompt_size = state['chat_prompt_size']
|
max_length = min(get_max_prompt_length(state), state['chat_prompt_size'])
|
||||||
if shared.soft_prompt:
|
|
||||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
|
||||||
|
|
||||||
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
|
||||||
|
|
||||||
# Building the turn templates
|
# Building the turn templates
|
||||||
if 'turn_template' not in state or state['turn_template'] == '':
|
if 'turn_template' not in state or state['turn_template'] == '':
|
||||||
|
@ -103,7 +103,7 @@ class ModelDownloader:
|
|||||||
classifications = []
|
classifications = []
|
||||||
has_pytorch = False
|
has_pytorch = False
|
||||||
has_pt = False
|
has_pt = False
|
||||||
has_ggml = False
|
# has_ggml = False
|
||||||
has_safetensors = False
|
has_safetensors = False
|
||||||
is_lora = False
|
is_lora = False
|
||||||
while True:
|
while True:
|
||||||
@ -148,7 +148,7 @@ class ModelDownloader:
|
|||||||
has_pt = True
|
has_pt = True
|
||||||
classifications.append('pt')
|
classifications.append('pt')
|
||||||
elif is_ggml:
|
elif is_ggml:
|
||||||
has_ggml = True
|
# has_ggml = True
|
||||||
classifications.append('ggml')
|
classifications.append('ggml')
|
||||||
|
|
||||||
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
|
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
|
||||||
|
69
extensions/ngrok/README.md
Normal file
69
extensions/ngrok/README.md
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
# Adding an ingress URL through the ngrok Agent SDK for Python
|
||||||
|
|
||||||
|
[ngrok](https://ngrok.com) is a globally distributed reverse proxy commonly used for quickly getting a public URL to a
|
||||||
|
service running inside a private network, such as on your local laptop. The ngrok agent is usually
|
||||||
|
deployed inside a private network and is used to communicate with the ngrok cloud service.
|
||||||
|
|
||||||
|
By default the authtoken in the NGROK_AUTHTOKEN environment variable will be used. Alternatively one may be specified in
|
||||||
|
the `settings.json` file, see the Examples below. Retrieve your authtoken on the [Auth Token page of your ngrok dashboard](https://dashboard.ngrok.com/get-started/your-authtoken), signing up is free.
|
||||||
|
|
||||||
|
# Documentation
|
||||||
|
|
||||||
|
For a list of all available options, see [the configuration documentation](https://ngrok.com/docs/ngrok-agent/config/) or [the connect example](https://github.com/ngrok/ngrok-py/blob/main/examples/ngrok-connect-full.py).
|
||||||
|
|
||||||
|
The ngrok Python SDK is [on github here](https://github.com/ngrok/ngrok-py). A quickstart guide and a full API reference are included in the [ngrok-py Python API documentation](https://ngrok.github.io/ngrok-py/).
|
||||||
|
|
||||||
|
# Running
|
||||||
|
|
||||||
|
To enable ngrok install the requirements and then add `--extension ngrok` to the command line options, for instance:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r extensions/ngrok/requirements.txt
|
||||||
|
python server.py --extension ngrok
|
||||||
|
```
|
||||||
|
|
||||||
|
In the output you should then see something like this:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
INFO:Loading the extension "ngrok"...
|
||||||
|
INFO:Session created
|
||||||
|
INFO:Created tunnel "9d9d0944dc75ff9d3aae653e5eb29fe9" with url "https://d83706cf7be7.ngrok.app"
|
||||||
|
INFO:Tunnel "9d9d0944dc75ff9d3aae653e5eb29fe9" TCP forwarding to "localhost:7860"
|
||||||
|
INFO:Ingress established at https://d83706cf7be7.ngrok.app
|
||||||
|
```
|
||||||
|
|
||||||
|
You can now access the webui via the url shown, in this case `https://d83706cf7be7.ngrok.app`. It is recommended to add some authentication to the ingress, see below.
|
||||||
|
|
||||||
|
# Example Settings
|
||||||
|
|
||||||
|
In `settings.json` add a `ngrok` key with a dictionary of options, for instance:
|
||||||
|
|
||||||
|
To enable basic authentication:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"ngrok": {
|
||||||
|
"basic_auth": "user:password"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
To enable OAUTH authentication:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"ngrok": {
|
||||||
|
"oauth_provider": "google",
|
||||||
|
"oauth_allow_domains": "asdf.com",
|
||||||
|
"oauth_allow_emails": "asdf@asdf.com"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
To add an authtoken instead of using the NGROK_AUTHTOKEN environment variable:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"ngrok": {
|
||||||
|
"authtoken": "<token>",
|
||||||
|
"authtoken_from_env":false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
1
extensions/ngrok/requirements.txt
Normal file
1
extensions/ngrok/requirements.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
ngrok==0.*
|
36
extensions/ngrok/script.py
Normal file
36
extensions/ngrok/script.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
# Adds ngrok ingress, to use add `--extension ngrok` to the command line options
|
||||||
|
#
|
||||||
|
# Parameters can be customized in settings.json of webui, e.g.:
|
||||||
|
# {"ngrok": {"basic_auth":"user:password"} }
|
||||||
|
# or
|
||||||
|
# {"ngrok": {"oauth_provider":"google", "oauth_allow_emails":["asdf@asdf.com"]} }
|
||||||
|
#
|
||||||
|
# See this example for full list of options: https://github.com/ngrok/ngrok-py/blob/main/examples/ngrok-connect-full.py
|
||||||
|
# or the README.md in this directory.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
# Pick up host/port command line arguments
|
||||||
|
host = shared.args.listen_host if shared.args.listen_host and shared.args.listen else '127.0.0.1'
|
||||||
|
port = shared.args.listen_port if shared.args.listen_port else '7860'
|
||||||
|
|
||||||
|
# Default options
|
||||||
|
options = {
|
||||||
|
'addr': f"{host}:{port}",
|
||||||
|
'authtoken_from_env': True,
|
||||||
|
'session_metadata': 'text-generation-webui',
|
||||||
|
}
|
||||||
|
|
||||||
|
def ui():
|
||||||
|
settings = shared.settings.get("ngrok")
|
||||||
|
if settings:
|
||||||
|
options.update(settings)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ngrok
|
||||||
|
tunnel = ngrok.connect(**options)
|
||||||
|
logging.info(f"Ingress established at: {tunnel.url()}")
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
logging.error("===> ngrok library not found, please run `pip install -r extensions/ngrok/requirements.txt`")
|
||||||
|
|
@ -20,6 +20,12 @@ Example:
|
|||||||
SD_WEBUI_URL=http://127.0.0.1:7861
|
SD_WEBUI_URL=http://127.0.0.1:7861
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Make sure you enable it in server launch parameters. Just make sure they include:
|
||||||
|
|
||||||
|
```
|
||||||
|
--extensions openai
|
||||||
|
```
|
||||||
|
|
||||||
### Embeddings (alpha)
|
### Embeddings (alpha)
|
||||||
|
|
||||||
Embeddings requires ```sentence-transformers``` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: ```sentence-transformers/all-mpnet-base-v2``` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default ```text-embedding-ada-002``` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future.
|
Embeddings requires ```sentence-transformers``` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: ```sentence-transformers/all-mpnet-base-v2``` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default ```text-embedding-ada-002``` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future.
|
||||||
@ -42,7 +48,7 @@ Almost everything you use it with will require you to set a dummy OpenAI API key
|
|||||||
With the [official python openai client](https://github.com/openai/openai-python), you can set the OPENAI_API_BASE environment variable before you import the openai module, like so:
|
With the [official python openai client](https://github.com/openai/openai-python), you can set the OPENAI_API_BASE environment variable before you import the openai module, like so:
|
||||||
|
|
||||||
```
|
```
|
||||||
OPENAI_API_KEY=dummy
|
OPENAI_API_KEY=sk-dummy
|
||||||
OPENAI_API_BASE=http://127.0.0.1:5001/v1
|
OPENAI_API_BASE=http://127.0.0.1:5001/v1
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ params = {
|
|||||||
debug = True if 'OPENEDAI_DEBUG' in os.environ else False
|
debug = True if 'OPENEDAI_DEBUG' in os.environ else False
|
||||||
|
|
||||||
# Slightly different defaults for OpenAI's API
|
# Slightly different defaults for OpenAI's API
|
||||||
|
# Data type is important, Ex. use 0.0 for a float 0
|
||||||
default_req_params = {
|
default_req_params = {
|
||||||
'max_new_tokens': 200,
|
'max_new_tokens': 200,
|
||||||
'temperature': 1.0,
|
'temperature': 1.0,
|
||||||
@ -44,14 +45,14 @@ default_req_params = {
|
|||||||
'no_repeat_ngram_size': 0,
|
'no_repeat_ngram_size': 0,
|
||||||
'num_beams': 1,
|
'num_beams': 1,
|
||||||
'penalty_alpha': 0.0,
|
'penalty_alpha': 0.0,
|
||||||
'length_penalty': 1,
|
'length_penalty': 1.0,
|
||||||
'early_stopping': False,
|
'early_stopping': False,
|
||||||
'mirostat_mode': 0,
|
'mirostat_mode': 0,
|
||||||
'mirostat_tau': 5,
|
'mirostat_tau': 5.0,
|
||||||
'mirostat_eta': 0.1,
|
'mirostat_eta': 0.1,
|
||||||
'ban_eos_token': False,
|
'ban_eos_token': False,
|
||||||
'skip_special_tokens': True,
|
'skip_special_tokens': True,
|
||||||
'custom_stopping_strings': [],
|
'custom_stopping_strings': ['\n###'],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Optional, install the module and download the model to enable
|
# Optional, install the module and download the model to enable
|
||||||
@ -64,8 +65,6 @@ except ImportError:
|
|||||||
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
||||||
embedding_model = None
|
embedding_model = None
|
||||||
|
|
||||||
standard_stopping_strings = ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ]
|
|
||||||
|
|
||||||
# little helper to get defaults if arg is present but None and should be the same type as default.
|
# little helper to get defaults if arg is present but None and should be the same type as default.
|
||||||
def default(dic, key, default):
|
def default(dic, key, default):
|
||||||
val = dic.get(key, default)
|
val = dic.get(key, default)
|
||||||
@ -86,31 +85,6 @@ def clamp(value, minvalue, maxvalue):
|
|||||||
return max(minvalue, min(value, maxvalue))
|
return max(minvalue, min(value, maxvalue))
|
||||||
|
|
||||||
|
|
||||||
def deduce_template():
|
|
||||||
# Alpaca is verbose so a good default prompt
|
|
||||||
default_template = (
|
|
||||||
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
|
||||||
"Write a response that appropriately completes the request.\n\n"
|
|
||||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use the special instruction/input/response template for anything trained like Alpaca
|
|
||||||
if shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']:
|
|
||||||
return default_template
|
|
||||||
|
|
||||||
try:
|
|
||||||
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
|
||||||
|
|
||||||
template = instruct['turn_template']
|
|
||||||
template = template\
|
|
||||||
.replace('<|user|>', instruct.get('user', ''))\
|
|
||||||
.replace('<|bot|>', instruct.get('bot', ''))\
|
|
||||||
.replace('<|user-message|>', '{instruction}\n{input}')
|
|
||||||
return instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
|
||||||
except:
|
|
||||||
return default_template
|
|
||||||
|
|
||||||
|
|
||||||
def float_list_to_base64(float_list):
|
def float_list_to_base64(float_list):
|
||||||
# Convert the list to a float32 array that the OpenAPI client expects
|
# Convert the list to a float32 array that the OpenAPI client expects
|
||||||
float_array = np.array(float_list, dtype="float32")
|
float_array = np.array(float_list, dtype="float32")
|
||||||
@ -139,8 +113,27 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
"Origin, Accept, X-Requested-With, Content-Type, "
|
"Origin, Accept, X-Requested-With, Content-Type, "
|
||||||
"Access-Control-Request-Method, Access-Control-Request-Headers, "
|
"Access-Control-Request-Method, Access-Control-Request-Headers, "
|
||||||
"Authorization"
|
"Authorization"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def openai_error(self, message, code = 500, error_type = 'APIError', param = '', internal_message = ''):
|
||||||
|
self.send_response(code)
|
||||||
|
self.send_access_control_headers()
|
||||||
|
self.send_header('Content-Type', 'application/json')
|
||||||
|
self.end_headers()
|
||||||
|
error_resp = {
|
||||||
|
'error': {
|
||||||
|
'message': message,
|
||||||
|
'code': code,
|
||||||
|
'type': error_type,
|
||||||
|
'param': param,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if internal_message:
|
||||||
|
error_resp['internal_message'] = internal_message
|
||||||
|
|
||||||
|
response = json.dumps(error_resp)
|
||||||
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
def do_OPTIONS(self):
|
def do_OPTIONS(self):
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_access_control_headers()
|
self.send_access_control_headers()
|
||||||
@ -150,42 +143,24 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
if self.path.startswith('/v1/models'):
|
if self.path.startswith('/v1/models'):
|
||||||
|
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_access_control_headers()
|
self.send_access_control_headers()
|
||||||
self.send_header('Content-Type', 'application/json')
|
self.send_header('Content-Type', 'application/json')
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
||||||
# TODO: list all models and allow model changes via API? Lora's?
|
# TODO: Lora's?
|
||||||
# This API should list capabilities, limits and pricing...
|
# This API should list capabilities, limits and pricing...
|
||||||
models = [{
|
current_model_list = [ shared.model_name ] # The real chat/completions model
|
||||||
"id": shared.model_name, # The real chat/completions model
|
embeddings_model_list = [ st_model ] if embedding_model else [] # The real sentence transformer embeddings model
|
||||||
"object": "model",
|
pseudo_model_list = [ # these are expected by so much, so include some here as a dummy
|
||||||
"owned_by": "user",
|
'gpt-3.5-turbo', # /v1/chat/completions
|
||||||
"permission": []
|
'text-curie-001', # /v1/completions, 2k context
|
||||||
}, {
|
'text-davinci-002' # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768
|
||||||
"id": st_model, # The real sentence transformer embeddings model
|
]
|
||||||
"object": "model",
|
available_model_list = get_available_models()
|
||||||
"owned_by": "user",
|
all_model_list = current_model_list + embeddings_model_list + pseudo_model_list + available_model_list
|
||||||
"permission": []
|
|
||||||
}, { # these are expected by so much, so include some here as a dummy
|
|
||||||
"id": "gpt-3.5-turbo", # /v1/chat/completions
|
|
||||||
"object": "model",
|
|
||||||
"owned_by": "user",
|
|
||||||
"permission": []
|
|
||||||
}, {
|
|
||||||
"id": "text-curie-001", # /v1/completions, 2k context
|
|
||||||
"object": "model",
|
|
||||||
"owned_by": "user",
|
|
||||||
"permission": []
|
|
||||||
}, {
|
|
||||||
"id": "text-davinci-002", # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768
|
|
||||||
"object": "model",
|
|
||||||
"owned_by": "user",
|
|
||||||
"permission": []
|
|
||||||
}]
|
|
||||||
|
|
||||||
models.extend([{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in get_available_models() ])
|
models = [{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in all_model_list ]
|
||||||
|
|
||||||
response = ''
|
response = ''
|
||||||
if self.path == '/v1/models':
|
if self.path == '/v1/models':
|
||||||
@ -203,6 +178,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
})
|
})
|
||||||
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
elif '/billing/usage' in self.path:
|
elif '/billing/usage' in self.path:
|
||||||
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
@ -214,6 +190,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
"total_usage": 0,
|
"total_usage": 0,
|
||||||
})
|
})
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.send_error(404)
|
self.send_error(404)
|
||||||
|
|
||||||
@ -227,6 +204,11 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
print(body)
|
print(body)
|
||||||
|
|
||||||
if '/completions' in self.path or '/generate' in self.path:
|
if '/completions' in self.path or '/generate' in self.path:
|
||||||
|
|
||||||
|
if not shared.model:
|
||||||
|
self.openai_error("No model loaded.")
|
||||||
|
return
|
||||||
|
|
||||||
is_legacy = '/generate' in self.path
|
is_legacy = '/generate' in self.path
|
||||||
is_chat = 'chat' in self.path
|
is_chat = 'chat' in self.path
|
||||||
resp_list = 'data' if is_legacy else 'choices'
|
resp_list = 'data' if is_legacy else 'choices'
|
||||||
@ -238,13 +220,16 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
cmpl_id = "chatcmpl-%d" % (created_time) if is_chat else "conv-%d" % (created_time)
|
cmpl_id = "chatcmpl-%d" % (created_time) if is_chat else "conv-%d" % (created_time)
|
||||||
|
|
||||||
|
# Request Parameters
|
||||||
# Try to use openai defaults or map them to something with the same intent
|
# Try to use openai defaults or map them to something with the same intent
|
||||||
stopping_strings = default(shared.settings, 'custom_stopping_strings', [])
|
req_params = default_req_params.copy()
|
||||||
|
req_params['custom_stopping_strings'] = default_req_params['custom_stopping_strings'].copy()
|
||||||
|
|
||||||
if 'stop' in body:
|
if 'stop' in body:
|
||||||
if isinstance(body['stop'], str):
|
if isinstance(body['stop'], str):
|
||||||
stopping_strings = [body['stop']]
|
req_params['custom_stopping_strings'].extend([body['stop']])
|
||||||
elif isinstance(body['stop'], list):
|
elif isinstance(body['stop'], list):
|
||||||
stopping_strings = body['stop']
|
req_params['custom_stopping_strings'].extend(body['stop'])
|
||||||
|
|
||||||
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
||||||
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length)
|
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length)
|
||||||
@ -255,8 +240,6 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
|
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
|
||||||
# if the user assumes OpenAI, the max_tokens is way too large - try to ignore it unless it's small enough
|
# if the user assumes OpenAI, the max_tokens is way too large - try to ignore it unless it's small enough
|
||||||
|
|
||||||
req_params = default_req_params.copy()
|
|
||||||
|
|
||||||
req_params['max_new_tokens'] = max_tokens
|
req_params['max_new_tokens'] = max_tokens
|
||||||
req_params['truncation_length'] = truncation_length
|
req_params['truncation_length'] = truncation_length
|
||||||
req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
||||||
@ -319,9 +302,14 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
'prompt': bot_prompt,
|
'prompt': bot_prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if instruct['user']: # WizardLM and some others have no user prompt.
|
||||||
|
req_params['custom_stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
print(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
print(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
||||||
except:
|
except:
|
||||||
|
req_params['custom_stopping_strings'].extend(['\nuser:'])
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
print("Loaded default role format.")
|
print("Loaded default role format.")
|
||||||
|
|
||||||
@ -396,11 +384,6 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
print(f"Warning: Ignoring max_new_tokens ({req_params['max_new_tokens']}), too large for the remaining context. Remaining tokens: {req_params['truncation_length'] - token_count}")
|
print(f"Warning: Ignoring max_new_tokens ({req_params['max_new_tokens']}), too large for the remaining context. Remaining tokens: {req_params['truncation_length'] - token_count}")
|
||||||
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
||||||
print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}")
|
print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}")
|
||||||
|
|
||||||
# pass with some expected stop strings.
|
|
||||||
# some strange cases of "##| Instruction: " sneaking through.
|
|
||||||
stopping_strings += standard_stopping_strings
|
|
||||||
req_params['custom_stopping_strings'] = stopping_strings
|
|
||||||
|
|
||||||
if req_params['stream']:
|
if req_params['stream']:
|
||||||
shared.args.chat = True
|
shared.args.chat = True
|
||||||
@ -423,19 +406,17 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''}
|
chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''}
|
||||||
chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''}
|
chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''}
|
||||||
|
|
||||||
data_chunk = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
||||||
chunk_size = hex(len(data_chunk))[2:] + '\r\n'
|
|
||||||
response = chunk_size + data_chunk
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
# generate reply #######################################
|
# generate reply #######################################
|
||||||
if debug:
|
if debug:
|
||||||
print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings})
|
print({'prompt': prompt, 'req_params': req_params})
|
||||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
generator = generate_reply(prompt, req_params, is_chat=False)
|
||||||
|
|
||||||
answer = ''
|
answer = ''
|
||||||
seen_content = ''
|
seen_content = ''
|
||||||
longest_stop_len = max([len(x) for x in stopping_strings])
|
longest_stop_len = max([len(x) for x in req_params['custom_stopping_strings']] + [0])
|
||||||
|
|
||||||
for a in generator:
|
for a in generator:
|
||||||
answer = a
|
answer = a
|
||||||
@ -444,7 +425,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
len_seen = len(seen_content)
|
len_seen = len(seen_content)
|
||||||
search_start = max(len_seen - longest_stop_len, 0)
|
search_start = max(len_seen - longest_stop_len, 0)
|
||||||
|
|
||||||
for string in stopping_strings:
|
for string in req_params['custom_stopping_strings']:
|
||||||
idx = answer.find(string, search_start)
|
idx = answer.find(string, search_start)
|
||||||
if idx != -1:
|
if idx != -1:
|
||||||
answer = answer[:idx] # clip it.
|
answer = answer[:idx] # clip it.
|
||||||
@ -457,7 +438,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
# is completed, buffer and generate more, don't send it
|
# is completed, buffer and generate more, don't send it
|
||||||
buffer_and_continue = False
|
buffer_and_continue = False
|
||||||
|
|
||||||
for string in stopping_strings:
|
for string in req_params['custom_stopping_strings']:
|
||||||
for j in range(len(string) - 1, 0, -1):
|
for j in range(len(string) - 1, 0, -1):
|
||||||
if answer[-j:] == string[:j]:
|
if answer[-j:] == string[:j]:
|
||||||
buffer_and_continue = True
|
buffer_and_continue = True
|
||||||
@ -498,9 +479,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
# So yeah... do both methods? delta and messages.
|
# So yeah... do both methods? delta and messages.
|
||||||
chunk[resp_list][0]['message'] = {'content': new_content}
|
chunk[resp_list][0]['message'] = {'content': new_content}
|
||||||
chunk[resp_list][0]['delta'] = {'content': new_content}
|
chunk[resp_list][0]['delta'] = {'content': new_content}
|
||||||
data_chunk = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
||||||
chunk_size = hex(len(data_chunk))[2:] + '\r\n'
|
|
||||||
response = chunk_size + data_chunk
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
completion_token_count += len(encode(new_content)[0])
|
completion_token_count += len(encode(new_content)[0])
|
||||||
|
|
||||||
@ -527,10 +506,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
chunk[resp_list][0]['message'] = {'content': ''}
|
chunk[resp_list][0]['message'] = {'content': ''}
|
||||||
chunk[resp_list][0]['delta'] = {'content': ''}
|
chunk[resp_list][0]['delta'] = {'content': ''}
|
||||||
|
|
||||||
data_chunk = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
response = 'data: ' + json.dumps(chunk) + '\r\n\r\ndata: [DONE]\r\n\r\n'
|
||||||
chunk_size = hex(len(data_chunk))[2:] + '\r\n'
|
|
||||||
done = 'data: [DONE]\r\n\r\n'
|
|
||||||
response = chunk_size + data_chunk + done
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
# Finished if streaming.
|
# Finished if streaming.
|
||||||
if debug:
|
if debug:
|
||||||
@ -574,7 +550,12 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
response = json.dumps(resp)
|
response = json.dumps(resp)
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
elif '/edits' in self.path:
|
elif '/edits' in self.path:
|
||||||
|
if not shared.model:
|
||||||
|
self.openai_error("No model loaded.")
|
||||||
|
return
|
||||||
|
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_access_control_headers()
|
self.send_access_control_headers()
|
||||||
self.send_header('Content-Type', 'application/json')
|
self.send_header('Content-Type', 'application/json')
|
||||||
@ -586,15 +567,42 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
instruction = body['instruction']
|
instruction = body['instruction']
|
||||||
input = body.get('input', '')
|
input = body.get('input', '')
|
||||||
|
|
||||||
instruction_template = deduce_template()
|
# Request parameters
|
||||||
|
req_params = default_req_params.copy()
|
||||||
|
|
||||||
|
# Alpaca is verbose so a good default prompt
|
||||||
|
default_template = (
|
||||||
|
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||||
|
"Write a response that appropriately completes the request.\n\n"
|
||||||
|
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
instruction_template = default_template
|
||||||
|
req_params['custom_stopping_strings'] = [ '\n###' ]
|
||||||
|
|
||||||
|
# Use the special instruction/input/response template for anything trained like Alpaca
|
||||||
|
if not (shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']):
|
||||||
|
try:
|
||||||
|
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||||
|
|
||||||
|
template = instruct['turn_template']
|
||||||
|
template = template\
|
||||||
|
.replace('<|user|>', instruct.get('user', ''))\
|
||||||
|
.replace('<|bot|>', instruct.get('bot', ''))\
|
||||||
|
.replace('<|user-message|>', '{instruction}\n{input}')
|
||||||
|
|
||||||
|
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
||||||
|
if instruct['user']:
|
||||||
|
req_params['custom_stopping_strings'] = [ '\n' + instruct['user'], instruct['user'] ]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
edit_task = instruction_template.format(instruction=instruction, input=input)
|
edit_task = instruction_template.format(instruction=instruction, input=input)
|
||||||
|
|
||||||
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
||||||
token_count = len(encode(edit_task)[0])
|
token_count = len(encode(edit_task)[0])
|
||||||
max_tokens = truncation_length - token_count
|
max_tokens = truncation_length - token_count
|
||||||
|
|
||||||
req_params = default_req_params.copy()
|
|
||||||
|
|
||||||
req_params['max_new_tokens'] = max_tokens
|
req_params['max_new_tokens'] = max_tokens
|
||||||
req_params['truncation_length'] = truncation_length
|
req_params['truncation_length'] = truncation_length
|
||||||
req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
||||||
@ -605,7 +613,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
if debug:
|
if debug:
|
||||||
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
||||||
|
|
||||||
generator = generate_reply(edit_task, req_params, stopping_strings=standard_stopping_strings, is_chat=False)
|
generator = generate_reply(edit_task, req_params, is_chat=False)
|
||||||
|
|
||||||
answer = ''
|
answer = ''
|
||||||
for a in generator:
|
for a in generator:
|
||||||
@ -636,6 +644,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
response = json.dumps(resp)
|
response = json.dumps(resp)
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ:
|
elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ:
|
||||||
# Stable Diffusion callout wrapper for txt2img
|
# Stable Diffusion callout wrapper for txt2img
|
||||||
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
|
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
|
||||||
@ -682,6 +691,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
response = json.dumps(resp)
|
response = json.dumps(resp)
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
elif '/embeddings' in self.path and embedding_model is not None:
|
elif '/embeddings' in self.path and embedding_model is not None:
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_access_control_headers()
|
self.send_access_control_headers()
|
||||||
@ -715,6 +725,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
if debug:
|
if debug:
|
||||||
print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
elif '/moderations' in self.path:
|
elif '/moderations' in self.path:
|
||||||
# for now do nothing, just don't error.
|
# for now do nothing, just don't error.
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
@ -763,6 +774,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
}]
|
}]
|
||||||
})
|
})
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(self.path, self.headers)
|
print(self.path, self.headers)
|
||||||
self.send_error(404)
|
self.send_error(404)
|
||||||
|
103
modules/LoRA.py
103
modules/LoRA.py
@ -5,6 +5,14 @@ from peft import PeftModel
|
|||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
from modules.models import reload_model
|
||||||
|
|
||||||
|
try:
|
||||||
|
from auto_gptq import get_gptq_peft_model
|
||||||
|
from auto_gptq.utils.peft_utils import GPTQLoraConfig
|
||||||
|
has_auto_gptq_peft = True
|
||||||
|
except:
|
||||||
|
has_auto_gptq_peft = False
|
||||||
|
|
||||||
|
|
||||||
def add_lora_to_model(lora_names):
|
def add_lora_to_model(lora_names):
|
||||||
@ -13,43 +21,72 @@ def add_lora_to_model(lora_names):
|
|||||||
removed_set = prior_set - set(lora_names)
|
removed_set = prior_set - set(lora_names)
|
||||||
shared.lora_names = list(lora_names)
|
shared.lora_names = list(lora_names)
|
||||||
|
|
||||||
# If no LoRA needs to be added or removed, exit
|
is_autogptq = 'GPTQForCausalLM' in shared.model.__class__.__name__
|
||||||
if len(added_set) == 0 and len(removed_set) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Add a LoRA when another LoRA is already present
|
# AutoGPTQ case. It doesn't use the peft functions.
|
||||||
if len(removed_set) == 0 and len(prior_set) > 0:
|
# Copied from https://github.com/Ph0rk0z/text-generation-webui-testing
|
||||||
logger.info(f"Adding the LoRA(s) named {added_set} to the model...")
|
if is_autogptq:
|
||||||
for lora in added_set:
|
if not has_auto_gptq_peft:
|
||||||
shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora)
|
logger.error("This version of AutoGPTQ does not support LoRA. You need to install from source or wait for a new release.")
|
||||||
|
return
|
||||||
|
|
||||||
return
|
if len(prior_set) > 0:
|
||||||
|
reload_model()
|
||||||
|
|
||||||
# If any LoRA needs to be removed, start over
|
if len(shared.lora_names) == 0:
|
||||||
if len(removed_set) > 0:
|
return
|
||||||
shared.model.disable_adapter()
|
else:
|
||||||
shared.model = shared.model.base_model.model
|
if len(shared.lora_names) > 1:
|
||||||
|
logger.warning('AutoGPTQ can only work with 1 LoRA at the moment. Only the first one in the list will be loaded')
|
||||||
|
|
||||||
if len(lora_names) > 0:
|
peft_config = GPTQLoraConfig(
|
||||||
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
|
inference_mode=True,
|
||||||
params = {}
|
)
|
||||||
if not shared.args.cpu:
|
|
||||||
params['dtype'] = shared.model.dtype
|
|
||||||
if hasattr(shared.model, "hf_device_map"):
|
|
||||||
params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}
|
|
||||||
elif shared.args.load_in_8bit:
|
|
||||||
params['device_map'] = {'': 0}
|
|
||||||
|
|
||||||
shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_names[0]}"), **params)
|
lora_path = Path(f"{shared.args.lora_dir}/{shared.lora_names[0]}")
|
||||||
|
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join([lora_names[0]])))
|
||||||
|
shared.model = get_gptq_peft_model(shared.model, peft_config, lora_path)
|
||||||
|
return
|
||||||
|
|
||||||
for lora in lora_names[1:]:
|
# Transformers case
|
||||||
shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora)
|
else:
|
||||||
|
# If no LoRA needs to be added or removed, exit
|
||||||
|
if len(added_set) == 0 and len(removed_set) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
if not shared.args.load_in_8bit and not shared.args.cpu:
|
# Add a LoRA when another LoRA is already present
|
||||||
shared.model.half()
|
if len(removed_set) == 0 and len(prior_set) > 0:
|
||||||
if not hasattr(shared.model, "hf_device_map"):
|
logger.info(f"Adding the LoRA(s) named {added_set} to the model...")
|
||||||
if torch.has_mps:
|
for lora in added_set:
|
||||||
device = torch.device('mps')
|
shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora)
|
||||||
shared.model = shared.model.to(device)
|
|
||||||
else:
|
return
|
||||||
shared.model = shared.model.cuda()
|
|
||||||
|
# If any LoRA needs to be removed, start over
|
||||||
|
if len(removed_set) > 0:
|
||||||
|
shared.model.disable_adapter()
|
||||||
|
shared.model = shared.model.base_model.model
|
||||||
|
|
||||||
|
if len(lora_names) > 0:
|
||||||
|
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
|
||||||
|
params = {}
|
||||||
|
if not shared.args.cpu:
|
||||||
|
params['dtype'] = shared.model.dtype
|
||||||
|
if hasattr(shared.model, "hf_device_map"):
|
||||||
|
params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}
|
||||||
|
elif shared.args.load_in_8bit:
|
||||||
|
params['device_map'] = {'': 0}
|
||||||
|
|
||||||
|
shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_names[0]}"), **params)
|
||||||
|
|
||||||
|
for lora in lora_names[1:]:
|
||||||
|
shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora)
|
||||||
|
|
||||||
|
if not shared.args.load_in_8bit and not shared.args.cpu:
|
||||||
|
shared.model.half()
|
||||||
|
if not hasattr(shared.model, "hf_device_map"):
|
||||||
|
if torch.has_mps:
|
||||||
|
device = torch.device('mps')
|
||||||
|
shared.model = shared.model.to(device)
|
||||||
|
else:
|
||||||
|
shared.model = shared.model.cuda()
|
||||||
|
@ -55,11 +55,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
is_instruct = state['mode'] == 'instruct'
|
is_instruct = state['mode'] == 'instruct'
|
||||||
|
|
||||||
# Find the maximum prompt size
|
# Find the maximum prompt size
|
||||||
chat_prompt_size = state['chat_prompt_size']
|
max_length = min(get_max_prompt_length(state), state['chat_prompt_size'])
|
||||||
if shared.soft_prompt:
|
|
||||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
|
||||||
|
|
||||||
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
|
||||||
all_substrings = {
|
all_substrings = {
|
||||||
'chat': get_turn_substrings(state, instruct=False),
|
'chat': get_turn_substrings(state, instruct=False),
|
||||||
'instruct': get_turn_substrings(state, instruct=True)
|
'instruct': get_turn_substrings(state, instruct=True)
|
||||||
@ -441,6 +437,9 @@ def save_history(mode, timestamp=False):
|
|||||||
|
|
||||||
fname = f"Instruct_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
fname = f"Instruct_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||||
else:
|
else:
|
||||||
|
if shared.character == 'None':
|
||||||
|
return
|
||||||
|
|
||||||
if timestamp:
|
if timestamp:
|
||||||
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||||
else:
|
else:
|
||||||
@ -564,7 +563,7 @@ def load_character(character, name1, name2, instruct=False):
|
|||||||
if not instruct:
|
if not instruct:
|
||||||
shared.history['internal'] = []
|
shared.history['internal'] = []
|
||||||
shared.history['visible'] = []
|
shared.history['visible'] = []
|
||||||
if Path(f'logs/{shared.character}_persistent.json').exists():
|
if shared.character != 'None' and Path(f'logs/{shared.character}_persistent.json').exists():
|
||||||
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
|
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
|
||||||
else:
|
else:
|
||||||
# Insert greeting if it exists
|
# Insert greeting if it exists
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import traceback
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
@ -25,7 +25,6 @@ class LlamaCppModel:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(self, path):
|
def from_pretrained(self, path):
|
||||||
result = self()
|
result = self()
|
||||||
|
|
||||||
cache_capacity = 0
|
cache_capacity = 0
|
||||||
if shared.args.cache_capacity is not None:
|
if shared.args.cache_capacity is not None:
|
||||||
if 'GiB' in shared.args.cache_capacity:
|
if 'GiB' in shared.args.cache_capacity:
|
||||||
@ -36,7 +35,6 @@ class LlamaCppModel:
|
|||||||
cache_capacity = int(shared.args.cache_capacity)
|
cache_capacity = int(shared.args.cache_capacity)
|
||||||
|
|
||||||
logger.info("Cache capacity is " + str(cache_capacity) + " bytes")
|
logger.info("Cache capacity is " + str(cache_capacity) + " bytes")
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'model_path': str(path),
|
'model_path': str(path),
|
||||||
'n_ctx': shared.args.n_ctx,
|
'n_ctx': shared.args.n_ctx,
|
||||||
@ -47,6 +45,7 @@ class LlamaCppModel:
|
|||||||
'use_mlock': shared.args.mlock,
|
'use_mlock': shared.args.mlock,
|
||||||
'n_gpu_layers': shared.args.n_gpu_layers
|
'n_gpu_layers': shared.args.n_gpu_layers
|
||||||
}
|
}
|
||||||
|
|
||||||
self.model = Llama(**params)
|
self.model = Llama(**params)
|
||||||
if cache_capacity > 0:
|
if cache_capacity > 0:
|
||||||
self.model.set_cache(LlamaCache(capacity_bytes=cache_capacity))
|
self.model.set_cache(LlamaCache(capacity_bytes=cache_capacity))
|
||||||
@ -57,6 +56,7 @@ class LlamaCppModel:
|
|||||||
def encode(self, string):
|
def encode(self, string):
|
||||||
if type(string) is str:
|
if type(string) is str:
|
||||||
string = string.encode()
|
string = string.encode()
|
||||||
|
|
||||||
return self.model.tokenize(string)
|
return self.model.tokenize(string)
|
||||||
|
|
||||||
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, mirostat_mode=0, mirostat_tau=5, mirostat_eta=0.1, callback=None):
|
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, mirostat_mode=0, mirostat_tau=5, mirostat_eta=0.1, callback=None):
|
||||||
@ -73,12 +73,14 @@ class LlamaCppModel:
|
|||||||
mirostat_eta=mirostat_eta,
|
mirostat_eta=mirostat_eta,
|
||||||
stream=True
|
stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
output = ""
|
output = ""
|
||||||
for completion_chunk in completion_chunks:
|
for completion_chunk in completion_chunks:
|
||||||
text = completion_chunk['choices'][0]['text']
|
text = completion_chunk['choices'][0]['text']
|
||||||
output += text
|
output += text
|
||||||
if callback:
|
if callback:
|
||||||
callback(text)
|
callback(text)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def generate_with_streaming(self, **kwargs):
|
def generate_with_streaming(self, **kwargs):
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
import gc
|
import gc
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import zipfile
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import infer_auto_device_map, init_empty_weights
|
from accelerate import infer_auto_device_map, init_empty_weights
|
||||||
@ -338,32 +335,3 @@ def unload_model():
|
|||||||
def reload_model():
|
def reload_model():
|
||||||
unload_model()
|
unload_model()
|
||||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||||
|
|
||||||
|
|
||||||
def load_soft_prompt(name):
|
|
||||||
if name == 'None':
|
|
||||||
shared.soft_prompt = False
|
|
||||||
shared.soft_prompt_tensor = None
|
|
||||||
else:
|
|
||||||
with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf:
|
|
||||||
zf.extract('tensor.npy')
|
|
||||||
zf.extract('meta.json')
|
|
||||||
j = json.loads(open('meta.json', 'r').read())
|
|
||||||
logger.info(f"\nLoading the softprompt \"{name}\".")
|
|
||||||
for field in j:
|
|
||||||
if field != 'name':
|
|
||||||
if type(j[field]) is list:
|
|
||||||
logger.info(f"{field}: {', '.join(j[field])}")
|
|
||||||
else:
|
|
||||||
logger.info(f"{field}: {j[field]}")
|
|
||||||
|
|
||||||
tensor = np.load('tensor.npy')
|
|
||||||
Path('tensor.npy').unlink()
|
|
||||||
Path('meta.json').unlink()
|
|
||||||
|
|
||||||
tensor = torch.Tensor(tensor).to(device=shared.model.device, dtype=shared.model.dtype)
|
|
||||||
tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
|
|
||||||
shared.soft_prompt = True
|
|
||||||
shared.soft_prompt_tensor = tensor
|
|
||||||
|
|
||||||
return name
|
|
||||||
|
@ -12,8 +12,6 @@ tokenizer = None
|
|||||||
model_name = "None"
|
model_name = "None"
|
||||||
model_type = None
|
model_type = None
|
||||||
lora_names = []
|
lora_names = []
|
||||||
soft_prompt_tensor = None
|
|
||||||
soft_prompt = False
|
|
||||||
|
|
||||||
# Chat variables
|
# Chat variables
|
||||||
history = {'internal': [], 'visible': []}
|
history = {'internal': [], 'visible': []}
|
||||||
@ -61,7 +59,7 @@ settings = {
|
|||||||
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
|
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
|
||||||
'chat_prompt_size': 2048,
|
'chat_prompt_size': 2048,
|
||||||
'chat_prompt_size_min': 0,
|
'chat_prompt_size_min': 0,
|
||||||
'chat_prompt_size_max': 2048,
|
'chat_prompt_size_max': 8192,
|
||||||
'chat_generation_attempts': 1,
|
'chat_generation_attempts': 1,
|
||||||
'chat_generation_attempts_min': 1,
|
'chat_generation_attempts_min': 1,
|
||||||
'chat_generation_attempts_max': 10,
|
'chat_generation_attempts_max': 10,
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import ast
|
import ast
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
@ -28,11 +27,7 @@ def generate_reply(*args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def get_max_prompt_length(state):
|
def get_max_prompt_length(state):
|
||||||
max_length = state['truncation_length'] - state['max_new_tokens']
|
return state['truncation_length'] - state['max_new_tokens']
|
||||||
if shared.soft_prompt:
|
|
||||||
max_length -= shared.soft_prompt_tensor.shape[1]
|
|
||||||
|
|
||||||
return max_length
|
|
||||||
|
|
||||||
|
|
||||||
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
||||||
@ -81,14 +76,6 @@ def decode(output_ids, skip_special_tokens=True):
|
|||||||
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
||||||
|
|
||||||
|
|
||||||
def generate_softprompt_input_tensors(input_ids):
|
|
||||||
inputs_embeds = shared.model.transformer.wte(input_ids)
|
|
||||||
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
|
|
||||||
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
|
|
||||||
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
|
||||||
return inputs_embeds, filler_input_ids
|
|
||||||
|
|
||||||
|
|
||||||
# Removes empty replies from gpt4chan outputs
|
# Removes empty replies from gpt4chan outputs
|
||||||
def fix_gpt4chan(s):
|
def fix_gpt4chan(s):
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
@ -233,18 +220,11 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
|||||||
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
||||||
|
|
||||||
# Add the encoded tokens to generate_params
|
# Add the encoded tokens to generate_params
|
||||||
if shared.soft_prompt:
|
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
original_input_ids = input_ids
|
||||||
question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds)
|
generate_params.update({'inputs': input_ids})
|
||||||
original_input_ids = input_ids
|
if inputs_embeds is not None:
|
||||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||||
generate_params.update({'inputs': filler_input_ids})
|
|
||||||
else:
|
|
||||||
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
|
||||||
original_input_ids = input_ids
|
|
||||||
generate_params.update({'inputs': input_ids})
|
|
||||||
if inputs_embeds is not None:
|
|
||||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
|
||||||
|
|
||||||
# Create the StoppingCriteriaList with the stopping strings (needs to be done after tokenizer extensions)
|
# Create the StoppingCriteriaList with the stopping strings (needs to be done after tokenizer extensions)
|
||||||
stopping_criteria_list = transformers.StoppingCriteriaList()
|
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||||
@ -270,9 +250,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
|||||||
if cuda:
|
if cuda:
|
||||||
output = output.cuda()
|
output = output.cuda()
|
||||||
|
|
||||||
if shared.soft_prompt:
|
|
||||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
|
||||||
|
|
||||||
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
|
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
|
||||||
|
|
||||||
# Stream the reply 1 token at a time.
|
# Stream the reply 1 token at a time.
|
||||||
@ -290,9 +267,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
|||||||
|
|
||||||
with generate_with_streaming(**generate_params) as generator:
|
with generate_with_streaming(**generate_params) as generator:
|
||||||
for output in generator:
|
for output in generator:
|
||||||
if shared.soft_prompt:
|
|
||||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
|
||||||
|
|
||||||
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
|
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
|
||||||
if output[-1] in eos_token_ids:
|
if output[-1] in eos_token_ids:
|
||||||
break
|
break
|
||||||
|
@ -60,10 +60,6 @@ def get_available_extensions():
|
|||||||
return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys)
|
return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys)
|
||||||
|
|
||||||
|
|
||||||
def get_available_softprompts():
|
|
||||||
return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=natural_keys)
|
|
||||||
|
|
||||||
|
|
||||||
def get_available_loras():
|
def get_available_loras():
|
||||||
return sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=natural_keys)
|
return sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=natural_keys)
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ datasets
|
|||||||
einops
|
einops
|
||||||
flexgen==0.1.7
|
flexgen==0.1.7
|
||||||
gradio_client==0.2.5
|
gradio_client==0.2.5
|
||||||
gradio==3.31.0
|
gradio==3.33.1
|
||||||
markdown
|
markdown
|
||||||
numpy
|
numpy
|
||||||
pandas
|
pandas
|
||||||
@ -19,7 +19,7 @@ git+https://github.com/huggingface/transformers@e45e756d22206ca8fa9fb057c8c3d8fa
|
|||||||
git+https://github.com/huggingface/accelerate@0226f750257b3bf2cadc4f189f9eef0c764a0467
|
git+https://github.com/huggingface/accelerate@0226f750257b3bf2cadc4f189f9eef0c764a0467
|
||||||
bitsandbytes==0.39.0; platform_system != "Windows"
|
bitsandbytes==0.39.0; platform_system != "Windows"
|
||||||
https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.39.0-py3-none-any.whl; platform_system == "Windows"
|
https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.39.0-py3-none-any.whl; platform_system == "Windows"
|
||||||
llama-cpp-python==0.1.56; platform_system != "Windows"
|
llama-cpp-python==0.1.57; platform_system != "Windows"
|
||||||
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.56/llama_cpp_python-0.1.56-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.57/llama_cpp_python-0.1.57-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
||||||
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.0/auto_gptq-0.2.0+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.0/auto_gptq-0.2.0+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
||||||
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.0/auto_gptq-0.2.0+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux"
|
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.0/auto_gptq-0.2.0+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux"
|
||||||
|
66
server.py
66
server.py
@ -26,7 +26,6 @@ import matplotlib
|
|||||||
matplotlib.use('Agg') # This fixes LaTeX rendering on some systems
|
matplotlib.use('Agg') # This fixes LaTeX rendering on some systems
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import io
|
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -34,7 +33,6 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import zipfile
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -50,7 +48,7 @@ from modules import chat, shared, training, ui, utils
|
|||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.html_generator import chat_html_wrapper
|
from modules.html_generator import chat_html_wrapper
|
||||||
from modules.LoRA import add_lora_to_model
|
from modules.LoRA import add_lora_to_model
|
||||||
from modules.models import load_model, load_soft_prompt, unload_model
|
from modules.models import load_model, unload_model
|
||||||
from modules.text_generation import (generate_reply_wrapper,
|
from modules.text_generation import (generate_reply_wrapper,
|
||||||
get_encoded_length, stop_everything_event)
|
get_encoded_length, stop_everything_event)
|
||||||
|
|
||||||
@ -119,19 +117,6 @@ def load_preset_values(preset_menu, state, return_dict=False):
|
|||||||
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
|
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
|
||||||
|
|
||||||
|
|
||||||
def upload_soft_prompt(file):
|
|
||||||
with zipfile.ZipFile(io.BytesIO(file)) as zf:
|
|
||||||
zf.extract('meta.json')
|
|
||||||
j = json.loads(open('meta.json', 'r').read())
|
|
||||||
name = j['name']
|
|
||||||
Path('meta.json').unlink()
|
|
||||||
|
|
||||||
with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
|
|
||||||
f.write(file)
|
|
||||||
|
|
||||||
return name
|
|
||||||
|
|
||||||
|
|
||||||
def open_save_prompt():
|
def open_save_prompt():
|
||||||
fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
|
fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
|
||||||
return gr.update(value=fname, visible=True), gr.update(visible=False), gr.update(visible=True)
|
return gr.update(value=fname, visible=True), gr.update(visible=False), gr.update(visible=True)
|
||||||
@ -392,13 +377,12 @@ def create_model_menus():
|
|||||||
with gr.Box():
|
with gr.Box():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
gr.Markdown('AutoGPTQ')
|
gr.Markdown('GPTQ')
|
||||||
shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton)
|
shared.gradio['triton'] = gr.Checkbox(label="triton", value=shared.args.triton)
|
||||||
shared.gradio['desc_act'] = gr.Checkbox(label="desc_act", value=shared.args.desc_act, info='\'desc_act\', \'wbits\', and \'groupsize\' are used for old models without a quantize_config.json.')
|
shared.gradio['desc_act'] = gr.Checkbox(label="desc_act", value=shared.args.desc_act, info='\'desc_act\', \'wbits\', and \'groupsize\' are used for old models without a quantize_config.json.')
|
||||||
|
shared.gradio['gptq_for_llama'] = gr.Checkbox(label="gptq-for-llama", value=shared.args.gptq_for_llama, info='Use GPTQ-for-LLaMa loader instead of AutoGPTQ. pre_layer should be used for CPU offloading instead of gpu-memory.')
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
gr.Markdown('GPTQ-for-LLaMa')
|
|
||||||
shared.gradio['gptq_for_llama'] = gr.Checkbox(label="gptq-for-llama", value=shared.args.gptq_for_llama, info='Use GPTQ-for-LLaMa to load the GPTQ model instead of AutoGPTQ. pre_layer should be used for CPU offloading instead of gpu-memory.')
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=shared.args.wbits if shared.args.wbits > 0 else "None")
|
shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=shared.args.wbits if shared.args.wbits > 0 else "None")
|
||||||
shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=shared.args.groupsize if shared.args.groupsize > 0 else "None")
|
shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=shared.args.groupsize if shared.args.groupsize > 0 else "None")
|
||||||
@ -457,10 +441,24 @@ def create_model_menus():
|
|||||||
shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), shared.gradio['autoload_model'], load)
|
shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), shared.gradio['autoload_model'], load)
|
||||||
|
|
||||||
|
|
||||||
|
def create_chat_settings_menus():
|
||||||
|
if not shared.is_chat():
|
||||||
|
return
|
||||||
|
|
||||||
|
with gr.Box():
|
||||||
|
gr.Markdown("Chat parameters")
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||||
|
shared.gradio['chat_prompt_size'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='chat_prompt_size', info='Set limit on prompt size by removing old messages (while retaining context and user input)', value=shared.settings['chat_prompt_size'])
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)', info='New generations will be called until either this number is reached or no new content is generated between two iterations.')
|
||||||
|
shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character')
|
||||||
|
|
||||||
|
|
||||||
def create_settings_menus(default_preset):
|
def create_settings_menus(default_preset):
|
||||||
|
|
||||||
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
|
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -493,13 +491,14 @@ def create_settings_menus(default_preset):
|
|||||||
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
|
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
create_chat_settings_menus()
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
gr.Markdown('Contrastive search')
|
gr.Markdown('Contrastive search')
|
||||||
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha', info='Contrastive Search is enabled by setting this to greater than zero and unchecking "do_sample". It should be used with a low value of top_k, for instance, top_k = 4.')
|
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha', info='Contrastive Search is enabled by setting this to greater than zero and unchecking "do_sample". It should be used with a low value of top_k, for instance, top_k = 4.')
|
||||||
|
|
||||||
gr.Markdown('Beam search (uses a lot of VRAM)')
|
gr.Markdown('Beam search')
|
||||||
shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams')
|
shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams')
|
||||||
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
||||||
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
||||||
@ -510,16 +509,6 @@ def create_settings_menus(default_preset):
|
|||||||
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
|
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
|
||||||
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
|
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
|
||||||
|
|
||||||
gr.Markdown('Other')
|
|
||||||
with gr.Accordion('Soft prompt', open=False):
|
|
||||||
with gr.Row():
|
|
||||||
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=utils.get_available_softprompts(), value='None', label='Soft prompt')
|
|
||||||
ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda: None, lambda: {'choices': utils.get_available_softprompts()}, 'refresh-button')
|
|
||||||
|
|
||||||
gr.Markdown('Upload a soft prompt (.zip format):')
|
|
||||||
with gr.Row():
|
|
||||||
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
|
|
||||||
|
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@ -535,8 +524,6 @@ def create_settings_menus(default_preset):
|
|||||||
gr.Markdown('[Click here for more information.](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Generation-parameters.md)')
|
gr.Markdown('[Click here for more information.](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Generation-parameters.md)')
|
||||||
|
|
||||||
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']])
|
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']])
|
||||||
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
|
|
||||||
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
|
|
||||||
|
|
||||||
|
|
||||||
def set_interface_arguments(interface_mode, extensions, bool_active):
|
def set_interface_arguments(interface_mode, extensions, bool_active):
|
||||||
@ -696,17 +683,6 @@ def create_interface():
|
|||||||
shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
|
shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
|
||||||
|
|
||||||
with gr.Tab("Parameters", elem_id="parameters"):
|
with gr.Tab("Parameters", elem_id="parameters"):
|
||||||
with gr.Box():
|
|
||||||
gr.Markdown("Chat parameters")
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
|
||||||
shared.gradio['chat_prompt_size'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='chat_prompt_size', info='Set limit on prompt size by removing old messages (while retaining context and user input)', value=shared.settings['chat_prompt_size'])
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)', info='New generations will be called until either this number is reached or no new content is generated between two iterations.')
|
|
||||||
shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character')
|
|
||||||
|
|
||||||
create_settings_menus(default_preset)
|
create_settings_menus(default_preset)
|
||||||
|
|
||||||
# Create notebook mode interface
|
# Create notebook mode interface
|
||||||
|
@ -32,7 +32,7 @@ chat-instruct_command: 'Continue the chat dialogue below. Write a single reply f
|
|||||||
<|prompt|>'
|
<|prompt|>'
|
||||||
chat_prompt_size: 2048
|
chat_prompt_size: 2048
|
||||||
chat_prompt_size_min: 0
|
chat_prompt_size_min: 0
|
||||||
chat_prompt_size_max: 2048
|
chat_prompt_size_max: 8192
|
||||||
chat_generation_attempts: 1
|
chat_generation_attempts: 1
|
||||||
chat_generation_attempts_min: 1
|
chat_generation_attempts_min: 1
|
||||||
chat_generation_attempts_max: 10
|
chat_generation_attempts_max: 10
|
||||||
|
Loading…
x
Reference in New Issue
Block a user