diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 82cd1701..ce603a4f 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -13,7 +13,7 @@ jobs: - uses: actions/stale@v5 with: stale-issue-message: "" - close-issue-message: "This issue has been closed due to inactivity for 30 days. If you believe it is still relevant, you can reopen it (if you are the author) or leave a comment below." + close-issue-message: "This issue has been closed due to inactivity for 30 days. If you believe it is still relevant, please leave a comment below." days-before-issue-stale: 30 days-before-issue-close: 0 stale-issue-label: "stale" diff --git a/.gitignore b/.gitignore index 36852916..a9c47a5a 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ repositories settings.json img_bot* img_me* +prompts/[0-9]* diff --git a/README.md b/README.md index 3bfbc72f..97f26ccb 100644 --- a/README.md +++ b/README.md @@ -36,10 +36,32 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github. ## Installation -The recommended installation methods are the following: +### One-click installers -* Linux and MacOS: using conda natively. -* Windows: using conda on WSL ([WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/Windows-Subsystem-for-Linux-(Ubuntu)-Installation-Guide)). +[oobabooga-windows.zip](https://github.com/oobabooga/text-generation-webui/releases/download/installers/oobabooga-windows.zip) + +Just download the zip above, extract it, and double click on "install". The web UI and all its dependencies will be installed in the same folder. + +* To download a model, double click on "download-model" +* To start the web UI, double click on "start-webui" + +Source codes: https://github.com/oobabooga/one-click-installers + +> **Note** +> +> Thanks to [@jllllll](https://github.com/jllllll) and [@ClayShoaf](https://github.com/ClayShoaf), the Windows 1-click installer now sets up 8-bit and 4-bit requirements out of the box. No additional installation steps are necessary. + +> **Note** +> +> There is no need to run the installer as admin. + +### Manual installation using Conda + +Recommended if you have some experience with the command-line. + +On Windows, I additionally recommend carrying out the installation on WSL instead of the base system: [WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/Windows-Subsystem-for-Linux-(Ubuntu)-Installation-Guide). + +#### 0. Install Conda Conda can be downloaded here: https://docs.conda.io/en/latest/miniconda.html @@ -84,26 +106,10 @@ pip install -r requirements.txt > > For bitsandbytes and `--load-in-8bit` to work on Linux/WSL, this dirty fix is currently necessary: https://github.com/oobabooga/text-generation-webui/issues/400#issuecomment-1474876859 -### Alternative: one-click installers -[oobabooga-windows.zip](https://github.com/oobabooga/one-click-installers/archive/refs/heads/oobabooga-windows.zip) +### Alternative: manual Windows installation -[oobabooga-linux.zip](https://github.com/oobabooga/one-click-installers/archive/refs/heads/oobabooga-linux.zip) - -Just download the zip above, extract it, and double click on "install". The web UI and all its dependencies will be installed in the same folder. - -* To download a model, double click on "download-model" -* To start the web UI, double click on "start-webui" - -Source codes: https://github.com/oobabooga/one-click-installers - -> **Note** -> -> To get 8-bit and 4-bit models working in your 1-click Windows installation, you can use the [one-click-bandaid](https://github.com/ClayShoaf/oobabooga-one-click-bandaid). - -### Alternative: native Windows installation - -As an alternative to the recommended WSL method, you can install the web UI natively on Windows using this guide. It will be a lot harder and the performance may be slower: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings). +As an alternative to the recommended WSL method, you can install the web UI natively on Windows using this guide. It will be a lot harder and the performance may be slower: [Windows installation guide](https://github.com/oobabooga/text-generation-webui/wiki/Windows-installation-guide). ### Alternative: Docker @@ -177,7 +183,7 @@ Optionally, you can use the following command-line flags: | `--cpu` | Use the CPU to generate text.| | `--load-in-8bit` | Load the model with 8-bit precision.| | `--wbits WBITS` | GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. | -| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently only LLaMA and OPT are supported. | +| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. | | `--groupsize GROUPSIZE` | GPTQ: Group size. | | `--pre_layer PRE_LAYER` | GPTQ: The number of layers to preload. | | `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. | @@ -198,12 +204,15 @@ Optionally, you can use the following command-line flags: | `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. | | `--no-stream` | Don't stream the text output in real time. | | `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.| -| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. | -| `--listen` | Make the web UI reachable from your local network.| -| `--listen-port LISTEN_PORT` | The listening port that the server will use. | -| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. | -| `--auto-launch` | Open the web UI in the default browser upon launch. | -| `--verbose` | Print the prompts to the terminal. | +| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. | +| `--model-dir MODEL_DIR` | Path to directory with all the models | +| `--lora-dir LORA_DIR` | Path to directory with all the loras | +| `--verbose` | Print the prompts to the terminal. | +| `--listen` | Make the web UI reachable from your local network. | +| `--listen-port LISTEN_PORT` | The listening port that the server will use. | +| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. | +| `--auto-launch` | Open the web UI in the default browser upon launch. | +| `--gradio-auth-path GRADIO_AUTH_PATH` | Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3" | Out of memory errors? [Check the low VRAM guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide). diff --git a/css/chat.css b/css/chat.css index 8d9d88a6..c8a9d70a 100644 --- a/css/chat.css +++ b/css/chat.css @@ -23,3 +23,16 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* { .pending.svelte-1ed2p3z { opacity: 1; } + +#extensions { + padding: 0; + padding: 0; +} + +#gradio-chatbot { + height: 66.67vh; +} + +.wrap.svelte-6roggh.svelte-6roggh { + max-height: 92.5%; +} diff --git a/css/main.css b/css/main.css index 09f3b6a8..6aa3bc1a 100644 --- a/css/main.css +++ b/css/main.css @@ -37,20 +37,29 @@ text-decoration: none !important; } -svg { - display: unset !important; - vertical-align: middle !important; - margin: 5px; -} - ol li p, ul li p { display: inline-block; } -#main, #parameters, #chat-settings, #interface-mode, #lora { +#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab { border: 0; } .gradio-container-3-18-0 .prose * h1, h2, h3, h4 { color: white; } + +.gradio-container { + max-width: 100% !important; + padding-top: 0 !important; +} + +#extensions { + padding: 15px; + padding: 15px; +} + +span.math.inline { + font-size: 27px; + vertical-align: baseline !important; +} diff --git a/css/main.js b/css/main.js index 9db3fe8b..029ecb62 100644 --- a/css/main.js +++ b/css/main.js @@ -11,7 +11,7 @@ let extensions = document.getElementById('extensions'); main_parent.addEventListener('click', function(e) { // Check if the main element is visible if (main.offsetHeight > 0 && main.offsetWidth > 0) { - extensions.style.display = 'block'; + extensions.style.display = 'flex'; } else { extensions.style.display = 'none'; } diff --git a/download-model.py b/download-model.py index 25386e5f..7e5f61b2 100644 --- a/download-model.py +++ b/download-model.py @@ -8,38 +8,33 @@ python download-model.py facebook/opt-1.3b import argparse import base64 +import datetime import json -import multiprocessing import re import sys from pathlib import Path import requests import tqdm +from tqdm.contrib.concurrent import thread_map parser = argparse.ArgumentParser() parser.add_argument('MODEL', type=str, default=None, nargs='?') parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.') parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') +parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.') args = parser.parse_args() -def get_file(args): - url = args[0] - output_folder = args[1] - idx = args[2] - tot = args[3] - - print(f"Downloading file {idx} of {tot}...") +def get_file(url, output_folder): r = requests.get(url, stream=True) - with open(output_folder / Path(url.split('/')[-1]), 'wb') as f: + with open(output_folder / Path(url.rsplit('/', 1)[1]), 'wb') as f: total_size = int(r.headers.get('content-length', 0)) block_size = 1024 - t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True) - for data in r.iter_content(block_size): - t.update(len(data)) - f.write(data) - t.close() + with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t: + for data in r.iter_content(block_size): + t.update(len(data)) + f.write(data) def sanitize_branch_name(branch_name): pattern = re.compile(r"^[a-zA-Z0-9._-]+$") @@ -98,8 +93,10 @@ def get_download_links_from_huggingface(model, branch): cursor = b"" links = [] + sha256 = [] classifications = [] has_pytorch = False + has_pt = False has_safetensors = False is_lora = False while True: @@ -115,12 +112,14 @@ def get_download_links_from_huggingface(model, branch): is_lora = True is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname) - is_safetensors = re.match("model.*\.safetensors", fname) + is_safetensors = re.match(".*\.safetensors", fname) is_pt = re.match(".*\.pt", fname) is_tokenizer = re.match("tokenizer.*\.model", fname) is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer if any((is_pytorch, is_safetensors, is_pt, is_tokenizer, is_text)): + if 'lfs' in dict[i]: + sha256.append([fname, dict[i]['lfs']['oid']]) if is_text: links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}") classifications.append('text') @@ -134,6 +133,7 @@ def get_download_links_from_huggingface(model, branch): has_pytorch = True classifications.append('pytorch') elif is_pt: + has_pt = True classifications.append('pt') cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50' @@ -141,12 +141,15 @@ def get_download_links_from_huggingface(model, branch): cursor = cursor.replace(b'=', b'%3D') # If both pytorch and safetensors are available, download safetensors only - if has_pytorch and has_safetensors: + if (has_pytorch or has_pt) and has_safetensors: for i in range(len(classifications)-1, -1, -1): - if classifications[i] == 'pytorch': + if classifications[i] in ['pytorch', 'pt']: links.pop(i) - return links, is_lora + return links, sha256, is_lora + +def download_files(file_list, output_folder, num_threads=8): + thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads) if __name__ == '__main__': model = args.MODEL @@ -166,18 +169,32 @@ if __name__ == '__main__': print(f"Error: {err_branch}") sys.exit() - links, is_lora = get_download_links_from_huggingface(model, branch) - base_folder = 'models' if not is_lora else 'loras' - if branch != 'main': - output_folder = Path(base_folder) / (model.split('/')[-1] + f'_{branch}') + links, sha256, is_lora = get_download_links_from_huggingface(model, branch) + + if args.output is not None: + base_folder = args.output else: - output_folder = Path(base_folder) / model.split('/')[-1] + base_folder = 'models' if not is_lora else 'loras' + + output_folder = f"{'_'.join(model.split('/')[-2:])}" + if branch != 'main': + output_folder += f'_{branch}' + + # Creating the folder and writing the metadata + output_folder = Path(base_folder) / output_folder if not output_folder.exists(): output_folder.mkdir() + with open(output_folder / 'huggingface-metadata.txt', 'w') as f: + f.write(f'url: https://huggingface.co/{model}\n') + f.write(f'branch: {branch}\n') + f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n') + sha256_str = '' + for i in range(len(sha256)): + sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n' + if sha256_str != '': + f.write(f'sha256sum:\n{sha256_str}') # Downloading the files print(f"Downloading the model to {output_folder}") - pool = multiprocessing.Pool(processes=args.threads) - results = pool.map(get_file, [[links[i], output_folder, i+1, len(links)] for i in range(len(links))]) - pool.close() - pool.join() + download_files(links, output_folder, args.threads) + print() diff --git a/extensions/api/script.py b/extensions/api/script.py index bd7c1900..dd48f58f 100644 --- a/extensions/api/script.py +++ b/extensions/api/script.py @@ -43,14 +43,14 @@ class Handler(BaseHTTPRequestHandler): generator = generate_reply( question = prompt, - max_new_tokens = body.get('max_length', 200), + max_new_tokens = int(body.get('max_length', 200)), do_sample=True, - temperature=body.get('temperature', 0.5), - top_p=body.get('top_p', 1), - typical_p=body.get('typical', 1), - repetition_penalty=body.get('rep_pen', 1.1), + temperature=float(body.get('temperature', 0.5)), + top_p=float(body.get('top_p', 1)), + typical_p=float(body.get('typical', 1)), + repetition_penalty=float(body.get('rep_pen', 1.1)), encoder_repetition_penalty=1, - top_k=body.get('top_k', 0), + top_k=int(body.get('top_k', 0)), min_length=0, no_repeat_ngram_size=0, num_beams=1, @@ -62,7 +62,10 @@ class Handler(BaseHTTPRequestHandler): answer = '' for a in generator: - answer = a[0] + if isinstance(a, str): + answer = a + else: + answer = a[0] response = json.dumps({ 'results': [{ diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py index a81a5da1..1352993a 100644 --- a/extensions/silero_tts/script.py +++ b/extensions/silero_tts/script.py @@ -26,6 +26,7 @@ current_params = params.copy() voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115'] voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high'] voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast'] +streaming_state = shared.args.no_stream # remember if chat streaming was enabled # Used for making text xml compatible, needed for voice pitch and speed control table = str.maketrans({ @@ -77,6 +78,7 @@ def input_modifier(string): shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>','controls>')] shared.processing_message = "*Is recording a voice message...*" + shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated return string def output_modifier(string): @@ -84,7 +86,7 @@ def output_modifier(string): This function is applied to the model outputs. """ - global model, current_params + global model, current_params, streaming_state for i in params: if params[i] != current_params[i]: @@ -116,6 +118,7 @@ def output_modifier(string): string += f'\n\n{original_string}' shared.processing_message = "*Is typing...*" + shared.args.no_stream = streaming_state # restore the streaming option to the previous value return string def bot_prefix_modifier(string): diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index afb5695f..e7877de7 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -4,22 +4,60 @@ from pathlib import Path import accelerate import torch +import transformers +from transformers import AutoConfig, AutoModelForCausalLM import modules.shared as shared sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa"))) -import llama import llama_inference_offload -import opt +from modelutils import find_layers +from quant import make_quant +def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128): + config = AutoConfig.from_pretrained(model) + def noop(*args, **kwargs): + pass + torch.nn.init.kaiming_uniform_ = noop + torch.nn.init.uniform_ = noop + torch.nn.init.normal_ = noop + + torch.set_default_dtype(torch.half) + transformers.modeling_utils._init_weights = False + torch.set_default_dtype(torch.half) + model = AutoModelForCausalLM.from_config(config) + torch.set_default_dtype(torch.float) + model = model.eval() + layers = find_layers(model) + for name in exclude_layers: + if name in layers: + del layers[name] + make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold) + + del layers + + print('Loading model ...') + if checkpoint.endswith('.safetensors'): + from safetensors.torch import load_file as safe_load + model.load_state_dict(safe_load(checkpoint)) + else: + model.load_state_dict(torch.load(checkpoint)) + model.seqlen = 2048 + print('Done.') + + return model + def load_quantized(model_name): if not shared.args.model_type: # Try to determine model type from model name - if model_name.lower().startswith(('llama', 'alpaca')): + name = model_name.lower() + if any((k in name for k in ['llama', 'alpaca'])): model_type = 'llama' - elif model_name.lower().startswith(('opt', 'galactica')): + elif any((k in name for k in ['opt-', 'galactica'])): model_type = 'opt' + elif any((k in name for k in ['gpt-j', 'pygmalion-6b'])): + model_type = 'gptj' else: print("Can't determine model type from model name. Please specify it manually using --model_type " "argument") @@ -27,15 +65,12 @@ def load_quantized(model_name): else: model_type = shared.args.model_type.lower() - if model_type == 'llama': - if not shared.args.pre_layer: - load_quant = llama.load_quant - else: - load_quant = llama_inference_offload.load_quant - elif model_type == 'opt': - load_quant = opt.load_quant + if model_type == 'llama' and shared.args.pre_layer: + load_quant = llama_inference_offload.load_quant + elif model_type in ('llama', 'opt', 'gptj'): + load_quant = _load_quant else: - print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported") + print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported") exit() # Now we are going to try to locate the quantized model file. @@ -75,7 +110,8 @@ def load_quantized(model_name): if shared.args.pre_layer: model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer) else: - model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize) + threshold = False if model_type == 'gptj' else 128 + model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold) # accelerate offload (doesn't work properly) if shared.args.gpu_memory: diff --git a/modules/LoRA.py b/modules/LoRA.py index 283fcf4c..8c30e609 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -1,6 +1,7 @@ from pathlib import Path import torch +from peft import PeftModel import modules.shared as shared from modules.models import load_model @@ -14,15 +15,13 @@ def reload_model(): def add_lora_to_model(lora_name): - from peft import PeftModel - # If a LoRA had been previously loaded, or if we want # to unload a LoRA, reload the model - if shared.lora_name != "None" or lora_name == "None": + if shared.lora_name not in ['None', ''] or lora_name in ['None', '']: reload_model() shared.lora_name = lora_name - if lora_name != "None": + if lora_name not in ['None', '']: print(f"Adding the LoRA {lora_name} to the model...") params = {} if not shared.args.cpu: @@ -32,7 +31,7 @@ def add_lora_to_model(lora_name): elif shared.args.load_in_8bit: params['device_map'] = {'': 0} - shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params) + shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_name}"), **params) if not shared.args.load_in_8bit and not shared.args.cpu: shared.model.half() if not hasattr(shared.model, "hf_device_map"): diff --git a/modules/callbacks.py b/modules/callbacks.py index 8d30d615..aa92f9cb 100644 --- a/modules/callbacks.py +++ b/modules/callbacks.py @@ -1,4 +1,5 @@ import gc +import traceback from queue import Queue from threading import Thread @@ -54,7 +55,7 @@ class Iteratorize: self.stop_now = False def _callback(val): - if self.stop_now: + if self.stop_now or shared.stop_everything: raise ValueError self.q.put(val) @@ -63,6 +64,10 @@ class Iteratorize: ret = self.mfunc(callback=_callback, **self.kwargs) except ValueError: pass + except: + traceback.print_exc() + pass + clear_torch_cache() self.q.put(self.sentinel) if self.c_callback: diff --git a/modules/chat.py b/modules/chat.py index 1a43cf3d..cc3c45c7 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -80,11 +80,7 @@ def extract_message_from_reply(reply, name1, name2, check): reply = fix_newlines(reply) return reply, next_character_found -def stop_everything_event(): - shared.stop_everything = True - def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False): - shared.stop_everything = False just_started = True eos_token = '\n' if check else None name1_original = name1 diff --git a/modules/extensions.py b/modules/extensions.py index c55dc978..fe6a3945 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -7,7 +7,7 @@ import modules.shared as shared state = {} available_extensions = [] -setup_called = False +setup_called = set() def load_extensions(): global state @@ -53,18 +53,17 @@ def create_extensions_block(): should_display_ui = False # Running setup function - if not setup_called: - for extension, name in iterator(): - if hasattr(extension, "setup"): - extension.setup() - if hasattr(extension, "ui"): - should_display_ui = True - setup_called = True + for extension, name in iterator(): + if hasattr(extension, "ui"): + should_display_ui = True + if extension not in setup_called and hasattr(extension, "setup"): + setup_called.add(extension) + extension.setup() # Creating the extension ui elements if should_display_ui: - with gr.Box(elem_id="extensions"): - gr.Markdown("Extensions") + with gr.Column(elem_id="extensions"): for extension, name in iterator(): + gr.Markdown(f"\n### {name}") if hasattr(extension, "ui"): extension.ui() diff --git a/modules/html_generator.py b/modules/html_generator.py index ff18c913..48d2e02e 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -34,7 +34,7 @@ def convert_to_markdown(string): string = string.replace('\\begin{blockquote}', '> ') string = string.replace('\\end{blockquote}', '') string = re.sub(r"(.)```", r"\1\n```", string) -# string = fix_newlines(string) + string = fix_newlines(string) return markdown.markdown(string, extensions=['fenced_code']) def generate_basic_html(string): diff --git a/modules/models.py b/modules/models.py index c9f03588..b19507db 100644 --- a/modules/models.py +++ b/modules/models.py @@ -41,14 +41,14 @@ def load_model(model_name): print(f"Loading {model_name}...") t0 = time.time() - shared.is_RWKV = model_name.lower().startswith('rwkv-') + shared.is_RWKV = 'rwkv-' in model_name.lower() # Default settings if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]): if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')): - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True) + model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), device_map='auto', load_in_8bit=True) else: - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) + model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) if torch.has_mps: device = torch.device('mps') model = model.to(device) @@ -76,11 +76,11 @@ def load_model(model_name): num_bits=4, group_size=64, group_dim=2, symmetric=False)) - model = OptLM(f"facebook/{shared.model_name}", env, "models", policy) + model = OptLM(f"facebook/{shared.model_name}", env, shared.args.model_dir, policy) # DeepSpeed ZeRO-3 elif shared.args.deepspeed: - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) + model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0] model.module.eval() # Inference print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}") @@ -89,8 +89,8 @@ def load_model(model_name): elif shared.is_RWKV: from modules.RWKV import RWKVModel, RWKVTokenizer - model = RWKVModel.from_pretrained(Path(f'models/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda") - tokenizer = RWKVTokenizer.from_pretrained(Path('models')) + model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda") + tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir)) return model, tokenizer @@ -142,7 +142,7 @@ def load_model(model_name): if shared.args.disk: params["offload_folder"] = shared.args.disk_cache_dir - checkpoint = Path(f'models/{shared.model_name}') + checkpoint = Path(f'{shared.args.model_dir}/{shared.model_name}') if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto': config = AutoConfig.from_pretrained(checkpoint) @@ -159,10 +159,10 @@ def load_model(model_name): model = AutoModelForCausalLM.from_pretrained(checkpoint, **params) # Loading the tokenizer - if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists(): - tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/")) + if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists(): + tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/")) else: - tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{shared.model_name}/")) + tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/")) tokenizer.truncation_side = 'left' print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") diff --git a/modules/shared.py b/modules/shared.py index 6d34fd69..a2e24219 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -37,28 +37,23 @@ settings = { 'chat_generation_attempts': 1, 'chat_generation_attempts_min': 1, 'chat_generation_attempts_max': 5, - 'name1_pygmalion': 'You', - 'name2_pygmalion': 'Kawaii', - 'context_pygmalion': "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n", - 'stop_at_newline_pygmalion': False, 'default_extensions': [], 'chat_default_extensions': ["gallery"], 'presets': { 'default': 'NovelAI-Sphinx Moth', '(alpaca-*|llama-*)': "LLaMA-Precise", - 'pygmalion-*': 'Pygmalion', - 'RWKV-*': 'Naive', + '.*pygmalion': 'Pygmalion', + '.*RWKV': 'Naive', }, 'prompts': { - 'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:', - '^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n', - '(rosey|chip|joi)_.*_instruct.*': 'User: \n', - 'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>', - 'alpaca-*': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n", + 'default': 'QA', + '.*(gpt4chan|gpt-4chan|4chan)': 'GPT-4chan', + '.*oasst': 'Open Assistant', + '.*alpaca': "Alpaca", }, 'lora_prompts': { - 'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:', - '(alpaca-lora-7b|alpaca-lora-13b|alpaca-lora-30b)': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n" + 'default': 'QA', + '.*(alpaca-lora-7b|alpaca-lora-13b|alpaca-lora-30b)': "Alpaca", } } @@ -85,7 +80,7 @@ parser.add_argument('--gptq-bits', type=int, default=0, help='DEPRECATED: use -- parser.add_argument('--gptq-model-type', type=str, help='DEPRECATED: use --model_type instead.') parser.add_argument('--gptq-pre-layer', type=int, default=0, help='DEPRECATED: use --pre_layer instead.') parser.add_argument('--wbits', type=int, default=0, help='GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.') -parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently only LLaMA and OPT are supported.') +parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported.') parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.') parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to preload.') @@ -108,11 +103,14 @@ parser.add_argument('--rwkv-cuda-on', action='store_true', help='RWKV: Compile t parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.') parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.') parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.') +parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models") +parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to directory with all the loras") +parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.') parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.') parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.') -parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') +parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', default=None) args = parser.parse_args() # Provisional, this will be deleted later diff --git a/modules/text_generation.py b/modules/text_generation.py index 9b2c233d..7b5fcd6a 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -42,7 +42,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True): def decode(output_ids): # Open Assistant relies on special tokens like <|endoftext|> - if re.match('(oasst|galactica)-*', shared.model_name.lower()): + if re.match('.*(oasst|galactica)-*', shared.model_name.lower()): return shared.tokenizer.decode(output_ids, skip_special_tokens=False) else: reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True) @@ -77,10 +77,10 @@ def fix_galactica(s): def formatted_outputs(reply, model_name): if not (shared.args.chat or shared.args.cai_chat): - if model_name.lower().startswith('galactica'): + if 'galactica' in model_name.lower(): reply = fix_galactica(reply) return reply, reply, generate_basic_html(reply) - elif model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')): + elif any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])): reply = fix_gpt4chan(reply) return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply) else: @@ -99,9 +99,13 @@ def set_manual_seed(seed): if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) +def stop_everything_event(): + shared.stop_everything = True + def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]): clear_torch_cache() set_manual_seed(seed) + shared.stop_everything = False t0 = time.time() original_question = question @@ -236,8 +240,6 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi break yield formatted_outputs(reply, shared.model_name) - yield formatted_outputs(reply, shared.model_name) - # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' else: for i in range(max_new_tokens//8+1): diff --git a/modules/training.py b/modules/training.py new file mode 100644 index 00000000..62ba181c --- /dev/null +++ b/modules/training.py @@ -0,0 +1,275 @@ +import json +import sys +import threading +import time +import traceback +from pathlib import Path + +import gradio as gr +import torch +import transformers +from datasets import Dataset, load_dataset +from peft import (LoraConfig, get_peft_model, get_peft_model_state_dict, + prepare_model_for_int8_training) + +from modules import shared, ui + +WANT_INTERRUPT = False +CURRENT_STEPS = 0 +MAX_STEPS = 0 +CURRENT_GRADIENT_ACCUM = 1 + +def get_dataset(path: str, ext: str): + return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path(path).glob(f'*.{ext}'))), key=str.lower) + +def create_train_interface(): + with gr.Tab('Train LoRA', elem_id='lora-train-tab'): + lora_name = gr.Textbox(label="Name", info="The name of your new LoRA file") + with gr.Row(): + # TODO: Implement multi-device support. + micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.') + batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.') + + with gr.Row(): + epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.') + learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='Learning rate, in scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.') + + # TODO: What is the actual maximum rank? Likely distinct per model. This might be better to somehow be on a log scale. + lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='LoRA Rank, or dimension count. Higher values produce a larger file with better control over the model\'s content. Smaller values produce a smaller file with less overall control. Small values like 4 or 8 are great for stylistic guidance, high values like 128 or 256 are good for teaching content upgrades. Higher ranks also require higher VRAM.') + lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='LoRA Alpha. This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.') + # TODO: Better explain what this does, in terms of real world effect especially. + lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers.') + cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.') + + with gr.Tab(label="Formatted Dataset"): + with gr.Row(): + dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.') + ui.create_refresh_button(dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button') + eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The dataset file used to evaluate the model after training.') + ui.create_refresh_button(eval_dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button') + format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.') + ui.create_refresh_button(format, lambda : None, lambda : {'choices': get_dataset('training/formats', 'json')}, 'refresh-button') + with gr.Tab(label="Raw Text File"): + with gr.Row(): + raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.') + ui.create_refresh_button(raw_text_file, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button') + overlap_len = gr.Slider(label='Overlap Length', minimum=0,maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length above). Setting overlap to exactly half the cutoff length may be ideal.') + + with gr.Row(): + start_button = gr.Button("Start LoRA Training") + stop_button = gr.Button("Interrupt") + + output = gr.Markdown(value="Ready") + start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len], [output]) + stop_button.click(do_interrupt, [], [], cancels=[], queue=False) + +def do_interrupt(): + global WANT_INTERRUPT + WANT_INTERRUPT = True + +class Callbacks(transformers.TrainerCallback): + def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): + global CURRENT_STEPS, MAX_STEPS + CURRENT_STEPS = state.global_step * CURRENT_GRADIENT_ACCUM + MAX_STEPS = state.max_steps * CURRENT_GRADIENT_ACCUM + if WANT_INTERRUPT: + control.should_epoch_stop = True + control.should_training_stop = True + def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): + global CURRENT_STEPS + CURRENT_STEPS += 1 + if WANT_INTERRUPT: + control.should_epoch_stop = True + control.should_training_stop = True + +def clean_path(base_path: str, path: str): + """"Strips unusual symbols and forcibly builds a path as relative to the intended directory.""" + # TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path. + # Or swap it to a strict whitelist of [a-zA-Z_0-9] + path = path.replace('\\', '/').replace('..', '_') + if base_path is None: + return path + return f'{Path(base_path).absolute()}/{path}' + +def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, + lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int): + global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM + WANT_INTERRUPT = False + CURRENT_STEPS = 0 + MAX_STEPS = 0 + + # == Input validation / processing == + yield "Prepping..." + lora_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}" + actual_lr = float(learning_rate) + + if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0: + yield f"Cannot input zeroes." + return + + gradient_accumulation_steps = batch_size // micro_batch_size + CURRENT_GRADIENT_ACCUM = gradient_accumulation_steps + shared.tokenizer.pad_token = 0 + shared.tokenizer.padding_side = "left" + + def tokenize(prompt): + result = shared.tokenizer(prompt, truncation=True, max_length=cutoff_len + 1, padding="max_length") + return { + "input_ids": result["input_ids"][:-1], + "attention_mask": result["attention_mask"][:-1], + } + + # == Prep the dataset, format, etc == + if raw_text_file not in ['None', '']: + print("Loading raw text file dataset...") + with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r') as file: + raw_text = file.read() + tokens = shared.tokenizer.encode(raw_text) + del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM + tokens = list(split_chunks(tokens, cutoff_len - overlap_len)) + for i in range(1, len(tokens)): + tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i] + text_chunks = [shared.tokenizer.decode(x) for x in tokens] + del tokens + data = Dataset.from_list([tokenize(x) for x in text_chunks]) + train_data = data.shuffle() + eval_data = None + del text_chunks + + else: + if dataset in ['None', '']: + yield "**Missing dataset choice input, cannot continue.**" + return + + if format in ['None', '']: + yield "**Missing format choice input, cannot continue.**" + return + + with open(clean_path('training/formats', f'{format}.json'), 'r') as formatFile: + format_data: dict[str, str] = json.load(formatFile) + + def generate_prompt(data_point: dict[str, str]): + for options, data in format_data.items(): + if set(options.split(',')) == set(x[0] for x in data_point.items() if len(x[1].strip()) > 0): + for key, val in data_point.items(): + data = data.replace(f'%{key}%', val) + return data + raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"') + + def generate_and_tokenize_prompt(data_point): + prompt = generate_prompt(data_point) + return tokenize(prompt) + + print("Loading JSON datasets...") + data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json')) + train_data = data['train'].shuffle().map(generate_and_tokenize_prompt) + + if eval_dataset == 'None': + eval_data = None + else: + eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json')) + eval_data = eval_data['train'].shuffle().map(generate_and_tokenize_prompt) + + # == Start prepping the model itself == + if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'): + print("Getting model ready...") + prepare_model_for_int8_training(shared.model) + + print("Prepping for training...") + config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + # TODO: Should target_modules be configurable? + target_modules=[ "q_proj", "v_proj" ], + lora_dropout=lora_dropout, + bias="none", + task_type="CAUSAL_LM" + ) + + try: + lora_model = get_peft_model(shared.model, config) + except: + yield traceback.format_exc() + return + + trainer = transformers.Trainer( + model=lora_model, + train_dataset=train_data, + eval_dataset=eval_data, + args=transformers.TrainingArguments( + per_device_train_batch_size=micro_batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + # TODO: Should more of these be configurable? Probably. + warmup_steps=100, + num_train_epochs=epochs, + learning_rate=actual_lr, + fp16=True, + logging_steps=20, + evaluation_strategy="steps" if eval_data is not None else "no", + save_strategy="steps", + eval_steps=200 if eval_data is not None else None, + save_steps=200, + output_dir=lora_name, + save_total_limit=3, + load_best_model_at_end=True if eval_data is not None else False, + # TODO: Enable multi-device support + ddp_find_unused_parameters=None + ), + data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False), + callbacks=list([Callbacks()]) + ) + + lora_model.config.use_cache = False + old_state_dict = lora_model.state_dict + lora_model.state_dict = ( + lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict()) + ).__get__(lora_model, type(lora_model)) + + if torch.__version__ >= "2" and sys.platform != "win32": + lora_model = torch.compile(lora_model) + + # == Main run and monitor loop == + # TODO: save/load checkpoints to resume from? + print("Starting training...") + yield "Starting..." + + def threadedRun(): + trainer.train() + + thread = threading.Thread(target=threadedRun) + thread.start() + lastStep = 0 + startTime = time.perf_counter() + + while thread.is_alive(): + time.sleep(0.5) + if WANT_INTERRUPT: + yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*" + elif CURRENT_STEPS != lastStep: + lastStep = CURRENT_STEPS + timeElapsed = time.perf_counter() - startTime + if timeElapsed <= 0: + timerInfo = "" + totalTimeEstimate = 999 + else: + its = CURRENT_STEPS / timeElapsed + if its > 1: + timerInfo = f"`{its:.2f}` it/s" + else: + timerInfo = f"`{1.0/its:.2f}` s/it" + totalTimeEstimate = (1.0/its) * (MAX_STEPS) + yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds" + + print("Training complete, saving...") + lora_model.save_pretrained(lora_name) + + if WANT_INTERRUPT: + print("Training interrupted.") + yield f"Interrupted. Incomplete LoRA saved to `{lora_name}`" + else: + print("Training complete!") + yield f"Done! LoRA saved to `{lora_name}`" + +def split_chunks(arr, step): + for i in range(0, len(arr), step): + yield arr[i:i + step] diff --git a/prompts/Alpaca.txt b/prompts/Alpaca.txt new file mode 100644 index 00000000..8434a80c --- /dev/null +++ b/prompts/Alpaca.txt @@ -0,0 +1,6 @@ +Below is an instruction that describes a task. Write a response that appropriately completes the request. +### Instruction: +Write a poem about the transformers Python library. +Mention the word "large language models" in that poem. +### Response: + diff --git a/prompts/GPT-4chan.txt b/prompts/GPT-4chan.txt new file mode 100644 index 00000000..1bc8c7f4 --- /dev/null +++ b/prompts/GPT-4chan.txt @@ -0,0 +1,6 @@ +----- +--- 865467536 +Hello, AI frens! +How are you doing on this fine day? +--- 865467537 + diff --git a/prompts/Open Assistant.txt b/prompts/Open Assistant.txt new file mode 100644 index 00000000..cf1ae4a2 --- /dev/null +++ b/prompts/Open Assistant.txt @@ -0,0 +1 @@ +<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|> diff --git a/prompts/QA.txt b/prompts/QA.txt new file mode 100644 index 00000000..32b0e235 --- /dev/null +++ b/prompts/QA.txt @@ -0,0 +1,4 @@ +Common sense questions and answers + +Question: +Factual answer: diff --git a/requirements.txt b/requirements.txt index e5b3de69..79da715d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,14 @@ -accelerate==0.17.1 -bitsandbytes==0.37.1 +accelerate==0.18.0 +bitsandbytes==0.37.2 flexgen==0.1.7 -gradio==3.18.0 +gradio==3.23.0 markdown numpy peft==0.2.0 requests -rwkv==0.7.0 +rwkv==0.7.1 safetensors==0.3.0 sentencepiece tqdm +datasets git+https://github.com/huggingface/transformers diff --git a/server.py b/server.py index f1b95a5b..27223f84 100644 --- a/server.py +++ b/server.py @@ -4,18 +4,18 @@ import re import sys import time import zipfile +from datetime import datetime from pathlib import Path import gradio as gr -import modules.chat as chat import modules.extensions as extensions_module -import modules.shared as shared -import modules.ui as ui +from modules import chat, shared, training, ui from modules.html_generator import generate_chat_html from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt -from modules.text_generation import clear_torch_cache, generate_reply +from modules.text_generation import (clear_torch_cache, generate_reply, + stop_everything_event) # Loading custom settings settings_file = None @@ -31,13 +31,20 @@ if settings_file is not None: def get_available_models(): if shared.args.flexgen: - return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower) + return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower) else: - return sorted([re.sub('.pth$', '', item.name) for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) + return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) def get_available_presets(): return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) +def get_available_prompts(): + prompts = [] + prompts += sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True) + prompts += sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('prompts').glob('*.txt'))), key=str.lower) + prompts += ['None'] + return prompts + def get_available_characters(): return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower) @@ -48,22 +55,25 @@ def get_available_softprompts(): return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower) def get_available_loras(): - return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) + return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) + +def unload_model(): + shared.model = shared.tokenizer = None + clear_torch_cache() def load_model_wrapper(selected_model): if selected_model != shared.model_name: shared.model_name = selected_model - shared.model = shared.tokenizer = None - clear_torch_cache() - shared.model, shared.tokenizer = load_model(shared.model_name) + + unload_model() + if selected_model != '': + shared.model, shared.tokenizer = load_model(shared.model_name) return selected_model def load_lora_wrapper(selected_lora): add_lora_to_model(selected_lora) - default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')] - - return selected_lora, default_text + return selected_lora def load_preset_values(preset_menu, return_dict=False): generate_params = { @@ -93,7 +103,7 @@ def load_preset_values(preset_menu, return_dict=False): if return_dict: return generate_params else: - return preset_menu, generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping'] + return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping'] def upload_soft_prompt(file): with zipfile.ZipFile(io.BytesIO(file)) as zf: @@ -118,9 +128,46 @@ def create_model_and_preset_menus(): shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset') ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button') +def save_prompt(text): + fname = f"{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.txt" + with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f: + f.write(text) + return f"Saved to prompts/{fname}" + +def load_prompt(fname): + if fname in ['None', '']: + return '' + else: + with open(Path(f'prompts/{fname}.txt'), 'r', encoding='utf-8') as f: + text = f.read() + if text[-1] == '\n': + text = text[:-1] + return text + +def create_prompt_menus(): + with gr.Row(): + with gr.Column(): + with gr.Row(): + shared.gradio['prompt_menu'] = gr.Dropdown(choices=get_available_prompts(), value='None', label='Prompt') + ui.create_refresh_button(shared.gradio['prompt_menu'], lambda : None, lambda : {'choices': get_available_prompts()}, 'refresh-button') + + with gr.Column(): + with gr.Column(): + shared.gradio['save_prompt'] = gr.Button('Save prompt') + shared.gradio['status'] = gr.Markdown('Ready') + + shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False) + shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False) + def create_settings_menus(default_preset): generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True) + with gr.Row(): + with gr.Column(): + create_model_and_preset_menus() + with gr.Column(): + shared.gradio['seed'] = gr.Number(value=-1, label='Seed (-1 for random)') + with gr.Row(): with gr.Column(): with gr.Box(): @@ -151,12 +198,6 @@ def create_settings_menus(default_preset): shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') - shared.gradio['seed'] = gr.Number(value=-1, label='Seed (-1 for random)') - - with gr.Row(): - shared.gradio['preset_menu_mirror'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset') - ui.create_refresh_button(shared.gradio['preset_menu_mirror'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button') - with gr.Row(): shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA') ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button') @@ -171,9 +212,8 @@ def create_settings_menus(default_preset): shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip']) shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True) - shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['preset_menu_mirror', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) - shared.gradio['preset_menu_mirror'].change(load_preset_values, [shared.gradio['preset_menu_mirror']], [shared.gradio[k] for k in ['preset_menu', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) - shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu'], shared.gradio['textbox']], show_progress=True) + shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) + shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu']], show_progress=True) shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True) shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']]) @@ -238,12 +278,10 @@ if shared.args.lora: # Default UI settings default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')] if shared.lora_name != "None": - default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')] + default_text = load_prompt(shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]) else: - default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')] + default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]) title ='Text generation web UI' -description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n' -suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else '' def create_interface(): @@ -255,13 +293,13 @@ def create_interface(): if shared.args.chat or shared.args.cai_chat: with gr.Tab("Text generation", elem_id="main"): if shared.args.cai_chat: - shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character)) + shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], shared.character)) else: - shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528")) + shared.gradio['display'] = gr.Chatbot(value=shared.history['visible'], elem_id="gradio-chatbot") shared.gradio['textbox'] = gr.Textbox(label='Input') with gr.Row(): - shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop") shared.gradio['Generate'] = gr.Button('Generate') + shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop") with gr.Row(): shared.gradio['Impersonate'] = gr.Button('Impersonate') shared.gradio['Regenerate'] = gr.Button('Regenerate') @@ -274,12 +312,10 @@ def create_interface(): shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False) shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) - create_model_and_preset_menus() - with gr.Tab("Character", elem_id="chat-settings"): - shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name') - shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name') - shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context') + shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name') + shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Bot\'s name') + shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=5, label='Context') with gr.Row(): shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu') ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button') @@ -317,7 +353,7 @@ def create_interface(): shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) with gr.Column(): shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)') - shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?') + shared.gradio['check'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?') create_settings_menus(default_preset) @@ -328,7 +364,7 @@ def create_interface(): gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)) - shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events, queue=False) + shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream) shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream) @@ -364,24 +400,34 @@ def create_interface(): shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") - shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None) + shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None) shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True) elif shared.args.notebook: with gr.Tab("Text generation", elem_id="main"): - with gr.Tab('Raw'): - shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=25) - with gr.Tab('Markdown'): - shared.gradio['markdown'] = gr.Markdown() - with gr.Tab('HTML'): - shared.gradio['html'] = gr.HTML() - with gr.Row(): - shared.gradio['Stop'] = gr.Button('Stop') - shared.gradio['Generate'] = gr.Button('Generate') - shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) + with gr.Column(scale=4): + with gr.Tab('Raw'): + shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_id="textbox", lines=25) + with gr.Tab('Markdown'): + shared.gradio['markdown'] = gr.Markdown() + with gr.Tab('HTML'): + shared.gradio['html'] = gr.HTML() + + with gr.Row(): + with gr.Column(): + with gr.Row(): + shared.gradio['Generate'] = gr.Button('Generate') + shared.gradio['Stop'] = gr.Button('Stop') + with gr.Column(): + pass + + with gr.Column(scale=1): + gr.HTML('
') + shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) + + create_prompt_menus() - create_model_and_preset_menus() with gr.Tab("Parameters", elem_id="parameters"): create_settings_menus(default_preset) @@ -389,7 +435,7 @@ def create_interface(): output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) - shared.gradio['Stop'].click(None, None, None, cancels=gen_events) + shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") else: @@ -405,7 +451,7 @@ def create_interface(): with gr.Column(): shared.gradio['Stop'] = gr.Button('Stop') - create_model_and_preset_menus() + create_prompt_menus() with gr.Column(): with gr.Tab('Raw'): @@ -414,6 +460,7 @@ def create_interface(): shared.gradio['markdown'] = gr.Markdown() with gr.Tab('HTML'): shared.gradio['html'] = gr.HTML() + with gr.Tab("Parameters", elem_id="parameters"): create_settings_menus(default_preset) @@ -422,9 +469,12 @@ def create_interface(): gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)) - shared.gradio['Stop'].click(None, None, None, cancels=gen_events) + shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") + with gr.Tab("Training", elem_id="training-tab"): + training.create_train_interface() + with gr.Tab("Interface mode", elem_id="interface-mode"): modes = ["default", "notebook", "chat", "cai_chat"] current_mode = "default" @@ -443,17 +493,26 @@ def create_interface(): shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", type="primary") shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'cmd_arguments_menu']], None) - shared.gradio['reset_interface'].click(lambda : None, None, None, _js='() => {document.body.innerHTML=\'

Reloading...

\'; setTimeout(function(){location.reload()},2500)}') + shared.gradio['reset_interface'].click(lambda : None, None, None, _js='() => {document.body.innerHTML=\'

Reloading...

\'; setTimeout(function(){location.reload()},2500); return []}') if shared.args.extensions is not None: extensions_module.create_extensions_block() + # Authentication + auth = None + if shared.args.gradio_auth_path is not None: + gradio_auth_creds = [] + with open(shared.args.gradio_auth_path, 'r', encoding="utf8") as file: + for line in file.readlines(): + gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()] + auth = [tuple(cred.split(':')) for cred in gradio_auth_creds] + # Launch the interface shared.gradio['interface'].queue() if shared.args.listen: - shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch) + shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth) else: - shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch) + shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth) create_interface() diff --git a/settings-template.json b/settings-template.json index 79fd5023..da767cda 100644 --- a/settings-template.json +++ b/settings-template.json @@ -12,27 +12,23 @@ "chat_generation_attempts": 1, "chat_generation_attempts_min": 1, "chat_generation_attempts_max": 5, - "name1_pygmalion": "You", - "name2_pygmalion": "Kawaii", - "context_pygmalion": "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n", - "stop_at_newline_pygmalion": false, "default_extensions": [], "chat_default_extensions": [ "gallery" ], "presets": { "default": "NovelAI-Sphinx Moth", - "pygmalion-*": "Pygmalion", - "RWKV-*": "Naive" + ".*pygmalion": "Pygmalion", + ".*RWKV": "Naive" }, "prompts": { - "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:", - "^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n", - "(rosey|chip|joi)_.*_instruct.*": "User: \n", - "oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>" + "default": "QA", + ".*(gpt4chan|gpt-4chan|4chan)": "GPT-4chan", + ".*oasst": "Open Assistant", + ".*alpaca": "Alpaca" }, "lora_prompts": { - "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:", - "alpaca-lora-7b": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n" + "default": "QA", + ".*(alpaca-lora-7b|alpaca-lora-13b|alpaca-lora-30b)": "Alpaca" } } diff --git a/loras/place-your-loras-here.txt b/training/datasets/put-trainer-datasets-here.txt similarity index 100% rename from loras/place-your-loras-here.txt rename to training/datasets/put-trainer-datasets-here.txt diff --git a/training/formats/alpaca-chatbot-format.json b/training/formats/alpaca-chatbot-format.json new file mode 100644 index 00000000..4b38103f --- /dev/null +++ b/training/formats/alpaca-chatbot-format.json @@ -0,0 +1,4 @@ +{ + "instruction,output": "User: %instruction%\nAssistant: %output%", + "instruction,input,output": "User: %instruction%: %input%\nAssistant: %output%" +} diff --git a/training/formats/alpaca-format.json b/training/formats/alpaca-format.json new file mode 100644 index 00000000..dd6df956 --- /dev/null +++ b/training/formats/alpaca-format.json @@ -0,0 +1,4 @@ +{ + "instruction,output": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n%instruction%\n\n### Response:\n%output%", + "instruction,input,output": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n%instruction%\n\n### Input:\n%input%\n\n### Response:\n%output%" +}