Merge branch 'main' into main

This commit is contained in:
SDS 2023-03-30 21:16:51 +02:00 committed by GitHub
commit 848f9d8fde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 654 additions and 208 deletions

View File

@ -13,7 +13,7 @@ jobs:
- uses: actions/stale@v5 - uses: actions/stale@v5
with: with:
stale-issue-message: "" stale-issue-message: ""
close-issue-message: "This issue has been closed due to inactivity for 30 days. If you believe it is still relevant, you can reopen it (if you are the author) or leave a comment below." close-issue-message: "This issue has been closed due to inactivity for 30 days. If you believe it is still relevant, please leave a comment below."
days-before-issue-stale: 30 days-before-issue-stale: 30
days-before-issue-close: 0 days-before-issue-close: 0
stale-issue-label: "stale" stale-issue-label: "stale"

1
.gitignore vendored
View File

@ -19,3 +19,4 @@ repositories
settings.json settings.json
img_bot* img_bot*
img_me* img_me*
prompts/[0-9]*

View File

@ -36,10 +36,32 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
## Installation ## Installation
The recommended installation methods are the following: ### One-click installers
* Linux and MacOS: using conda natively. [oobabooga-windows.zip](https://github.com/oobabooga/text-generation-webui/releases/download/installers/oobabooga-windows.zip)
* Windows: using conda on WSL ([WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/Windows-Subsystem-for-Linux-(Ubuntu)-Installation-Guide)).
Just download the zip above, extract it, and double click on "install". The web UI and all its dependencies will be installed in the same folder.
* To download a model, double click on "download-model"
* To start the web UI, double click on "start-webui"
Source codes: https://github.com/oobabooga/one-click-installers
> **Note**
>
> Thanks to [@jllllll](https://github.com/jllllll) and [@ClayShoaf](https://github.com/ClayShoaf), the Windows 1-click installer now sets up 8-bit and 4-bit requirements out of the box. No additional installation steps are necessary.
> **Note**
>
> There is no need to run the installer as admin.
### Manual installation using Conda
Recommended if you have some experience with the command-line.
On Windows, I additionally recommend carrying out the installation on WSL instead of the base system: [WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/Windows-Subsystem-for-Linux-(Ubuntu)-Installation-Guide).
#### 0. Install Conda
Conda can be downloaded here: https://docs.conda.io/en/latest/miniconda.html Conda can be downloaded here: https://docs.conda.io/en/latest/miniconda.html
@ -84,26 +106,10 @@ pip install -r requirements.txt
> >
> For bitsandbytes and `--load-in-8bit` to work on Linux/WSL, this dirty fix is currently necessary: https://github.com/oobabooga/text-generation-webui/issues/400#issuecomment-1474876859 > For bitsandbytes and `--load-in-8bit` to work on Linux/WSL, this dirty fix is currently necessary: https://github.com/oobabooga/text-generation-webui/issues/400#issuecomment-1474876859
### Alternative: one-click installers
[oobabooga-windows.zip](https://github.com/oobabooga/one-click-installers/archive/refs/heads/oobabooga-windows.zip) ### Alternative: manual Windows installation
[oobabooga-linux.zip](https://github.com/oobabooga/one-click-installers/archive/refs/heads/oobabooga-linux.zip) As an alternative to the recommended WSL method, you can install the web UI natively on Windows using this guide. It will be a lot harder and the performance may be slower: [Windows installation guide](https://github.com/oobabooga/text-generation-webui/wiki/Windows-installation-guide).
Just download the zip above, extract it, and double click on "install". The web UI and all its dependencies will be installed in the same folder.
* To download a model, double click on "download-model"
* To start the web UI, double click on "start-webui"
Source codes: https://github.com/oobabooga/one-click-installers
> **Note**
>
> To get 8-bit and 4-bit models working in your 1-click Windows installation, you can use the [one-click-bandaid](https://github.com/ClayShoaf/oobabooga-one-click-bandaid).
### Alternative: native Windows installation
As an alternative to the recommended WSL method, you can install the web UI natively on Windows using this guide. It will be a lot harder and the performance may be slower: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings).
### Alternative: Docker ### Alternative: Docker
@ -177,7 +183,7 @@ Optionally, you can use the following command-line flags:
| `--cpu` | Use the CPU to generate text.| | `--cpu` | Use the CPU to generate text.|
| `--load-in-8bit` | Load the model with 8-bit precision.| | `--load-in-8bit` | Load the model with 8-bit precision.|
| `--wbits WBITS` | GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. | | `--wbits WBITS` | GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. |
| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently only LLaMA and OPT are supported. | | `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. |
| `--groupsize GROUPSIZE` | GPTQ: Group size. | | `--groupsize GROUPSIZE` | GPTQ: Group size. |
| `--pre_layer PRE_LAYER` | GPTQ: The number of layers to preload. | | `--pre_layer PRE_LAYER` | GPTQ: The number of layers to preload. |
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. | | `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
@ -198,12 +204,15 @@ Optionally, you can use the following command-line flags:
| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. | | `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. |
| `--no-stream` | Don't stream the text output in real time. | | `--no-stream` | Don't stream the text output in real time. |
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.| | `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.|
| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. | | `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
| `--listen` | Make the web UI reachable from your local network.| | `--model-dir MODEL_DIR` | Path to directory with all the models |
| `--listen-port LISTEN_PORT` | The listening port that the server will use. | | `--lora-dir LORA_DIR` | Path to directory with all the loras |
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. | | `--verbose` | Print the prompts to the terminal. |
| `--auto-launch` | Open the web UI in the default browser upon launch. | | `--listen` | Make the web UI reachable from your local network. |
| `--verbose` | Print the prompts to the terminal. | | `--listen-port LISTEN_PORT` | The listening port that the server will use. |
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
| `--auto-launch` | Open the web UI in the default browser upon launch. |
| `--gradio-auth-path GRADIO_AUTH_PATH` | Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3" |
Out of memory errors? [Check the low VRAM guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide). Out of memory errors? [Check the low VRAM guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide).

View File

@ -23,3 +23,16 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* {
.pending.svelte-1ed2p3z { .pending.svelte-1ed2p3z {
opacity: 1; opacity: 1;
} }
#extensions {
padding: 0;
padding: 0;
}
#gradio-chatbot {
height: 66.67vh;
}
.wrap.svelte-6roggh.svelte-6roggh {
max-height: 92.5%;
}

View File

@ -37,20 +37,29 @@
text-decoration: none !important; text-decoration: none !important;
} }
svg {
display: unset !important;
vertical-align: middle !important;
margin: 5px;
}
ol li p, ul li p { ol li p, ul li p {
display: inline-block; display: inline-block;
} }
#main, #parameters, #chat-settings, #interface-mode, #lora { #main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab {
border: 0; border: 0;
} }
.gradio-container-3-18-0 .prose * h1, h2, h3, h4 { .gradio-container-3-18-0 .prose * h1, h2, h3, h4 {
color: white; color: white;
} }
.gradio-container {
max-width: 100% !important;
padding-top: 0 !important;
}
#extensions {
padding: 15px;
padding: 15px;
}
span.math.inline {
font-size: 27px;
vertical-align: baseline !important;
}

View File

@ -11,7 +11,7 @@ let extensions = document.getElementById('extensions');
main_parent.addEventListener('click', function(e) { main_parent.addEventListener('click', function(e) {
// Check if the main element is visible // Check if the main element is visible
if (main.offsetHeight > 0 && main.offsetWidth > 0) { if (main.offsetHeight > 0 && main.offsetWidth > 0) {
extensions.style.display = 'block'; extensions.style.display = 'flex';
} else { } else {
extensions.style.display = 'none'; extensions.style.display = 'none';
} }

View File

@ -8,38 +8,33 @@ python download-model.py facebook/opt-1.3b
import argparse import argparse
import base64 import base64
import datetime
import json import json
import multiprocessing
import re import re
import sys import sys
from pathlib import Path from pathlib import Path
import requests import requests
import tqdm import tqdm
from tqdm.contrib.concurrent import thread_map
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str, default=None, nargs='?') parser.add_argument('MODEL', type=str, default=None, nargs='?')
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.') parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
args = parser.parse_args() args = parser.parse_args()
def get_file(args): def get_file(url, output_folder):
url = args[0]
output_folder = args[1]
idx = args[2]
tot = args[3]
print(f"Downloading file {idx} of {tot}...")
r = requests.get(url, stream=True) r = requests.get(url, stream=True)
with open(output_folder / Path(url.split('/')[-1]), 'wb') as f: with open(output_folder / Path(url.rsplit('/', 1)[1]), 'wb') as f:
total_size = int(r.headers.get('content-length', 0)) total_size = int(r.headers.get('content-length', 0))
block_size = 1024 block_size = 1024
t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True) with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
for data in r.iter_content(block_size): for data in r.iter_content(block_size):
t.update(len(data)) t.update(len(data))
f.write(data) f.write(data)
t.close()
def sanitize_branch_name(branch_name): def sanitize_branch_name(branch_name):
pattern = re.compile(r"^[a-zA-Z0-9._-]+$") pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
@ -98,8 +93,10 @@ def get_download_links_from_huggingface(model, branch):
cursor = b"" cursor = b""
links = [] links = []
sha256 = []
classifications = [] classifications = []
has_pytorch = False has_pytorch = False
has_pt = False
has_safetensors = False has_safetensors = False
is_lora = False is_lora = False
while True: while True:
@ -115,12 +112,14 @@ def get_download_links_from_huggingface(model, branch):
is_lora = True is_lora = True
is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname) is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
is_safetensors = re.match("model.*\.safetensors", fname) is_safetensors = re.match(".*\.safetensors", fname)
is_pt = re.match(".*\.pt", fname) is_pt = re.match(".*\.pt", fname)
is_tokenizer = re.match("tokenizer.*\.model", fname) is_tokenizer = re.match("tokenizer.*\.model", fname)
is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
if any((is_pytorch, is_safetensors, is_pt, is_tokenizer, is_text)): if any((is_pytorch, is_safetensors, is_pt, is_tokenizer, is_text)):
if 'lfs' in dict[i]:
sha256.append([fname, dict[i]['lfs']['oid']])
if is_text: if is_text:
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}") links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
classifications.append('text') classifications.append('text')
@ -134,6 +133,7 @@ def get_download_links_from_huggingface(model, branch):
has_pytorch = True has_pytorch = True
classifications.append('pytorch') classifications.append('pytorch')
elif is_pt: elif is_pt:
has_pt = True
classifications.append('pt') classifications.append('pt')
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50' cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
@ -141,12 +141,15 @@ def get_download_links_from_huggingface(model, branch):
cursor = cursor.replace(b'=', b'%3D') cursor = cursor.replace(b'=', b'%3D')
# If both pytorch and safetensors are available, download safetensors only # If both pytorch and safetensors are available, download safetensors only
if has_pytorch and has_safetensors: if (has_pytorch or has_pt) and has_safetensors:
for i in range(len(classifications)-1, -1, -1): for i in range(len(classifications)-1, -1, -1):
if classifications[i] == 'pytorch': if classifications[i] in ['pytorch', 'pt']:
links.pop(i) links.pop(i)
return links, is_lora return links, sha256, is_lora
def download_files(file_list, output_folder, num_threads=8):
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads)
if __name__ == '__main__': if __name__ == '__main__':
model = args.MODEL model = args.MODEL
@ -166,18 +169,32 @@ if __name__ == '__main__':
print(f"Error: {err_branch}") print(f"Error: {err_branch}")
sys.exit() sys.exit()
links, is_lora = get_download_links_from_huggingface(model, branch) links, sha256, is_lora = get_download_links_from_huggingface(model, branch)
base_folder = 'models' if not is_lora else 'loras'
if branch != 'main': if args.output is not None:
output_folder = Path(base_folder) / (model.split('/')[-1] + f'_{branch}') base_folder = args.output
else: else:
output_folder = Path(base_folder) / model.split('/')[-1] base_folder = 'models' if not is_lora else 'loras'
output_folder = f"{'_'.join(model.split('/')[-2:])}"
if branch != 'main':
output_folder += f'_{branch}'
# Creating the folder and writing the metadata
output_folder = Path(base_folder) / output_folder
if not output_folder.exists(): if not output_folder.exists():
output_folder.mkdir() output_folder.mkdir()
with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
f.write(f'url: https://huggingface.co/{model}\n')
f.write(f'branch: {branch}\n')
f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
sha256_str = ''
for i in range(len(sha256)):
sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
if sha256_str != '':
f.write(f'sha256sum:\n{sha256_str}')
# Downloading the files # Downloading the files
print(f"Downloading the model to {output_folder}") print(f"Downloading the model to {output_folder}")
pool = multiprocessing.Pool(processes=args.threads) download_files(links, output_folder, args.threads)
results = pool.map(get_file, [[links[i], output_folder, i+1, len(links)] for i in range(len(links))]) print()
pool.close()
pool.join()

View File

@ -43,14 +43,14 @@ class Handler(BaseHTTPRequestHandler):
generator = generate_reply( generator = generate_reply(
question = prompt, question = prompt,
max_new_tokens = body.get('max_length', 200), max_new_tokens = int(body.get('max_length', 200)),
do_sample=True, do_sample=True,
temperature=body.get('temperature', 0.5), temperature=float(body.get('temperature', 0.5)),
top_p=body.get('top_p', 1), top_p=float(body.get('top_p', 1)),
typical_p=body.get('typical', 1), typical_p=float(body.get('typical', 1)),
repetition_penalty=body.get('rep_pen', 1.1), repetition_penalty=float(body.get('rep_pen', 1.1)),
encoder_repetition_penalty=1, encoder_repetition_penalty=1,
top_k=body.get('top_k', 0), top_k=int(body.get('top_k', 0)),
min_length=0, min_length=0,
no_repeat_ngram_size=0, no_repeat_ngram_size=0,
num_beams=1, num_beams=1,
@ -62,7 +62,10 @@ class Handler(BaseHTTPRequestHandler):
answer = '' answer = ''
for a in generator: for a in generator:
answer = a[0] if isinstance(a, str):
answer = a
else:
answer = a[0]
response = json.dumps({ response = json.dumps({
'results': [{ 'results': [{

View File

@ -26,6 +26,7 @@ current_params = params.copy()
voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115'] voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high'] voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast'] voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
# Used for making text xml compatible, needed for voice pitch and speed control # Used for making text xml compatible, needed for voice pitch and speed control
table = str.maketrans({ table = str.maketrans({
@ -77,6 +78,7 @@ def input_modifier(string):
shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>','controls>')] shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>','controls>')]
shared.processing_message = "*Is recording a voice message...*" shared.processing_message = "*Is recording a voice message...*"
shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated
return string return string
def output_modifier(string): def output_modifier(string):
@ -84,7 +86,7 @@ def output_modifier(string):
This function is applied to the model outputs. This function is applied to the model outputs.
""" """
global model, current_params global model, current_params, streaming_state
for i in params: for i in params:
if params[i] != current_params[i]: if params[i] != current_params[i]:
@ -116,6 +118,7 @@ def output_modifier(string):
string += f'\n\n{original_string}' string += f'\n\n{original_string}'
shared.processing_message = "*Is typing...*" shared.processing_message = "*Is typing...*"
shared.args.no_stream = streaming_state # restore the streaming option to the previous value
return string return string
def bot_prefix_modifier(string): def bot_prefix_modifier(string):

View File

@ -4,22 +4,60 @@ from pathlib import Path
import accelerate import accelerate
import torch import torch
import transformers
from transformers import AutoConfig, AutoModelForCausalLM
import modules.shared as shared import modules.shared as shared
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa"))) sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
import llama
import llama_inference_offload import llama_inference_offload
import opt from modelutils import find_layers
from quant import make_quant
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
config = AutoConfig.from_pretrained(model)
def noop(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
for name in exclude_layers:
if name in layers:
del layers[name]
make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold)
del layers
print('Loading model ...')
if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint))
else:
model.load_state_dict(torch.load(checkpoint))
model.seqlen = 2048
print('Done.')
return model
def load_quantized(model_name): def load_quantized(model_name):
if not shared.args.model_type: if not shared.args.model_type:
# Try to determine model type from model name # Try to determine model type from model name
if model_name.lower().startswith(('llama', 'alpaca')): name = model_name.lower()
if any((k in name for k in ['llama', 'alpaca'])):
model_type = 'llama' model_type = 'llama'
elif model_name.lower().startswith(('opt', 'galactica')): elif any((k in name for k in ['opt-', 'galactica'])):
model_type = 'opt' model_type = 'opt'
elif any((k in name for k in ['gpt-j', 'pygmalion-6b'])):
model_type = 'gptj'
else: else:
print("Can't determine model type from model name. Please specify it manually using --model_type " print("Can't determine model type from model name. Please specify it manually using --model_type "
"argument") "argument")
@ -27,15 +65,12 @@ def load_quantized(model_name):
else: else:
model_type = shared.args.model_type.lower() model_type = shared.args.model_type.lower()
if model_type == 'llama': if model_type == 'llama' and shared.args.pre_layer:
if not shared.args.pre_layer: load_quant = llama_inference_offload.load_quant
load_quant = llama.load_quant elif model_type in ('llama', 'opt', 'gptj'):
else: load_quant = _load_quant
load_quant = llama_inference_offload.load_quant
elif model_type == 'opt':
load_quant = opt.load_quant
else: else:
print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported") print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
exit() exit()
# Now we are going to try to locate the quantized model file. # Now we are going to try to locate the quantized model file.
@ -75,7 +110,8 @@ def load_quantized(model_name):
if shared.args.pre_layer: if shared.args.pre_layer:
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer) model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
else: else:
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize) threshold = False if model_type == 'gptj' else 128
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold)
# accelerate offload (doesn't work properly) # accelerate offload (doesn't work properly)
if shared.args.gpu_memory: if shared.args.gpu_memory:

View File

@ -1,6 +1,7 @@
from pathlib import Path from pathlib import Path
import torch import torch
from peft import PeftModel
import modules.shared as shared import modules.shared as shared
from modules.models import load_model from modules.models import load_model
@ -14,15 +15,13 @@ def reload_model():
def add_lora_to_model(lora_name): def add_lora_to_model(lora_name):
from peft import PeftModel
# If a LoRA had been previously loaded, or if we want # If a LoRA had been previously loaded, or if we want
# to unload a LoRA, reload the model # to unload a LoRA, reload the model
if shared.lora_name != "None" or lora_name == "None": if shared.lora_name not in ['None', ''] or lora_name in ['None', '']:
reload_model() reload_model()
shared.lora_name = lora_name shared.lora_name = lora_name
if lora_name != "None": if lora_name not in ['None', '']:
print(f"Adding the LoRA {lora_name} to the model...") print(f"Adding the LoRA {lora_name} to the model...")
params = {} params = {}
if not shared.args.cpu: if not shared.args.cpu:
@ -32,7 +31,7 @@ def add_lora_to_model(lora_name):
elif shared.args.load_in_8bit: elif shared.args.load_in_8bit:
params['device_map'] = {'': 0} params['device_map'] = {'': 0}
shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params) shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_name}"), **params)
if not shared.args.load_in_8bit and not shared.args.cpu: if not shared.args.load_in_8bit and not shared.args.cpu:
shared.model.half() shared.model.half()
if not hasattr(shared.model, "hf_device_map"): if not hasattr(shared.model, "hf_device_map"):

View File

@ -1,4 +1,5 @@
import gc import gc
import traceback
from queue import Queue from queue import Queue
from threading import Thread from threading import Thread
@ -54,7 +55,7 @@ class Iteratorize:
self.stop_now = False self.stop_now = False
def _callback(val): def _callback(val):
if self.stop_now: if self.stop_now or shared.stop_everything:
raise ValueError raise ValueError
self.q.put(val) self.q.put(val)
@ -63,6 +64,10 @@ class Iteratorize:
ret = self.mfunc(callback=_callback, **self.kwargs) ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError: except ValueError:
pass pass
except:
traceback.print_exc()
pass
clear_torch_cache() clear_torch_cache()
self.q.put(self.sentinel) self.q.put(self.sentinel)
if self.c_callback: if self.c_callback:

View File

@ -80,11 +80,7 @@ def extract_message_from_reply(reply, name1, name2, check):
reply = fix_newlines(reply) reply = fix_newlines(reply)
return reply, next_character_found return reply, next_character_found
def stop_everything_event():
shared.stop_everything = True
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False): def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
shared.stop_everything = False
just_started = True just_started = True
eos_token = '\n' if check else None eos_token = '\n' if check else None
name1_original = name1 name1_original = name1

View File

@ -7,7 +7,7 @@ import modules.shared as shared
state = {} state = {}
available_extensions = [] available_extensions = []
setup_called = False setup_called = set()
def load_extensions(): def load_extensions():
global state global state
@ -53,18 +53,17 @@ def create_extensions_block():
should_display_ui = False should_display_ui = False
# Running setup function # Running setup function
if not setup_called: for extension, name in iterator():
for extension, name in iterator(): if hasattr(extension, "ui"):
if hasattr(extension, "setup"): should_display_ui = True
extension.setup() if extension not in setup_called and hasattr(extension, "setup"):
if hasattr(extension, "ui"): setup_called.add(extension)
should_display_ui = True extension.setup()
setup_called = True
# Creating the extension ui elements # Creating the extension ui elements
if should_display_ui: if should_display_ui:
with gr.Box(elem_id="extensions"): with gr.Column(elem_id="extensions"):
gr.Markdown("Extensions")
for extension, name in iterator(): for extension, name in iterator():
gr.Markdown(f"\n### {name}")
if hasattr(extension, "ui"): if hasattr(extension, "ui"):
extension.ui() extension.ui()

View File

@ -34,7 +34,7 @@ def convert_to_markdown(string):
string = string.replace('\\begin{blockquote}', '> ') string = string.replace('\\begin{blockquote}', '> ')
string = string.replace('\\end{blockquote}', '') string = string.replace('\\end{blockquote}', '')
string = re.sub(r"(.)```", r"\1\n```", string) string = re.sub(r"(.)```", r"\1\n```", string)
# string = fix_newlines(string) string = fix_newlines(string)
return markdown.markdown(string, extensions=['fenced_code']) return markdown.markdown(string, extensions=['fenced_code'])
def generate_basic_html(string): def generate_basic_html(string):

View File

@ -41,14 +41,14 @@ def load_model(model_name):
print(f"Loading {model_name}...") print(f"Loading {model_name}...")
t0 = time.time() t0 = time.time()
shared.is_RWKV = model_name.lower().startswith('rwkv-') shared.is_RWKV = 'rwkv-' in model_name.lower()
# Default settings # Default settings
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]): if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')): if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True) model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), device_map='auto', load_in_8bit=True)
else: else:
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
if torch.has_mps: if torch.has_mps:
device = torch.device('mps') device = torch.device('mps')
model = model.to(device) model = model.to(device)
@ -76,11 +76,11 @@ def load_model(model_name):
num_bits=4, group_size=64, num_bits=4, group_size=64,
group_dim=2, symmetric=False)) group_dim=2, symmetric=False))
model = OptLM(f"facebook/{shared.model_name}", env, "models", policy) model = OptLM(f"facebook/{shared.model_name}", env, shared.args.model_dir, policy)
# DeepSpeed ZeRO-3 # DeepSpeed ZeRO-3
elif shared.args.deepspeed: elif shared.args.deepspeed:
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0] model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
model.module.eval() # Inference model.module.eval() # Inference
print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}") print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
@ -89,8 +89,8 @@ def load_model(model_name):
elif shared.is_RWKV: elif shared.is_RWKV:
from modules.RWKV import RWKVModel, RWKVTokenizer from modules.RWKV import RWKVModel, RWKVTokenizer
model = RWKVModel.from_pretrained(Path(f'models/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda") model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
tokenizer = RWKVTokenizer.from_pretrained(Path('models')) tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir))
return model, tokenizer return model, tokenizer
@ -142,7 +142,7 @@ def load_model(model_name):
if shared.args.disk: if shared.args.disk:
params["offload_folder"] = shared.args.disk_cache_dir params["offload_folder"] = shared.args.disk_cache_dir
checkpoint = Path(f'models/{shared.model_name}') checkpoint = Path(f'{shared.args.model_dir}/{shared.model_name}')
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto': if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
config = AutoConfig.from_pretrained(checkpoint) config = AutoConfig.from_pretrained(checkpoint)
@ -159,10 +159,10 @@ def load_model(model_name):
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params) model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
# Loading the tokenizer # Loading the tokenizer
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists(): if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/")) tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
else: else:
tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{shared.model_name}/")) tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
tokenizer.truncation_side = 'left' tokenizer.truncation_side = 'left'
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")

View File

@ -37,28 +37,23 @@ settings = {
'chat_generation_attempts': 1, 'chat_generation_attempts': 1,
'chat_generation_attempts_min': 1, 'chat_generation_attempts_min': 1,
'chat_generation_attempts_max': 5, 'chat_generation_attempts_max': 5,
'name1_pygmalion': 'You',
'name2_pygmalion': 'Kawaii',
'context_pygmalion': "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n<START>",
'stop_at_newline_pygmalion': False,
'default_extensions': [], 'default_extensions': [],
'chat_default_extensions': ["gallery"], 'chat_default_extensions': ["gallery"],
'presets': { 'presets': {
'default': 'NovelAI-Sphinx Moth', 'default': 'NovelAI-Sphinx Moth',
'(alpaca-*|llama-*)': "LLaMA-Precise", '(alpaca-*|llama-*)': "LLaMA-Precise",
'pygmalion-*': 'Pygmalion', '.*pygmalion': 'Pygmalion',
'RWKV-*': 'Naive', '.*RWKV': 'Naive',
}, },
'prompts': { 'prompts': {
'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:', 'default': 'QA',
'^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n', '.*(gpt4chan|gpt-4chan|4chan)': 'GPT-4chan',
'(rosey|chip|joi)_.*_instruct.*': 'User: \n', '.*oasst': 'Open Assistant',
'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>', '.*alpaca': "Alpaca",
'alpaca-*': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n",
}, },
'lora_prompts': { 'lora_prompts': {
'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:', 'default': 'QA',
'(alpaca-lora-7b|alpaca-lora-13b|alpaca-lora-30b)': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n" '.*(alpaca-lora-7b|alpaca-lora-13b|alpaca-lora-30b)': "Alpaca",
} }
} }
@ -85,7 +80,7 @@ parser.add_argument('--gptq-bits', type=int, default=0, help='DEPRECATED: use --
parser.add_argument('--gptq-model-type', type=str, help='DEPRECATED: use --model_type instead.') parser.add_argument('--gptq-model-type', type=str, help='DEPRECATED: use --model_type instead.')
parser.add_argument('--gptq-pre-layer', type=int, default=0, help='DEPRECATED: use --pre_layer instead.') parser.add_argument('--gptq-pre-layer', type=int, default=0, help='DEPRECATED: use --pre_layer instead.')
parser.add_argument('--wbits', type=int, default=0, help='GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.') parser.add_argument('--wbits', type=int, default=0, help='GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.')
parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently only LLaMA and OPT are supported.') parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported.')
parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.') parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.')
parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to preload.') parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to preload.')
@ -108,11 +103,14 @@ parser.add_argument('--rwkv-cuda-on', action='store_true', help='RWKV: Compile t
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.') parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.')
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.') parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.')
parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.') parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models")
parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to directory with all the loras")
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.') parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.') parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.') parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.')
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', default=None)
args = parser.parse_args() args = parser.parse_args()
# Provisional, this will be deleted later # Provisional, this will be deleted later

View File

@ -42,7 +42,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
def decode(output_ids): def decode(output_ids):
# Open Assistant relies on special tokens like <|endoftext|> # Open Assistant relies on special tokens like <|endoftext|>
if re.match('(oasst|galactica)-*', shared.model_name.lower()): if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
return shared.tokenizer.decode(output_ids, skip_special_tokens=False) return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
else: else:
reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True) reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
@ -77,10 +77,10 @@ def fix_galactica(s):
def formatted_outputs(reply, model_name): def formatted_outputs(reply, model_name):
if not (shared.args.chat or shared.args.cai_chat): if not (shared.args.chat or shared.args.cai_chat):
if model_name.lower().startswith('galactica'): if 'galactica' in model_name.lower():
reply = fix_galactica(reply) reply = fix_galactica(reply)
return reply, reply, generate_basic_html(reply) return reply, reply, generate_basic_html(reply)
elif model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')): elif any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])):
reply = fix_gpt4chan(reply) reply = fix_gpt4chan(reply)
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply) return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
else: else:
@ -99,9 +99,13 @@ def set_manual_seed(seed):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
def stop_everything_event():
shared.stop_everything = True
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]): def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]):
clear_torch_cache() clear_torch_cache()
set_manual_seed(seed) set_manual_seed(seed)
shared.stop_everything = False
t0 = time.time() t0 = time.time()
original_question = question original_question = question
@ -236,8 +240,6 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
break break
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
yield formatted_outputs(reply, shared.model_name)
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else: else:
for i in range(max_new_tokens//8+1): for i in range(max_new_tokens//8+1):

275
modules/training.py Normal file
View File

@ -0,0 +1,275 @@
import json
import sys
import threading
import time
import traceback
from pathlib import Path
import gradio as gr
import torch
import transformers
from datasets import Dataset, load_dataset
from peft import (LoraConfig, get_peft_model, get_peft_model_state_dict,
prepare_model_for_int8_training)
from modules import shared, ui
WANT_INTERRUPT = False
CURRENT_STEPS = 0
MAX_STEPS = 0
CURRENT_GRADIENT_ACCUM = 1
def get_dataset(path: str, ext: str):
return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path(path).glob(f'*.{ext}'))), key=str.lower)
def create_train_interface():
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
lora_name = gr.Textbox(label="Name", info="The name of your new LoRA file")
with gr.Row():
# TODO: Implement multi-device support.
micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.')
batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.')
with gr.Row():
epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.')
learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='Learning rate, in scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.')
# TODO: What is the actual maximum rank? Likely distinct per model. This might be better to somehow be on a log scale.
lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='LoRA Rank, or dimension count. Higher values produce a larger file with better control over the model\'s content. Smaller values produce a smaller file with less overall control. Small values like 4 or 8 are great for stylistic guidance, high values like 128 or 256 are good for teaching content upgrades. Higher ranks also require higher VRAM.')
lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='LoRA Alpha. This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
# TODO: Better explain what this does, in terms of real world effect especially.
lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers.')
cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.')
with gr.Tab(label="Formatted Dataset"):
with gr.Row():
dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
ui.create_refresh_button(dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The dataset file used to evaluate the model after training.')
ui.create_refresh_button(eval_dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
ui.create_refresh_button(format, lambda : None, lambda : {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
with gr.Tab(label="Raw Text File"):
with gr.Row():
raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
ui.create_refresh_button(raw_text_file, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
overlap_len = gr.Slider(label='Overlap Length', minimum=0,maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length above). Setting overlap to exactly half the cutoff length may be ideal.')
with gr.Row():
start_button = gr.Button("Start LoRA Training")
stop_button = gr.Button("Interrupt")
output = gr.Markdown(value="Ready")
start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len], [output])
stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
def do_interrupt():
global WANT_INTERRUPT
WANT_INTERRUPT = True
class Callbacks(transformers.TrainerCallback):
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
global CURRENT_STEPS, MAX_STEPS
CURRENT_STEPS = state.global_step * CURRENT_GRADIENT_ACCUM
MAX_STEPS = state.max_steps * CURRENT_GRADIENT_ACCUM
if WANT_INTERRUPT:
control.should_epoch_stop = True
control.should_training_stop = True
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
global CURRENT_STEPS
CURRENT_STEPS += 1
if WANT_INTERRUPT:
control.should_epoch_stop = True
control.should_training_stop = True
def clean_path(base_path: str, path: str):
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
# Or swap it to a strict whitelist of [a-zA-Z_0-9]
path = path.replace('\\', '/').replace('..', '_')
if base_path is None:
return path
return f'{Path(base_path).absolute()}/{path}'
def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int,
lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int):
global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
WANT_INTERRUPT = False
CURRENT_STEPS = 0
MAX_STEPS = 0
# == Input validation / processing ==
yield "Prepping..."
lora_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}"
actual_lr = float(learning_rate)
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
yield f"Cannot input zeroes."
return
gradient_accumulation_steps = batch_size // micro_batch_size
CURRENT_GRADIENT_ACCUM = gradient_accumulation_steps
shared.tokenizer.pad_token = 0
shared.tokenizer.padding_side = "left"
def tokenize(prompt):
result = shared.tokenizer(prompt, truncation=True, max_length=cutoff_len + 1, padding="max_length")
return {
"input_ids": result["input_ids"][:-1],
"attention_mask": result["attention_mask"][:-1],
}
# == Prep the dataset, format, etc ==
if raw_text_file not in ['None', '']:
print("Loading raw text file dataset...")
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r') as file:
raw_text = file.read()
tokens = shared.tokenizer.encode(raw_text)
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
for i in range(1, len(tokens)):
tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
text_chunks = [shared.tokenizer.decode(x) for x in tokens]
del tokens
data = Dataset.from_list([tokenize(x) for x in text_chunks])
train_data = data.shuffle()
eval_data = None
del text_chunks
else:
if dataset in ['None', '']:
yield "**Missing dataset choice input, cannot continue.**"
return
if format in ['None', '']:
yield "**Missing format choice input, cannot continue.**"
return
with open(clean_path('training/formats', f'{format}.json'), 'r') as formatFile:
format_data: dict[str, str] = json.load(formatFile)
def generate_prompt(data_point: dict[str, str]):
for options, data in format_data.items():
if set(options.split(',')) == set(x[0] for x in data_point.items() if len(x[1].strip()) > 0):
for key, val in data_point.items():
data = data.replace(f'%{key}%', val)
return data
raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
def generate_and_tokenize_prompt(data_point):
prompt = generate_prompt(data_point)
return tokenize(prompt)
print("Loading JSON datasets...")
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
train_data = data['train'].shuffle().map(generate_and_tokenize_prompt)
if eval_dataset == 'None':
eval_data = None
else:
eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
eval_data = eval_data['train'].shuffle().map(generate_and_tokenize_prompt)
# == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
print("Getting model ready...")
prepare_model_for_int8_training(shared.model)
print("Prepping for training...")
config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
# TODO: Should target_modules be configurable?
target_modules=[ "q_proj", "v_proj" ],
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM"
)
try:
lora_model = get_peft_model(shared.model, config)
except:
yield traceback.format_exc()
return
trainer = transformers.Trainer(
model=lora_model,
train_dataset=train_data,
eval_dataset=eval_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
# TODO: Should more of these be configurable? Probably.
warmup_steps=100,
num_train_epochs=epochs,
learning_rate=actual_lr,
fp16=True,
logging_steps=20,
evaluation_strategy="steps" if eval_data is not None else "no",
save_strategy="steps",
eval_steps=200 if eval_data is not None else None,
save_steps=200,
output_dir=lora_name,
save_total_limit=3,
load_best_model_at_end=True if eval_data is not None else False,
# TODO: Enable multi-device support
ddp_find_unused_parameters=None
),
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
callbacks=list([Callbacks()])
)
lora_model.config.use_cache = False
old_state_dict = lora_model.state_dict
lora_model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(lora_model, type(lora_model))
if torch.__version__ >= "2" and sys.platform != "win32":
lora_model = torch.compile(lora_model)
# == Main run and monitor loop ==
# TODO: save/load checkpoints to resume from?
print("Starting training...")
yield "Starting..."
def threadedRun():
trainer.train()
thread = threading.Thread(target=threadedRun)
thread.start()
lastStep = 0
startTime = time.perf_counter()
while thread.is_alive():
time.sleep(0.5)
if WANT_INTERRUPT:
yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*"
elif CURRENT_STEPS != lastStep:
lastStep = CURRENT_STEPS
timeElapsed = time.perf_counter() - startTime
if timeElapsed <= 0:
timerInfo = ""
totalTimeEstimate = 999
else:
its = CURRENT_STEPS / timeElapsed
if its > 1:
timerInfo = f"`{its:.2f}` it/s"
else:
timerInfo = f"`{1.0/its:.2f}` s/it"
totalTimeEstimate = (1.0/its) * (MAX_STEPS)
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds"
print("Training complete, saving...")
lora_model.save_pretrained(lora_name)
if WANT_INTERRUPT:
print("Training interrupted.")
yield f"Interrupted. Incomplete LoRA saved to `{lora_name}`"
else:
print("Training complete!")
yield f"Done! LoRA saved to `{lora_name}`"
def split_chunks(arr, step):
for i in range(0, len(arr), step):
yield arr[i:i + step]

6
prompts/Alpaca.txt Normal file
View File

@ -0,0 +1,6 @@
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Write a poem about the transformers Python library.
Mention the word "large language models" in that poem.
### Response:

6
prompts/GPT-4chan.txt Normal file
View File

@ -0,0 +1,6 @@
-----
--- 865467536
Hello, AI frens!
How are you doing on this fine day?
--- 865467537

View File

@ -0,0 +1 @@
<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>

4
prompts/QA.txt Normal file
View File

@ -0,0 +1,4 @@
Common sense questions and answers
Question:
Factual answer:

View File

@ -1,13 +1,14 @@
accelerate==0.17.1 accelerate==0.18.0
bitsandbytes==0.37.1 bitsandbytes==0.37.2
flexgen==0.1.7 flexgen==0.1.7
gradio==3.18.0 gradio==3.23.0
markdown markdown
numpy numpy
peft==0.2.0 peft==0.2.0
requests requests
rwkv==0.7.0 rwkv==0.7.1
safetensors==0.3.0 safetensors==0.3.0
sentencepiece sentencepiece
tqdm tqdm
datasets
git+https://github.com/huggingface/transformers git+https://github.com/huggingface/transformers

169
server.py
View File

@ -4,18 +4,18 @@ import re
import sys import sys
import time import time
import zipfile import zipfile
from datetime import datetime
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
import modules.chat as chat
import modules.extensions as extensions_module import modules.extensions as extensions_module
import modules.shared as shared from modules import chat, shared, training, ui
import modules.ui as ui
from modules.html_generator import generate_chat_html from modules.html_generator import generate_chat_html
from modules.LoRA import add_lora_to_model from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt from modules.models import load_model, load_soft_prompt
from modules.text_generation import clear_torch_cache, generate_reply from modules.text_generation import (clear_torch_cache, generate_reply,
stop_everything_event)
# Loading custom settings # Loading custom settings
settings_file = None settings_file = None
@ -31,13 +31,20 @@ if settings_file is not None:
def get_available_models(): def get_available_models():
if shared.args.flexgen: if shared.args.flexgen:
return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower) return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower)
else: else:
return sorted([re.sub('.pth$', '', item.name) for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
def get_available_presets(): def get_available_presets():
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
def get_available_prompts():
prompts = []
prompts += sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
prompts += sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('prompts').glob('*.txt'))), key=str.lower)
prompts += ['None']
return prompts
def get_available_characters(): def get_available_characters():
return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower) return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
@ -48,22 +55,25 @@ def get_available_softprompts():
return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower) return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
def get_available_loras(): def get_available_loras():
return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
def unload_model():
shared.model = shared.tokenizer = None
clear_torch_cache()
def load_model_wrapper(selected_model): def load_model_wrapper(selected_model):
if selected_model != shared.model_name: if selected_model != shared.model_name:
shared.model_name = selected_model shared.model_name = selected_model
shared.model = shared.tokenizer = None
clear_torch_cache() unload_model()
shared.model, shared.tokenizer = load_model(shared.model_name) if selected_model != '':
shared.model, shared.tokenizer = load_model(shared.model_name)
return selected_model return selected_model
def load_lora_wrapper(selected_lora): def load_lora_wrapper(selected_lora):
add_lora_to_model(selected_lora) add_lora_to_model(selected_lora)
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')] return selected_lora
return selected_lora, default_text
def load_preset_values(preset_menu, return_dict=False): def load_preset_values(preset_menu, return_dict=False):
generate_params = { generate_params = {
@ -93,7 +103,7 @@ def load_preset_values(preset_menu, return_dict=False):
if return_dict: if return_dict:
return generate_params return generate_params
else: else:
return preset_menu, generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping'] return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
def upload_soft_prompt(file): def upload_soft_prompt(file):
with zipfile.ZipFile(io.BytesIO(file)) as zf: with zipfile.ZipFile(io.BytesIO(file)) as zf:
@ -118,9 +128,46 @@ def create_model_and_preset_menus():
shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset') shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button') ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
def save_prompt(text):
fname = f"{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.txt"
with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
f.write(text)
return f"Saved to prompts/{fname}"
def load_prompt(fname):
if fname in ['None', '']:
return ''
else:
with open(Path(f'prompts/{fname}.txt'), 'r', encoding='utf-8') as f:
text = f.read()
if text[-1] == '\n':
text = text[:-1]
return text
def create_prompt_menus():
with gr.Row():
with gr.Column():
with gr.Row():
shared.gradio['prompt_menu'] = gr.Dropdown(choices=get_available_prompts(), value='None', label='Prompt')
ui.create_refresh_button(shared.gradio['prompt_menu'], lambda : None, lambda : {'choices': get_available_prompts()}, 'refresh-button')
with gr.Column():
with gr.Column():
shared.gradio['save_prompt'] = gr.Button('Save prompt')
shared.gradio['status'] = gr.Markdown('Ready')
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
def create_settings_menus(default_preset): def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True) generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
with gr.Row():
with gr.Column():
create_model_and_preset_menus()
with gr.Column():
shared.gradio['seed'] = gr.Number(value=-1, label='Seed (-1 for random)')
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Box(): with gr.Box():
@ -151,12 +198,6 @@ def create_settings_menus(default_preset):
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
shared.gradio['seed'] = gr.Number(value=-1, label='Seed (-1 for random)')
with gr.Row():
shared.gradio['preset_menu_mirror'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
ui.create_refresh_button(shared.gradio['preset_menu_mirror'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
with gr.Row(): with gr.Row():
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA') shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button') ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
@ -171,9 +212,8 @@ def create_settings_menus(default_preset):
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip']) shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True) shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['preset_menu_mirror', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
shared.gradio['preset_menu_mirror'].change(load_preset_values, [shared.gradio['preset_menu_mirror']], [shared.gradio[k] for k in ['preset_menu', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu']], show_progress=True)
shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu'], shared.gradio['textbox']], show_progress=True)
shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True) shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']]) shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
@ -238,12 +278,10 @@ if shared.args.lora:
# Default UI settings # Default UI settings
default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')] default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
if shared.lora_name != "None": if shared.lora_name != "None":
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')] default_text = load_prompt(shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')])
else: else:
default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')] default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
title ='Text generation web UI' title ='Text generation web UI'
description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
def create_interface(): def create_interface():
@ -255,13 +293,13 @@ def create_interface():
if shared.args.chat or shared.args.cai_chat: if shared.args.chat or shared.args.cai_chat:
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
if shared.args.cai_chat: if shared.args.cai_chat:
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character)) shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], shared.character))
else: else:
shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528")) shared.gradio['display'] = gr.Chatbot(value=shared.history['visible'], elem_id="gradio-chatbot")
shared.gradio['textbox'] = gr.Textbox(label='Input') shared.gradio['textbox'] = gr.Textbox(label='Input')
with gr.Row(): with gr.Row():
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
shared.gradio['Generate'] = gr.Button('Generate') shared.gradio['Generate'] = gr.Button('Generate')
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
with gr.Row(): with gr.Row():
shared.gradio['Impersonate'] = gr.Button('Impersonate') shared.gradio['Impersonate'] = gr.Button('Impersonate')
shared.gradio['Regenerate'] = gr.Button('Regenerate') shared.gradio['Regenerate'] = gr.Button('Regenerate')
@ -274,12 +312,10 @@ def create_interface():
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False) shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
create_model_and_preset_menus()
with gr.Tab("Character", elem_id="chat-settings"): with gr.Tab("Character", elem_id="chat-settings"):
shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name') shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name') shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Bot\'s name')
shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context') shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=5, label='Context')
with gr.Row(): with gr.Row():
shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu') shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button') ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
@ -317,7 +353,7 @@ def create_interface():
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
with gr.Column(): with gr.Column():
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)') shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?') shared.gradio['check'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
create_settings_menus(default_preset) create_settings_menus(default_preset)
@ -328,7 +364,7 @@ def create_interface():
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events, queue=False) shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream) shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream) shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
@ -364,24 +400,34 @@ def create_interface():
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None) shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True) shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
elif shared.args.notebook: elif shared.args.notebook:
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
with gr.Tab('Raw'):
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=25)
with gr.Tab('Markdown'):
shared.gradio['markdown'] = gr.Markdown()
with gr.Tab('HTML'):
shared.gradio['html'] = gr.HTML()
with gr.Row(): with gr.Row():
shared.gradio['Stop'] = gr.Button('Stop') with gr.Column(scale=4):
shared.gradio['Generate'] = gr.Button('Generate') with gr.Tab('Raw'):
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_id="textbox", lines=25)
with gr.Tab('Markdown'):
shared.gradio['markdown'] = gr.Markdown()
with gr.Tab('HTML'):
shared.gradio['html'] = gr.HTML()
with gr.Row():
with gr.Column():
with gr.Row():
shared.gradio['Generate'] = gr.Button('Generate')
shared.gradio['Stop'] = gr.Button('Stop')
with gr.Column():
pass
with gr.Column(scale=1):
gr.HTML('<div style="padding-bottom: 13px"></div>')
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
create_prompt_menus()
create_model_and_preset_menus()
with gr.Tab("Parameters", elem_id="parameters"): with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset) create_settings_menus(default_preset)
@ -389,7 +435,7 @@ def create_interface():
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(None, None, None, cancels=gen_events) shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
else: else:
@ -405,7 +451,7 @@ def create_interface():
with gr.Column(): with gr.Column():
shared.gradio['Stop'] = gr.Button('Stop') shared.gradio['Stop'] = gr.Button('Stop')
create_model_and_preset_menus() create_prompt_menus()
with gr.Column(): with gr.Column():
with gr.Tab('Raw'): with gr.Tab('Raw'):
@ -414,6 +460,7 @@ def create_interface():
shared.gradio['markdown'] = gr.Markdown() shared.gradio['markdown'] = gr.Markdown()
with gr.Tab('HTML'): with gr.Tab('HTML'):
shared.gradio['html'] = gr.HTML() shared.gradio['html'] = gr.HTML()
with gr.Tab("Parameters", elem_id="parameters"): with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset) create_settings_menus(default_preset)
@ -422,9 +469,12 @@ def create_interface():
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(None, None, None, cancels=gen_events) shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
with gr.Tab("Training", elem_id="training-tab"):
training.create_train_interface()
with gr.Tab("Interface mode", elem_id="interface-mode"): with gr.Tab("Interface mode", elem_id="interface-mode"):
modes = ["default", "notebook", "chat", "cai_chat"] modes = ["default", "notebook", "chat", "cai_chat"]
current_mode = "default" current_mode = "default"
@ -443,17 +493,26 @@ def create_interface():
shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", type="primary") shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", type="primary")
shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'cmd_arguments_menu']], None) shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'cmd_arguments_menu']], None)
shared.gradio['reset_interface'].click(lambda : None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500)}') shared.gradio['reset_interface'].click(lambda : None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
if shared.args.extensions is not None: if shared.args.extensions is not None:
extensions_module.create_extensions_block() extensions_module.create_extensions_block()
# Authentication
auth = None
if shared.args.gradio_auth_path is not None:
gradio_auth_creds = []
with open(shared.args.gradio_auth_path, 'r', encoding="utf8") as file:
for line in file.readlines():
gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
auth = [tuple(cred.split(':')) for cred in gradio_auth_creds]
# Launch the interface # Launch the interface
shared.gradio['interface'].queue() shared.gradio['interface'].queue()
if shared.args.listen: if shared.args.listen:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch) shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
else: else:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch) shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
create_interface() create_interface()

View File

@ -12,27 +12,23 @@
"chat_generation_attempts": 1, "chat_generation_attempts": 1,
"chat_generation_attempts_min": 1, "chat_generation_attempts_min": 1,
"chat_generation_attempts_max": 5, "chat_generation_attempts_max": 5,
"name1_pygmalion": "You",
"name2_pygmalion": "Kawaii",
"context_pygmalion": "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n<START>",
"stop_at_newline_pygmalion": false,
"default_extensions": [], "default_extensions": [],
"chat_default_extensions": [ "chat_default_extensions": [
"gallery" "gallery"
], ],
"presets": { "presets": {
"default": "NovelAI-Sphinx Moth", "default": "NovelAI-Sphinx Moth",
"pygmalion-*": "Pygmalion", ".*pygmalion": "Pygmalion",
"RWKV-*": "Naive" ".*RWKV": "Naive"
}, },
"prompts": { "prompts": {
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:", "default": "QA",
"^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n", ".*(gpt4chan|gpt-4chan|4chan)": "GPT-4chan",
"(rosey|chip|joi)_.*_instruct.*": "User: \n", ".*oasst": "Open Assistant",
"oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>" ".*alpaca": "Alpaca"
}, },
"lora_prompts": { "lora_prompts": {
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:", "default": "QA",
"alpaca-lora-7b": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n" ".*(alpaca-lora-7b|alpaca-lora-13b|alpaca-lora-30b)": "Alpaca"
} }
} }

View File

@ -0,0 +1,4 @@
{
"instruction,output": "User: %instruction%\nAssistant: %output%",
"instruction,input,output": "User: %instruction%: %input%\nAssistant: %output%"
}

View File

@ -0,0 +1,4 @@
{
"instruction,output": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n%instruction%\n\n### Response:\n%output%",
"instruction,input,output": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n%instruction%\n\n### Input:\n%input%\n\n### Response:\n%output%"
}