Merge branch 'main' into main

This commit is contained in:
SDS 2023-03-30 21:16:51 +02:00 committed by GitHub
commit 848f9d8fde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 654 additions and 208 deletions

View File

@ -13,7 +13,7 @@ jobs:
- uses: actions/stale@v5
with:
stale-issue-message: ""
close-issue-message: "This issue has been closed due to inactivity for 30 days. If you believe it is still relevant, you can reopen it (if you are the author) or leave a comment below."
close-issue-message: "This issue has been closed due to inactivity for 30 days. If you believe it is still relevant, please leave a comment below."
days-before-issue-stale: 30
days-before-issue-close: 0
stale-issue-label: "stale"

1
.gitignore vendored
View File

@ -19,3 +19,4 @@ repositories
settings.json
img_bot*
img_me*
prompts/[0-9]*

View File

@ -36,10 +36,32 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
## Installation
The recommended installation methods are the following:
### One-click installers
* Linux and MacOS: using conda natively.
* Windows: using conda on WSL ([WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/Windows-Subsystem-for-Linux-(Ubuntu)-Installation-Guide)).
[oobabooga-windows.zip](https://github.com/oobabooga/text-generation-webui/releases/download/installers/oobabooga-windows.zip)
Just download the zip above, extract it, and double click on "install". The web UI and all its dependencies will be installed in the same folder.
* To download a model, double click on "download-model"
* To start the web UI, double click on "start-webui"
Source codes: https://github.com/oobabooga/one-click-installers
> **Note**
>
> Thanks to [@jllllll](https://github.com/jllllll) and [@ClayShoaf](https://github.com/ClayShoaf), the Windows 1-click installer now sets up 8-bit and 4-bit requirements out of the box. No additional installation steps are necessary.
> **Note**
>
> There is no need to run the installer as admin.
### Manual installation using Conda
Recommended if you have some experience with the command-line.
On Windows, I additionally recommend carrying out the installation on WSL instead of the base system: [WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/Windows-Subsystem-for-Linux-(Ubuntu)-Installation-Guide).
#### 0. Install Conda
Conda can be downloaded here: https://docs.conda.io/en/latest/miniconda.html
@ -84,26 +106,10 @@ pip install -r requirements.txt
>
> For bitsandbytes and `--load-in-8bit` to work on Linux/WSL, this dirty fix is currently necessary: https://github.com/oobabooga/text-generation-webui/issues/400#issuecomment-1474876859
### Alternative: one-click installers
[oobabooga-windows.zip](https://github.com/oobabooga/one-click-installers/archive/refs/heads/oobabooga-windows.zip)
### Alternative: manual Windows installation
[oobabooga-linux.zip](https://github.com/oobabooga/one-click-installers/archive/refs/heads/oobabooga-linux.zip)
Just download the zip above, extract it, and double click on "install". The web UI and all its dependencies will be installed in the same folder.
* To download a model, double click on "download-model"
* To start the web UI, double click on "start-webui"
Source codes: https://github.com/oobabooga/one-click-installers
> **Note**
>
> To get 8-bit and 4-bit models working in your 1-click Windows installation, you can use the [one-click-bandaid](https://github.com/ClayShoaf/oobabooga-one-click-bandaid).
### Alternative: native Windows installation
As an alternative to the recommended WSL method, you can install the web UI natively on Windows using this guide. It will be a lot harder and the performance may be slower: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings).
As an alternative to the recommended WSL method, you can install the web UI natively on Windows using this guide. It will be a lot harder and the performance may be slower: [Windows installation guide](https://github.com/oobabooga/text-generation-webui/wiki/Windows-installation-guide).
### Alternative: Docker
@ -177,7 +183,7 @@ Optionally, you can use the following command-line flags:
| `--cpu` | Use the CPU to generate text.|
| `--load-in-8bit` | Load the model with 8-bit precision.|
| `--wbits WBITS` | GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. |
| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently only LLaMA and OPT are supported. |
| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. |
| `--groupsize GROUPSIZE` | GPTQ: Group size. |
| `--pre_layer PRE_LAYER` | GPTQ: The number of layers to preload. |
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
@ -199,11 +205,14 @@ Optionally, you can use the following command-line flags:
| `--no-stream` | Don't stream the text output in real time. |
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.|
| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
| `--listen` | Make the web UI reachable from your local network.|
| `--model-dir MODEL_DIR` | Path to directory with all the models |
| `--lora-dir LORA_DIR` | Path to directory with all the loras |
| `--verbose` | Print the prompts to the terminal. |
| `--listen` | Make the web UI reachable from your local network. |
| `--listen-port LISTEN_PORT` | The listening port that the server will use. |
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
| `--auto-launch` | Open the web UI in the default browser upon launch. |
| `--verbose` | Print the prompts to the terminal. |
| `--gradio-auth-path GRADIO_AUTH_PATH` | Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3" |
Out of memory errors? [Check the low VRAM guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide).

View File

@ -23,3 +23,16 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* {
.pending.svelte-1ed2p3z {
opacity: 1;
}
#extensions {
padding: 0;
padding: 0;
}
#gradio-chatbot {
height: 66.67vh;
}
.wrap.svelte-6roggh.svelte-6roggh {
max-height: 92.5%;
}

View File

@ -37,20 +37,29 @@
text-decoration: none !important;
}
svg {
display: unset !important;
vertical-align: middle !important;
margin: 5px;
}
ol li p, ul li p {
display: inline-block;
}
#main, #parameters, #chat-settings, #interface-mode, #lora {
#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab {
border: 0;
}
.gradio-container-3-18-0 .prose * h1, h2, h3, h4 {
color: white;
}
.gradio-container {
max-width: 100% !important;
padding-top: 0 !important;
}
#extensions {
padding: 15px;
padding: 15px;
}
span.math.inline {
font-size: 27px;
vertical-align: baseline !important;
}

View File

@ -11,7 +11,7 @@ let extensions = document.getElementById('extensions');
main_parent.addEventListener('click', function(e) {
// Check if the main element is visible
if (main.offsetHeight > 0 && main.offsetWidth > 0) {
extensions.style.display = 'block';
extensions.style.display = 'flex';
} else {
extensions.style.display = 'none';
}

View File

@ -8,38 +8,33 @@ python download-model.py facebook/opt-1.3b
import argparse
import base64
import datetime
import json
import multiprocessing
import re
import sys
from pathlib import Path
import requests
import tqdm
from tqdm.contrib.concurrent import thread_map
parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str, default=None, nargs='?')
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
args = parser.parse_args()
def get_file(args):
url = args[0]
output_folder = args[1]
idx = args[2]
tot = args[3]
print(f"Downloading file {idx} of {tot}...")
def get_file(url, output_folder):
r = requests.get(url, stream=True)
with open(output_folder / Path(url.split('/')[-1]), 'wb') as f:
with open(output_folder / Path(url.rsplit('/', 1)[1]), 'wb') as f:
total_size = int(r.headers.get('content-length', 0))
block_size = 1024
t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True)
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
for data in r.iter_content(block_size):
t.update(len(data))
f.write(data)
t.close()
def sanitize_branch_name(branch_name):
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
@ -98,8 +93,10 @@ def get_download_links_from_huggingface(model, branch):
cursor = b""
links = []
sha256 = []
classifications = []
has_pytorch = False
has_pt = False
has_safetensors = False
is_lora = False
while True:
@ -115,12 +112,14 @@ def get_download_links_from_huggingface(model, branch):
is_lora = True
is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
is_safetensors = re.match("model.*\.safetensors", fname)
is_safetensors = re.match(".*\.safetensors", fname)
is_pt = re.match(".*\.pt", fname)
is_tokenizer = re.match("tokenizer.*\.model", fname)
is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
if any((is_pytorch, is_safetensors, is_pt, is_tokenizer, is_text)):
if 'lfs' in dict[i]:
sha256.append([fname, dict[i]['lfs']['oid']])
if is_text:
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
classifications.append('text')
@ -134,6 +133,7 @@ def get_download_links_from_huggingface(model, branch):
has_pytorch = True
classifications.append('pytorch')
elif is_pt:
has_pt = True
classifications.append('pt')
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
@ -141,12 +141,15 @@ def get_download_links_from_huggingface(model, branch):
cursor = cursor.replace(b'=', b'%3D')
# If both pytorch and safetensors are available, download safetensors only
if has_pytorch and has_safetensors:
if (has_pytorch or has_pt) and has_safetensors:
for i in range(len(classifications)-1, -1, -1):
if classifications[i] == 'pytorch':
if classifications[i] in ['pytorch', 'pt']:
links.pop(i)
return links, is_lora
return links, sha256, is_lora
def download_files(file_list, output_folder, num_threads=8):
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads)
if __name__ == '__main__':
model = args.MODEL
@ -166,18 +169,32 @@ if __name__ == '__main__':
print(f"Error: {err_branch}")
sys.exit()
links, is_lora = get_download_links_from_huggingface(model, branch)
base_folder = 'models' if not is_lora else 'loras'
if branch != 'main':
output_folder = Path(base_folder) / (model.split('/')[-1] + f'_{branch}')
links, sha256, is_lora = get_download_links_from_huggingface(model, branch)
if args.output is not None:
base_folder = args.output
else:
output_folder = Path(base_folder) / model.split('/')[-1]
base_folder = 'models' if not is_lora else 'loras'
output_folder = f"{'_'.join(model.split('/')[-2:])}"
if branch != 'main':
output_folder += f'_{branch}'
# Creating the folder and writing the metadata
output_folder = Path(base_folder) / output_folder
if not output_folder.exists():
output_folder.mkdir()
with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
f.write(f'url: https://huggingface.co/{model}\n')
f.write(f'branch: {branch}\n')
f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
sha256_str = ''
for i in range(len(sha256)):
sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
if sha256_str != '':
f.write(f'sha256sum:\n{sha256_str}')
# Downloading the files
print(f"Downloading the model to {output_folder}")
pool = multiprocessing.Pool(processes=args.threads)
results = pool.map(get_file, [[links[i], output_folder, i+1, len(links)] for i in range(len(links))])
pool.close()
pool.join()
download_files(links, output_folder, args.threads)
print()

View File

@ -43,14 +43,14 @@ class Handler(BaseHTTPRequestHandler):
generator = generate_reply(
question = prompt,
max_new_tokens = body.get('max_length', 200),
max_new_tokens = int(body.get('max_length', 200)),
do_sample=True,
temperature=body.get('temperature', 0.5),
top_p=body.get('top_p', 1),
typical_p=body.get('typical', 1),
repetition_penalty=body.get('rep_pen', 1.1),
temperature=float(body.get('temperature', 0.5)),
top_p=float(body.get('top_p', 1)),
typical_p=float(body.get('typical', 1)),
repetition_penalty=float(body.get('rep_pen', 1.1)),
encoder_repetition_penalty=1,
top_k=body.get('top_k', 0),
top_k=int(body.get('top_k', 0)),
min_length=0,
no_repeat_ngram_size=0,
num_beams=1,
@ -62,6 +62,9 @@ class Handler(BaseHTTPRequestHandler):
answer = ''
for a in generator:
if isinstance(a, str):
answer = a
else:
answer = a[0]
response = json.dumps({

View File

@ -26,6 +26,7 @@ current_params = params.copy()
voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
# Used for making text xml compatible, needed for voice pitch and speed control
table = str.maketrans({
@ -77,6 +78,7 @@ def input_modifier(string):
shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>','controls>')]
shared.processing_message = "*Is recording a voice message...*"
shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated
return string
def output_modifier(string):
@ -84,7 +86,7 @@ def output_modifier(string):
This function is applied to the model outputs.
"""
global model, current_params
global model, current_params, streaming_state
for i in params:
if params[i] != current_params[i]:
@ -116,6 +118,7 @@ def output_modifier(string):
string += f'\n\n{original_string}'
shared.processing_message = "*Is typing...*"
shared.args.no_stream = streaming_state # restore the streaming option to the previous value
return string
def bot_prefix_modifier(string):

View File

@ -4,22 +4,60 @@ from pathlib import Path
import accelerate
import torch
import transformers
from transformers import AutoConfig, AutoModelForCausalLM
import modules.shared as shared
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
import llama
import llama_inference_offload
import opt
from modelutils import find_layers
from quant import make_quant
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
config = AutoConfig.from_pretrained(model)
def noop(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
for name in exclude_layers:
if name in layers:
del layers[name]
make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold)
del layers
print('Loading model ...')
if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint))
else:
model.load_state_dict(torch.load(checkpoint))
model.seqlen = 2048
print('Done.')
return model
def load_quantized(model_name):
if not shared.args.model_type:
# Try to determine model type from model name
if model_name.lower().startswith(('llama', 'alpaca')):
name = model_name.lower()
if any((k in name for k in ['llama', 'alpaca'])):
model_type = 'llama'
elif model_name.lower().startswith(('opt', 'galactica')):
elif any((k in name for k in ['opt-', 'galactica'])):
model_type = 'opt'
elif any((k in name for k in ['gpt-j', 'pygmalion-6b'])):
model_type = 'gptj'
else:
print("Can't determine model type from model name. Please specify it manually using --model_type "
"argument")
@ -27,15 +65,12 @@ def load_quantized(model_name):
else:
model_type = shared.args.model_type.lower()
if model_type == 'llama':
if not shared.args.pre_layer:
load_quant = llama.load_quant
else:
if model_type == 'llama' and shared.args.pre_layer:
load_quant = llama_inference_offload.load_quant
elif model_type == 'opt':
load_quant = opt.load_quant
elif model_type in ('llama', 'opt', 'gptj'):
load_quant = _load_quant
else:
print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported")
print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
exit()
# Now we are going to try to locate the quantized model file.
@ -75,7 +110,8 @@ def load_quantized(model_name):
if shared.args.pre_layer:
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
else:
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize)
threshold = False if model_type == 'gptj' else 128
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold)
# accelerate offload (doesn't work properly)
if shared.args.gpu_memory:

View File

@ -1,6 +1,7 @@
from pathlib import Path
import torch
from peft import PeftModel
import modules.shared as shared
from modules.models import load_model
@ -14,15 +15,13 @@ def reload_model():
def add_lora_to_model(lora_name):
from peft import PeftModel
# If a LoRA had been previously loaded, or if we want
# to unload a LoRA, reload the model
if shared.lora_name != "None" or lora_name == "None":
if shared.lora_name not in ['None', ''] or lora_name in ['None', '']:
reload_model()
shared.lora_name = lora_name
if lora_name != "None":
if lora_name not in ['None', '']:
print(f"Adding the LoRA {lora_name} to the model...")
params = {}
if not shared.args.cpu:
@ -32,7 +31,7 @@ def add_lora_to_model(lora_name):
elif shared.args.load_in_8bit:
params['device_map'] = {'': 0}
shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params)
shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_name}"), **params)
if not shared.args.load_in_8bit and not shared.args.cpu:
shared.model.half()
if not hasattr(shared.model, "hf_device_map"):

View File

@ -1,4 +1,5 @@
import gc
import traceback
from queue import Queue
from threading import Thread
@ -54,7 +55,7 @@ class Iteratorize:
self.stop_now = False
def _callback(val):
if self.stop_now:
if self.stop_now or shared.stop_everything:
raise ValueError
self.q.put(val)
@ -63,6 +64,10 @@ class Iteratorize:
ret = self.mfunc(callback=_callback, **self.kwargs)
except ValueError:
pass
except:
traceback.print_exc()
pass
clear_torch_cache()
self.q.put(self.sentinel)
if self.c_callback:

View File

@ -80,11 +80,7 @@ def extract_message_from_reply(reply, name1, name2, check):
reply = fix_newlines(reply)
return reply, next_character_found
def stop_everything_event():
shared.stop_everything = True
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
shared.stop_everything = False
just_started = True
eos_token = '\n' if check else None
name1_original = name1

View File

@ -7,7 +7,7 @@ import modules.shared as shared
state = {}
available_extensions = []
setup_called = False
setup_called = set()
def load_extensions():
global state
@ -53,18 +53,17 @@ def create_extensions_block():
should_display_ui = False
# Running setup function
if not setup_called:
for extension, name in iterator():
if hasattr(extension, "setup"):
extension.setup()
if hasattr(extension, "ui"):
should_display_ui = True
setup_called = True
if extension not in setup_called and hasattr(extension, "setup"):
setup_called.add(extension)
extension.setup()
# Creating the extension ui elements
if should_display_ui:
with gr.Box(elem_id="extensions"):
gr.Markdown("Extensions")
with gr.Column(elem_id="extensions"):
for extension, name in iterator():
gr.Markdown(f"\n### {name}")
if hasattr(extension, "ui"):
extension.ui()

View File

@ -34,7 +34,7 @@ def convert_to_markdown(string):
string = string.replace('\\begin{blockquote}', '> ')
string = string.replace('\\end{blockquote}', '')
string = re.sub(r"(.)```", r"\1\n```", string)
# string = fix_newlines(string)
string = fix_newlines(string)
return markdown.markdown(string, extensions=['fenced_code'])
def generate_basic_html(string):

View File

@ -41,14 +41,14 @@ def load_model(model_name):
print(f"Loading {model_name}...")
t0 = time.time()
shared.is_RWKV = model_name.lower().startswith('rwkv-')
shared.is_RWKV = 'rwkv-' in model_name.lower()
# Default settings
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), device_map='auto', load_in_8bit=True)
else:
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
if torch.has_mps:
device = torch.device('mps')
model = model.to(device)
@ -76,11 +76,11 @@ def load_model(model_name):
num_bits=4, group_size=64,
group_dim=2, symmetric=False))
model = OptLM(f"facebook/{shared.model_name}", env, "models", policy)
model = OptLM(f"facebook/{shared.model_name}", env, shared.args.model_dir, policy)
# DeepSpeed ZeRO-3
elif shared.args.deepspeed:
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
model.module.eval() # Inference
print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
@ -89,8 +89,8 @@ def load_model(model_name):
elif shared.is_RWKV:
from modules.RWKV import RWKVModel, RWKVTokenizer
model = RWKVModel.from_pretrained(Path(f'models/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
tokenizer = RWKVTokenizer.from_pretrained(Path('models'))
model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir))
return model, tokenizer
@ -142,7 +142,7 @@ def load_model(model_name):
if shared.args.disk:
params["offload_folder"] = shared.args.disk_cache_dir
checkpoint = Path(f'models/{shared.model_name}')
checkpoint = Path(f'{shared.args.model_dir}/{shared.model_name}')
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
config = AutoConfig.from_pretrained(checkpoint)
@ -159,10 +159,10 @@ def load_model(model_name):
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
# Loading the tokenizer
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/"))
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
else:
tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{shared.model_name}/"))
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
tokenizer.truncation_side = 'left'
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")

View File

@ -37,28 +37,23 @@ settings = {
'chat_generation_attempts': 1,
'chat_generation_attempts_min': 1,
'chat_generation_attempts_max': 5,
'name1_pygmalion': 'You',
'name2_pygmalion': 'Kawaii',
'context_pygmalion': "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n<START>",
'stop_at_newline_pygmalion': False,
'default_extensions': [],
'chat_default_extensions': ["gallery"],
'presets': {
'default': 'NovelAI-Sphinx Moth',
'(alpaca-*|llama-*)': "LLaMA-Precise",
'pygmalion-*': 'Pygmalion',
'RWKV-*': 'Naive',
'.*pygmalion': 'Pygmalion',
'.*RWKV': 'Naive',
},
'prompts': {
'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
'^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n',
'(rosey|chip|joi)_.*_instruct.*': 'User: \n',
'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>',
'alpaca-*': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n",
'default': 'QA',
'.*(gpt4chan|gpt-4chan|4chan)': 'GPT-4chan',
'.*oasst': 'Open Assistant',
'.*alpaca': "Alpaca",
},
'lora_prompts': {
'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
'(alpaca-lora-7b|alpaca-lora-13b|alpaca-lora-30b)': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
'default': 'QA',
'.*(alpaca-lora-7b|alpaca-lora-13b|alpaca-lora-30b)': "Alpaca",
}
}
@ -85,7 +80,7 @@ parser.add_argument('--gptq-bits', type=int, default=0, help='DEPRECATED: use --
parser.add_argument('--gptq-model-type', type=str, help='DEPRECATED: use --model_type instead.')
parser.add_argument('--gptq-pre-layer', type=int, default=0, help='DEPRECATED: use --pre_layer instead.')
parser.add_argument('--wbits', type=int, default=0, help='GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.')
parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently only LLaMA and OPT are supported.')
parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported.')
parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.')
parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to preload.')
@ -108,11 +103,14 @@ parser.add_argument('--rwkv-cuda-on', action='store_true', help='RWKV: Compile t
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.')
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.')
parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models")
parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to directory with all the loras")
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.')
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', default=None)
args = parser.parse_args()
# Provisional, this will be deleted later

View File

@ -42,7 +42,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
def decode(output_ids):
# Open Assistant relies on special tokens like <|endoftext|>
if re.match('(oasst|galactica)-*', shared.model_name.lower()):
if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
else:
reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
@ -77,10 +77,10 @@ def fix_galactica(s):
def formatted_outputs(reply, model_name):
if not (shared.args.chat or shared.args.cai_chat):
if model_name.lower().startswith('galactica'):
if 'galactica' in model_name.lower():
reply = fix_galactica(reply)
return reply, reply, generate_basic_html(reply)
elif model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
elif any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])):
reply = fix_gpt4chan(reply)
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
else:
@ -99,9 +99,13 @@ def set_manual_seed(seed):
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def stop_everything_event():
shared.stop_everything = True
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]):
clear_torch_cache()
set_manual_seed(seed)
shared.stop_everything = False
t0 = time.time()
original_question = question
@ -236,8 +240,6 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
break
yield formatted_outputs(reply, shared.model_name)
yield formatted_outputs(reply, shared.model_name)
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else:
for i in range(max_new_tokens//8+1):

275
modules/training.py Normal file
View File

@ -0,0 +1,275 @@
import json
import sys
import threading
import time
import traceback
from pathlib import Path
import gradio as gr
import torch
import transformers
from datasets import Dataset, load_dataset
from peft import (LoraConfig, get_peft_model, get_peft_model_state_dict,
prepare_model_for_int8_training)
from modules import shared, ui
WANT_INTERRUPT = False
CURRENT_STEPS = 0
MAX_STEPS = 0
CURRENT_GRADIENT_ACCUM = 1
def get_dataset(path: str, ext: str):
return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path(path).glob(f'*.{ext}'))), key=str.lower)
def create_train_interface():
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
lora_name = gr.Textbox(label="Name", info="The name of your new LoRA file")
with gr.Row():
# TODO: Implement multi-device support.
micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.')
batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.')
with gr.Row():
epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.')
learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='Learning rate, in scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.')
# TODO: What is the actual maximum rank? Likely distinct per model. This might be better to somehow be on a log scale.
lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='LoRA Rank, or dimension count. Higher values produce a larger file with better control over the model\'s content. Smaller values produce a smaller file with less overall control. Small values like 4 or 8 are great for stylistic guidance, high values like 128 or 256 are good for teaching content upgrades. Higher ranks also require higher VRAM.')
lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='LoRA Alpha. This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
# TODO: Better explain what this does, in terms of real world effect especially.
lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers.')
cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.')
with gr.Tab(label="Formatted Dataset"):
with gr.Row():
dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
ui.create_refresh_button(dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The dataset file used to evaluate the model after training.')
ui.create_refresh_button(eval_dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
ui.create_refresh_button(format, lambda : None, lambda : {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
with gr.Tab(label="Raw Text File"):
with gr.Row():
raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
ui.create_refresh_button(raw_text_file, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
overlap_len = gr.Slider(label='Overlap Length', minimum=0,maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length above). Setting overlap to exactly half the cutoff length may be ideal.')
with gr.Row():
start_button = gr.Button("Start LoRA Training")
stop_button = gr.Button("Interrupt")
output = gr.Markdown(value="Ready")
start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len], [output])
stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
def do_interrupt():
global WANT_INTERRUPT
WANT_INTERRUPT = True
class Callbacks(transformers.TrainerCallback):
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
global CURRENT_STEPS, MAX_STEPS
CURRENT_STEPS = state.global_step * CURRENT_GRADIENT_ACCUM
MAX_STEPS = state.max_steps * CURRENT_GRADIENT_ACCUM
if WANT_INTERRUPT:
control.should_epoch_stop = True
control.should_training_stop = True
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
global CURRENT_STEPS
CURRENT_STEPS += 1
if WANT_INTERRUPT:
control.should_epoch_stop = True
control.should_training_stop = True
def clean_path(base_path: str, path: str):
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
# Or swap it to a strict whitelist of [a-zA-Z_0-9]
path = path.replace('\\', '/').replace('..', '_')
if base_path is None:
return path
return f'{Path(base_path).absolute()}/{path}'
def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int,
lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int):
global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
WANT_INTERRUPT = False
CURRENT_STEPS = 0
MAX_STEPS = 0
# == Input validation / processing ==
yield "Prepping..."
lora_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}"
actual_lr = float(learning_rate)
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
yield f"Cannot input zeroes."
return
gradient_accumulation_steps = batch_size // micro_batch_size
CURRENT_GRADIENT_ACCUM = gradient_accumulation_steps
shared.tokenizer.pad_token = 0
shared.tokenizer.padding_side = "left"
def tokenize(prompt):
result = shared.tokenizer(prompt, truncation=True, max_length=cutoff_len + 1, padding="max_length")
return {
"input_ids": result["input_ids"][:-1],
"attention_mask": result["attention_mask"][:-1],
}
# == Prep the dataset, format, etc ==
if raw_text_file not in ['None', '']:
print("Loading raw text file dataset...")
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r') as file:
raw_text = file.read()
tokens = shared.tokenizer.encode(raw_text)
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
for i in range(1, len(tokens)):
tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
text_chunks = [shared.tokenizer.decode(x) for x in tokens]
del tokens
data = Dataset.from_list([tokenize(x) for x in text_chunks])
train_data = data.shuffle()
eval_data = None
del text_chunks
else:
if dataset in ['None', '']:
yield "**Missing dataset choice input, cannot continue.**"
return
if format in ['None', '']:
yield "**Missing format choice input, cannot continue.**"
return
with open(clean_path('training/formats', f'{format}.json'), 'r') as formatFile:
format_data: dict[str, str] = json.load(formatFile)
def generate_prompt(data_point: dict[str, str]):
for options, data in format_data.items():
if set(options.split(',')) == set(x[0] for x in data_point.items() if len(x[1].strip()) > 0):
for key, val in data_point.items():
data = data.replace(f'%{key}%', val)
return data
raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
def generate_and_tokenize_prompt(data_point):
prompt = generate_prompt(data_point)
return tokenize(prompt)
print("Loading JSON datasets...")
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
train_data = data['train'].shuffle().map(generate_and_tokenize_prompt)
if eval_dataset == 'None':
eval_data = None
else:
eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
eval_data = eval_data['train'].shuffle().map(generate_and_tokenize_prompt)
# == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
print("Getting model ready...")
prepare_model_for_int8_training(shared.model)
print("Prepping for training...")
config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
# TODO: Should target_modules be configurable?
target_modules=[ "q_proj", "v_proj" ],
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM"
)
try:
lora_model = get_peft_model(shared.model, config)
except:
yield traceback.format_exc()
return
trainer = transformers.Trainer(
model=lora_model,
train_dataset=train_data,
eval_dataset=eval_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
# TODO: Should more of these be configurable? Probably.
warmup_steps=100,
num_train_epochs=epochs,
learning_rate=actual_lr,
fp16=True,
logging_steps=20,
evaluation_strategy="steps" if eval_data is not None else "no",
save_strategy="steps",
eval_steps=200 if eval_data is not None else None,
save_steps=200,
output_dir=lora_name,
save_total_limit=3,
load_best_model_at_end=True if eval_data is not None else False,
# TODO: Enable multi-device support
ddp_find_unused_parameters=None
),
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
callbacks=list([Callbacks()])
)
lora_model.config.use_cache = False
old_state_dict = lora_model.state_dict
lora_model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(lora_model, type(lora_model))
if torch.__version__ >= "2" and sys.platform != "win32":
lora_model = torch.compile(lora_model)
# == Main run and monitor loop ==
# TODO: save/load checkpoints to resume from?
print("Starting training...")
yield "Starting..."
def threadedRun():
trainer.train()
thread = threading.Thread(target=threadedRun)
thread.start()
lastStep = 0
startTime = time.perf_counter()
while thread.is_alive():
time.sleep(0.5)
if WANT_INTERRUPT:
yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*"
elif CURRENT_STEPS != lastStep:
lastStep = CURRENT_STEPS
timeElapsed = time.perf_counter() - startTime
if timeElapsed <= 0:
timerInfo = ""
totalTimeEstimate = 999
else:
its = CURRENT_STEPS / timeElapsed
if its > 1:
timerInfo = f"`{its:.2f}` it/s"
else:
timerInfo = f"`{1.0/its:.2f}` s/it"
totalTimeEstimate = (1.0/its) * (MAX_STEPS)
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds"
print("Training complete, saving...")
lora_model.save_pretrained(lora_name)
if WANT_INTERRUPT:
print("Training interrupted.")
yield f"Interrupted. Incomplete LoRA saved to `{lora_name}`"
else:
print("Training complete!")
yield f"Done! LoRA saved to `{lora_name}`"
def split_chunks(arr, step):
for i in range(0, len(arr), step):
yield arr[i:i + step]

6
prompts/Alpaca.txt Normal file
View File

@ -0,0 +1,6 @@
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Write a poem about the transformers Python library.
Mention the word "large language models" in that poem.
### Response:

6
prompts/GPT-4chan.txt Normal file
View File

@ -0,0 +1,6 @@
-----
--- 865467536
Hello, AI frens!
How are you doing on this fine day?
--- 865467537

View File

@ -0,0 +1 @@
<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>

4
prompts/QA.txt Normal file
View File

@ -0,0 +1,4 @@
Common sense questions and answers
Question:
Factual answer:

View File

@ -1,13 +1,14 @@
accelerate==0.17.1
bitsandbytes==0.37.1
accelerate==0.18.0
bitsandbytes==0.37.2
flexgen==0.1.7
gradio==3.18.0
gradio==3.23.0
markdown
numpy
peft==0.2.0
requests
rwkv==0.7.0
rwkv==0.7.1
safetensors==0.3.0
sentencepiece
tqdm
datasets
git+https://github.com/huggingface/transformers

151
server.py
View File

@ -4,18 +4,18 @@ import re
import sys
import time
import zipfile
from datetime import datetime
from pathlib import Path
import gradio as gr
import modules.chat as chat
import modules.extensions as extensions_module
import modules.shared as shared
import modules.ui as ui
from modules import chat, shared, training, ui
from modules.html_generator import generate_chat_html
from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt
from modules.text_generation import clear_torch_cache, generate_reply
from modules.text_generation import (clear_torch_cache, generate_reply,
stop_everything_event)
# Loading custom settings
settings_file = None
@ -31,13 +31,20 @@ if settings_file is not None:
def get_available_models():
if shared.args.flexgen:
return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower)
return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower)
else:
return sorted([re.sub('.pth$', '', item.name) for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
def get_available_presets():
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
def get_available_prompts():
prompts = []
prompts += sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
prompts += sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('prompts').glob('*.txt'))), key=str.lower)
prompts += ['None']
return prompts
def get_available_characters():
return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
@ -48,22 +55,25 @@ def get_available_softprompts():
return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
def get_available_loras():
return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
def unload_model():
shared.model = shared.tokenizer = None
clear_torch_cache()
def load_model_wrapper(selected_model):
if selected_model != shared.model_name:
shared.model_name = selected_model
shared.model = shared.tokenizer = None
clear_torch_cache()
unload_model()
if selected_model != '':
shared.model, shared.tokenizer = load_model(shared.model_name)
return selected_model
def load_lora_wrapper(selected_lora):
add_lora_to_model(selected_lora)
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
return selected_lora, default_text
return selected_lora
def load_preset_values(preset_menu, return_dict=False):
generate_params = {
@ -93,7 +103,7 @@ def load_preset_values(preset_menu, return_dict=False):
if return_dict:
return generate_params
else:
return preset_menu, generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
def upload_soft_prompt(file):
with zipfile.ZipFile(io.BytesIO(file)) as zf:
@ -118,9 +128,46 @@ def create_model_and_preset_menus():
shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
def save_prompt(text):
fname = f"{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.txt"
with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
f.write(text)
return f"Saved to prompts/{fname}"
def load_prompt(fname):
if fname in ['None', '']:
return ''
else:
with open(Path(f'prompts/{fname}.txt'), 'r', encoding='utf-8') as f:
text = f.read()
if text[-1] == '\n':
text = text[:-1]
return text
def create_prompt_menus():
with gr.Row():
with gr.Column():
with gr.Row():
shared.gradio['prompt_menu'] = gr.Dropdown(choices=get_available_prompts(), value='None', label='Prompt')
ui.create_refresh_button(shared.gradio['prompt_menu'], lambda : None, lambda : {'choices': get_available_prompts()}, 'refresh-button')
with gr.Column():
with gr.Column():
shared.gradio['save_prompt'] = gr.Button('Save prompt')
shared.gradio['status'] = gr.Markdown('Ready')
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
with gr.Row():
with gr.Column():
create_model_and_preset_menus()
with gr.Column():
shared.gradio['seed'] = gr.Number(value=-1, label='Seed (-1 for random)')
with gr.Row():
with gr.Column():
with gr.Box():
@ -151,12 +198,6 @@ def create_settings_menus(default_preset):
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
shared.gradio['seed'] = gr.Number(value=-1, label='Seed (-1 for random)')
with gr.Row():
shared.gradio['preset_menu_mirror'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
ui.create_refresh_button(shared.gradio['preset_menu_mirror'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
with gr.Row():
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
@ -171,9 +212,8 @@ def create_settings_menus(default_preset):
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['preset_menu_mirror', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
shared.gradio['preset_menu_mirror'].change(load_preset_values, [shared.gradio['preset_menu_mirror']], [shared.gradio[k] for k in ['preset_menu', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu'], shared.gradio['textbox']], show_progress=True)
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu']], show_progress=True)
shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
@ -238,12 +278,10 @@ if shared.args.lora:
# Default UI settings
default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
if shared.lora_name != "None":
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
default_text = load_prompt(shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')])
else:
default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
title ='Text generation web UI'
description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
def create_interface():
@ -255,13 +293,13 @@ def create_interface():
if shared.args.chat or shared.args.cai_chat:
with gr.Tab("Text generation", elem_id="main"):
if shared.args.cai_chat:
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], shared.character))
else:
shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528"))
shared.gradio['display'] = gr.Chatbot(value=shared.history['visible'], elem_id="gradio-chatbot")
shared.gradio['textbox'] = gr.Textbox(label='Input')
with gr.Row():
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
shared.gradio['Generate'] = gr.Button('Generate')
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
with gr.Row():
shared.gradio['Impersonate'] = gr.Button('Impersonate')
shared.gradio['Regenerate'] = gr.Button('Regenerate')
@ -274,12 +312,10 @@ def create_interface():
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
create_model_and_preset_menus()
with gr.Tab("Character", elem_id="chat-settings"):
shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context')
shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Bot\'s name')
shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=5, label='Context')
with gr.Row():
shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
@ -317,7 +353,7 @@ def create_interface():
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
with gr.Column():
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
shared.gradio['check'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
create_settings_menus(default_preset)
@ -328,7 +364,7 @@ def create_interface():
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events, queue=False)
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
@ -364,24 +400,34 @@ def create_interface():
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
elif shared.args.notebook:
with gr.Tab("Text generation", elem_id="main"):
with gr.Row():
with gr.Column(scale=4):
with gr.Tab('Raw'):
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=25)
shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_id="textbox", lines=25)
with gr.Tab('Markdown'):
shared.gradio['markdown'] = gr.Markdown()
with gr.Tab('HTML'):
shared.gradio['html'] = gr.HTML()
with gr.Row():
shared.gradio['Stop'] = gr.Button('Stop')
with gr.Column():
with gr.Row():
shared.gradio['Generate'] = gr.Button('Generate')
shared.gradio['Stop'] = gr.Button('Stop')
with gr.Column():
pass
with gr.Column(scale=1):
gr.HTML('<div style="padding-bottom: 13px"></div>')
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
create_model_and_preset_menus()
create_prompt_menus()
with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset)
@ -389,7 +435,7 @@ def create_interface():
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
else:
@ -405,7 +451,7 @@ def create_interface():
with gr.Column():
shared.gradio['Stop'] = gr.Button('Stop')
create_model_and_preset_menus()
create_prompt_menus()
with gr.Column():
with gr.Tab('Raw'):
@ -414,6 +460,7 @@ def create_interface():
shared.gradio['markdown'] = gr.Markdown()
with gr.Tab('HTML'):
shared.gradio['html'] = gr.HTML()
with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset)
@ -422,9 +469,12 @@ def create_interface():
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
with gr.Tab("Training", elem_id="training-tab"):
training.create_train_interface()
with gr.Tab("Interface mode", elem_id="interface-mode"):
modes = ["default", "notebook", "chat", "cai_chat"]
current_mode = "default"
@ -443,17 +493,26 @@ def create_interface():
shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", type="primary")
shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'cmd_arguments_menu']], None)
shared.gradio['reset_interface'].click(lambda : None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500)}')
shared.gradio['reset_interface'].click(lambda : None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
if shared.args.extensions is not None:
extensions_module.create_extensions_block()
# Authentication
auth = None
if shared.args.gradio_auth_path is not None:
gradio_auth_creds = []
with open(shared.args.gradio_auth_path, 'r', encoding="utf8") as file:
for line in file.readlines():
gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
auth = [tuple(cred.split(':')) for cred in gradio_auth_creds]
# Launch the interface
shared.gradio['interface'].queue()
if shared.args.listen:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
else:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
create_interface()

View File

@ -12,27 +12,23 @@
"chat_generation_attempts": 1,
"chat_generation_attempts_min": 1,
"chat_generation_attempts_max": 5,
"name1_pygmalion": "You",
"name2_pygmalion": "Kawaii",
"context_pygmalion": "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n<START>",
"stop_at_newline_pygmalion": false,
"default_extensions": [],
"chat_default_extensions": [
"gallery"
],
"presets": {
"default": "NovelAI-Sphinx Moth",
"pygmalion-*": "Pygmalion",
"RWKV-*": "Naive"
".*pygmalion": "Pygmalion",
".*RWKV": "Naive"
},
"prompts": {
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
"^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n",
"(rosey|chip|joi)_.*_instruct.*": "User: \n",
"oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>"
"default": "QA",
".*(gpt4chan|gpt-4chan|4chan)": "GPT-4chan",
".*oasst": "Open Assistant",
".*alpaca": "Alpaca"
},
"lora_prompts": {
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
"alpaca-lora-7b": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
"default": "QA",
".*(alpaca-lora-7b|alpaca-lora-13b|alpaca-lora-30b)": "Alpaca"
}
}

View File

@ -0,0 +1,4 @@
{
"instruction,output": "User: %instruction%\nAssistant: %output%",
"instruction,input,output": "User: %instruction%: %input%\nAssistant: %output%"
}

View File

@ -0,0 +1,4 @@
{
"instruction,output": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n%instruction%\n\n### Response:\n%output%",
"instruction,input,output": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n%instruction%\n\n### Input:\n%input%\n\n### Response:\n%output%"
}