diff --git a/.gitignore b/.gitignore
index a9c47a5a..bfb6d027 100644
--- a/.gitignore
+++ b/.gitignore
@@ -14,6 +14,7 @@ torch-dumps
*/*/pycache*
venv/
.venv/
+.vscode
repositories
settings.json
diff --git a/README.md b/README.md
index 169c894b..e0784e12 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
# Text generation web UI
-A gradio web UI for running Large Language Models like GPT-J 6B, OPT, GALACTICA, LLaMA, and Pygmalion.
+A gradio web UI for running Large Language Models like LLaMA, llama.cpp, GPT-J, OPT, and GALACTICA.
Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation.
@@ -28,6 +28,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
* [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed).
* Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming.
* [LLaMA model, including 4-bit GPTQ support](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model).
+* [llama.cpp support](https://github.com/oobabooga/text-generation-webui/wiki/llama.cpp-models). **\*NEW!\***
* [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model).
* [Supports LoRAs](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs).
* Supports softprompts.
@@ -175,24 +176,31 @@ Optionally, you can use the following command-line flags:
| Flag | Description |
|------------------|-------------|
| `-h`, `--help` | show this help message and exit |
-| `--model MODEL` | Name of the model to load by default. |
-| `--lora LORA` | Name of the LoRA to apply to the model by default. |
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
| `--chat` | Launch the web UI in chat mode.|
| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
+| `--model MODEL` | Name of the model to load by default. |
+| `--lora LORA` | Name of the LoRA to apply to the model by default. |
+| `--model-dir MODEL_DIR` | Path to directory with all the models |
+| `--lora-dir LORA_DIR` | Path to directory with all the loras |
+| `--no-stream` | Don't stream the text output in real time. |
+| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.|
+| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
+| `--verbose` | Print the prompts to the terminal. |
| `--cpu` | Use the CPU to generate text.|
+| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
+| `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` | Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. You can also set values in MiB like `--gpu-memory 3500MiB`. |
+| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.|
+| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
+| `--disk-cache-dir DISK_CACHE_DIR` | Directory to save the disk cache to. Defaults to `cache/`. |
| `--load-in-8bit` | Load the model with 8-bit precision.|
+| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
+| `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. |
+| `--threads` | Number of threads to use in llama.cpp. |
| `--wbits WBITS` | GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. |
| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. |
| `--groupsize GROUPSIZE` | GPTQ: Group size. |
| `--pre_layer PRE_LAYER` | GPTQ: The number of layers to preload. |
-| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
-| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
-| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
-| `--disk-cache-dir DISK_CACHE_DIR` | Directory to save the disk cache to. Defaults to `cache/`. |
-| `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` | Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. You can also set values in MiB like `--gpu-memory 3500MiB`. |
-| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.|
-| `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. |
| `--flexgen` | Enable the use of FlexGen offloading. |
| `--percent PERCENT [PERCENT ...]` | FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). |
| `--compress-weight` | FlexGen: Whether to compress weight (default: False).|
@@ -202,12 +210,6 @@ Optionally, you can use the following command-line flags:
| `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. |
| `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. |
-| `--no-stream` | Don't stream the text output in real time. |
-| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.|
-| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
-| `--model-dir MODEL_DIR` | Path to directory with all the models |
-| `--lora-dir LORA_DIR` | Path to directory with all the loras |
-| `--verbose` | Print the prompts to the terminal. |
| `--listen` | Make the web UI reachable from your local network. |
| `--listen-port LISTEN_PORT` | The listening port that the server will use. |
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
diff --git a/download-model.py b/download-model.py
index 7e5f61b2..0f40ab50 100644
--- a/download-model.py
+++ b/download-model.py
@@ -9,6 +9,7 @@ python download-model.py facebook/opt-1.3b
import argparse
import base64
import datetime
+import hashlib
import json
import re
import sys
@@ -24,11 +25,28 @@ parser.add_argument('--branch', type=str, default='main', help='Name of the Git
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
+parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
+parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
args = parser.parse_args()
def get_file(url, output_folder):
- r = requests.get(url, stream=True)
- with open(output_folder / Path(url.rsplit('/', 1)[1]), 'wb') as f:
+ filename = Path(url.rsplit('/', 1)[1])
+ output_path = output_folder / filename
+ if output_path.exists() and not args.clean:
+ # Check if the file has already been downloaded completely
+ r = requests.get(url, stream=True)
+ total_size = int(r.headers.get('content-length', 0))
+ if output_path.stat().st_size >= total_size:
+ return
+ # Otherwise, resume the download from where it left off
+ headers = {'Range': f'bytes={output_path.stat().st_size}-'}
+ mode = 'ab'
+ else:
+ headers = {}
+ mode = 'wb'
+
+ r = requests.get(url, stream=True, headers=headers)
+ with open(output_path, mode) as f:
total_size = int(r.headers.get('content-length', 0))
block_size = 1024
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
@@ -97,6 +115,7 @@ def get_download_links_from_huggingface(model, branch):
classifications = []
has_pytorch = False
has_pt = False
+ has_ggml = False
has_safetensors = False
is_lora = False
while True:
@@ -114,6 +133,7 @@ def get_download_links_from_huggingface(model, branch):
is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
is_safetensors = re.match(".*\.safetensors", fname)
is_pt = re.match(".*\.pt", fname)
+ is_ggml = re.match("ggml.*\.bin", fname)
is_tokenizer = re.match("tokenizer.*\.model", fname)
is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
@@ -135,6 +155,9 @@ def get_download_links_from_huggingface(model, branch):
elif is_pt:
has_pt = True
classifications.append('pt')
+ elif is_ggml:
+ has_ggml = True
+ classifications.append('ggml')
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
cursor = base64.b64encode(cursor)
@@ -149,7 +172,7 @@ def get_download_links_from_huggingface(model, branch):
return links, sha256, is_lora
def download_files(file_list, output_folder, num_threads=8):
- thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads)
+ thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
if __name__ == '__main__':
model = args.MODEL
@@ -179,22 +202,48 @@ if __name__ == '__main__':
output_folder = f"{'_'.join(model.split('/')[-2:])}"
if branch != 'main':
output_folder += f'_{branch}'
-
- # Creating the folder and writing the metadata
output_folder = Path(base_folder) / output_folder
- if not output_folder.exists():
- output_folder.mkdir()
- with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
- f.write(f'url: https://huggingface.co/{model}\n')
- f.write(f'branch: {branch}\n')
- f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
- sha256_str = ''
- for i in range(len(sha256)):
- sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
- if sha256_str != '':
- f.write(f'sha256sum:\n{sha256_str}')
- # Downloading the files
- print(f"Downloading the model to {output_folder}")
- download_files(links, output_folder, args.threads)
- print()
+ if args.check:
+ # Validate the checksums
+ validated = True
+ for i in range(len(sha256)):
+ fpath = (output_folder / sha256[i][0])
+
+ if not fpath.exists():
+ print(f"The following file is missing: {fpath}")
+ validated = False
+ continue
+
+ with open(output_folder / sha256[i][0], "rb") as f:
+ bytes = f.read()
+ file_hash = hashlib.sha256(bytes).hexdigest()
+ if file_hash != sha256[i][1]:
+ print(f'Checksum failed: {sha256[i][0]} {sha256[i][1]}')
+ validated = False
+ else:
+ print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}')
+
+ if validated:
+ print('[+] Validated checksums of all model files!')
+ else:
+ print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
+
+ else:
+
+ # Creating the folder and writing the metadata
+ if not output_folder.exists():
+ output_folder.mkdir()
+ with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
+ f.write(f'url: https://huggingface.co/{model}\n')
+ f.write(f'branch: {branch}\n')
+ f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
+ sha256_str = ''
+ for i in range(len(sha256)):
+ sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
+ if sha256_str != '':
+ f.write(f'sha256sum:\n{sha256_str}')
+
+ # Downloading the files
+ print(f"Downloading the model to {output_folder}")
+ download_files(links, output_folder, args.threads)
\ No newline at end of file
diff --git a/extensions/gallery/script.py b/extensions/gallery/script.py
index fbf23bc9..c17d69ee 100644
--- a/extensions/gallery/script.py
+++ b/extensions/gallery/script.py
@@ -2,19 +2,29 @@ from pathlib import Path
import gradio as gr
+from modules.chat import load_character
from modules.html_generator import get_image_cache
+from modules.shared import gradio, settings
-def generate_html():
+def generate_css():
css = """
- .character-gallery {
+ .character-gallery > .gallery {
margin: 1rem 0;
- display: grid;
+ display: grid !important;
grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));
grid-column-gap: 0.4rem;
grid-row-gap: 1.2rem;
}
+ .character-gallery > .label {
+ display: none !important;
+ }
+
+ .character-gallery button.gallery-item {
+ display: contents;
+ }
+
.character-container {
cursor: pointer;
text-align: center;
@@ -45,14 +55,16 @@ def generate_html():
overflow-wrap: anywhere;
}
"""
+ return css
- container_html = f'
'
+def generate_html():
+ cards = []
# Iterate through files in image folder
for file in sorted(Path("characters").glob("*")):
if file.name.endswith(".json"):
character = file.name.replace(".json", "")
- container_html += f'
'
+ container_html = f'
'
image_html = "
"
for i in [
@@ -71,12 +83,24 @@ def generate_html():
container_html += f'{image_html}
{character}'
container_html += "
"
+ cards.append([container_html, character])
+
+ return cards
+
+
+def select_character(evt: gr.SelectData):
+ return (evt.value[1])
- container_html += "
"
- return container_html
def ui():
with gr.Accordion("Character gallery", open=False):
update = gr.Button("Refresh")
- gallery = gr.HTML(value=generate_html())
+ gr.HTML(value="")
+ gallery = gr.Dataset(components=[gr.HTML(visible=False)],
+ label="",
+ samples=generate_html(),
+ elem_classes=["character-gallery"],
+ samples_per_page=50
+ )
update.click(generate_html, [], gallery)
+ gallery.select(select_character, None, gradio['character_menu'])
\ No newline at end of file
diff --git a/loras/place-your-loras-here.txt b/loras/place-your-loras-here.txt
new file mode 100644
index 00000000..e69de29b
diff --git a/modules/RWKV.py b/modules/RWKV.py
index 8c7ea2b9..10c4c366 100644
--- a/modules/RWKV.py
+++ b/modules/RWKV.py
@@ -34,7 +34,7 @@ class RWKVModel:
result.pipeline = pipeline
return result
- def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None):
+ def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=None, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None):
args = PIPELINE_ARGS(
temperature = temperature,
top_p = top_p,
diff --git a/modules/callbacks.py b/modules/callbacks.py
index aa92f9cb..945b8c37 100644
--- a/modules/callbacks.py
+++ b/modules/callbacks.py
@@ -12,7 +12,7 @@ import modules.shared as shared
# Copied from https://github.com/PygmalionAI/gradio-ui/
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
- def __init__(self, sentinel_token_ids: list[torch.LongTensor], starting_idx: int):
+ def __init__(self, sentinel_token_ids: list, starting_idx: int):
transformers.StoppingCriteria.__init__(self)
self.sentinel_token_ids = sentinel_token_ids
self.starting_idx = starting_idx
diff --git a/modules/chat.py b/modules/chat.py
index cc3c45c7..db79e7db 100644
--- a/modules/chat.py
+++ b/modules/chat.py
@@ -22,7 +22,7 @@ def generate_chat_output(history, name1, name2, character):
else:
return history
-def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False):
+def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False, also_return_rows=False):
user_input = fix_newlines(user_input)
rows = [f"{context.strip()}\n"]
@@ -51,7 +51,11 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
rows.pop(1)
prompt = ''.join(rows)
- return prompt
+
+ if also_return_rows:
+ return prompt, rows
+ else:
+ return prompt
def extract_message_from_reply(reply, name1, name2, check):
next_character_found = False
diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py
new file mode 100644
index 00000000..4f491329
--- /dev/null
+++ b/modules/llamacpp_model.py
@@ -0,0 +1,82 @@
+import multiprocessing
+
+import llamacpp
+
+from modules import shared
+from modules.callbacks import Iteratorize
+
+
+class LlamaCppTokenizer:
+ """A thin wrapper over the llamacpp tokenizer"""
+ def __init__(self, model: llamacpp.LlamaInference):
+ self._tokenizer = model.get_tokenizer()
+ self.eos_token_id = 2
+ self.bos_token_id = 0
+
+ @classmethod
+ def from_model(cls, model: llamacpp.LlamaInference):
+ return cls(model)
+
+ def encode(self, prompt: str):
+ return self._tokenizer.tokenize(prompt)
+
+ def decode(self, ids):
+ return self._tokenizer.detokenize(ids)
+
+
+class LlamaCppModel:
+ def __init__(self):
+ self.initialized = False
+
+ @classmethod
+ def from_pretrained(self, path):
+ params = llamacpp.InferenceParams()
+ params.path_model = str(path)
+ params.n_threads = shared.args.threads or multiprocessing.cpu_count() // 2
+
+ _model = llamacpp.LlamaInference(params)
+
+ result = self()
+ result.model = _model
+ result.params = params
+
+ tokenizer = LlamaCppTokenizer.from_model(_model)
+ return result, tokenizer
+
+ def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
+ params = self.params
+ params.n_predict = token_count
+ params.top_p = top_p
+ params.top_k = top_k
+ params.temp = temperature
+ params.repeat_penalty = repetition_penalty
+ #params.repeat_last_n = repeat_last_n
+
+ #self.model.params = params
+ self.model.add_bos()
+ self.model.update_input(context)
+
+ output = ""
+ is_end_of_text = False
+ ctr = 0
+ while ctr < token_count and not is_end_of_text:
+ if self.model.has_unconsumed_input():
+ self.model.ingest_all_pending_input()
+ else:
+ self.model.eval()
+ token = self.model.sample()
+ text = self.model.token_to_str(token)
+ output += text
+ is_end_of_text = token == self.model.token_eos()
+ if callback:
+ callback(text)
+ ctr += 1
+
+ return output
+
+ def generate_with_streaming(self, **kwargs):
+ with Iteratorize(self.generate, kwargs, callback=None) as generator:
+ reply = ''
+ for token in generator:
+ reply += token
+ yield reply
diff --git a/modules/models.py b/modules/models.py
index b19507db..edcb3507 100644
--- a/modules/models.py
+++ b/modules/models.py
@@ -42,9 +42,10 @@ def load_model(model_name):
t0 = time.time()
shared.is_RWKV = 'rwkv-' in model_name.lower()
+ shared.is_llamacpp = len(list(Path(f'models/{model_name}').glob('ggml*.bin'))) > 0
# Default settings
- if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
+ if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]):
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), device_map='auto', load_in_8bit=True)
else:
@@ -100,6 +101,16 @@ def load_model(model_name):
model = load_quantized(model_name)
+ # llamacpp model
+ elif shared.is_llamacpp:
+ from modules.llamacpp_model import LlamaCppModel
+
+ model_file = list(Path(f'models/{model_name}').glob('ggml*.bin'))[0]
+ print(f"llama.cpp weights detected: {model_file}\n")
+
+ model, tokenizer = LlamaCppModel.from_pretrained(model_file)
+ return model, tokenizer
+
# Custom
else:
params = {"low_cpu_mem_usage": True}
diff --git a/modules/shared.py b/modules/shared.py
index 06535d1e..608ef315 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -27,6 +27,7 @@ settings = {
'max_new_tokens': 200,
'max_new_tokens_min': 1,
'max_new_tokens_max': 2000,
+ 'seed': -1,
'name1': 'You',
'name2': 'Assistant',
'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
@@ -68,51 +69,68 @@ def str2bool(v):
raise argparse.ArgumentTypeError('Boolean value expected.')
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
-parser.add_argument('--model', type=str, help='Name of the model to load by default.')
-parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
+
+# Basic settings
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
-parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
-parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
+parser.add_argument('--model', type=str, help='Name of the model to load by default.')
+parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
+parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models")
+parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to directory with all the loras")
+parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.')
+parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.')
+parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
+parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
-parser.add_argument('--gptq-bits', type=int, default=0, help='DEPRECATED: use --wbits instead.')
-parser.add_argument('--gptq-model-type', type=str, help='DEPRECATED: use --model_type instead.')
-parser.add_argument('--gptq-pre-layer', type=int, default=0, help='DEPRECATED: use --pre_layer instead.')
+# Accelerate/transformers
+parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
+parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
+parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs.')
+parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.')
+parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
+parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directory to save the disk cache to. Defaults to "cache".')
+parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
+parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
+parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
+
+# llama.cpp
+parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.')
+
+# GPTQ
parser.add_argument('--wbits', type=int, default=0, help='GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.')
parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported.')
parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.')
parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to preload.')
+parser.add_argument('--gptq-bits', type=int, default=0, help='DEPRECATED: use --wbits instead.')
+parser.add_argument('--gptq-model-type', type=str, help='DEPRECATED: use --model_type instead.')
+parser.add_argument('--gptq-pre-layer', type=int, default=0, help='DEPRECATED: use --pre_layer instead.')
-parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
-parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
-parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
-parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directory to save the disk cache to. Defaults to "cache".')
-parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs.')
-parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.')
-parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
+# FlexGen
parser.add_argument('--flexgen', action='store_true', help='Enable the use of FlexGen offloading.')
parser.add_argument('--percent', type=int, nargs="+", default=[0, 100, 100, 0, 100, 0], help='FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0).')
parser.add_argument("--compress-weight", action="store_true", help="FlexGen: activate weight compression.")
parser.add_argument("--pin-weight", type=str2bool, nargs="?", const=True, default=True, help="FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%%).")
+
+# DeepSpeed
parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.')
parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
+
+# RWKV
parser.add_argument('--rwkv-strategy', type=str, default=None, help='RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8".')
parser.add_argument('--rwkv-cuda-on', action='store_true', help='RWKV: Compile the CUDA kernel for better performance.')
-parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.')
-parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.')
-parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
-parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models")
-parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to directory with all the loras")
-parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
+
+# Gradio
parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.')
parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', default=None)
+
args = parser.parse_args()
+
# Provisional, this will be deleted later
deprecated_dict = {'gptq_bits': ['wbits', 0], 'gptq_model_type': ['model_type', None], 'gptq_pre_layer': ['prelayer', 0]}
for k in deprecated_dict:
diff --git a/modules/text_generation.py b/modules/text_generation.py
index 7b5fcd6a..6ae592db 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -22,7 +22,7 @@ def get_max_prompt_length(tokens):
return max_length
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
- if shared.is_RWKV:
+ if any((shared.is_RWKV, shared.is_llamacpp)):
input_ids = shared.tokenizer.encode(str(prompt))
input_ids = np.array(input_ids).reshape(1, len(input_ids))
return input_ids
@@ -116,10 +116,11 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# These models are not part of Hugging Face, so we handle them
# separately and terminate the function call earlier
- if shared.is_RWKV:
+ if any((shared.is_RWKV, shared.is_llamacpp)):
try:
if shared.args.no_stream:
- reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
+ reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
+ output = original_question+reply
if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply, "output")
yield formatted_outputs(reply, shared.model_name)
@@ -129,7 +130,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# RWKV has proper streaming, which is very nice.
# No need to generate 8 tokens at a time.
- for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
+ for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty):
+ output = original_question+reply
if not (shared.args.chat or shared.args.cai_chat):
reply = original_question + apply_extensions(reply, "output")
yield formatted_outputs(reply, shared.model_name)
@@ -138,9 +140,9 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
traceback.print_exc()
finally:
t1 = time.time()
- output = encode(reply)[0]
- input_ids = encode(question)
- print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
+ original_tokens = len(encode(original_question)[0])
+ new_tokens = len(encode(output)[0]) - original_tokens
+ print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
return
input_ids = encode(question, max_new_tokens)
@@ -272,5 +274,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
traceback.print_exc()
finally:
t1 = time.time()
- print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens, context {len(original_input_ids[0])})")
+ original_tokens = len(original_input_ids[0])
+ new_tokens = len(output)-original_tokens
+ print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
return
diff --git a/requirements.txt b/requirements.txt
index 79da715d..ffa6b51a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,7 +1,8 @@
accelerate==0.18.0
bitsandbytes==0.37.2
flexgen==0.1.7
-gradio==3.23.0
+gradio==3.24.0
+llamacpp==0.1.11
markdown
numpy
peft==0.2.0
diff --git a/server.py b/server.py
index 27223f84..ebd9c81e 100644
--- a/server.py
+++ b/server.py
@@ -166,7 +166,7 @@ def create_settings_menus(default_preset):
with gr.Column():
create_model_and_preset_menus()
with gr.Column():
- shared.gradio['seed'] = gr.Number(value=-1, label='Seed (-1 for random)')
+ shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
with gr.Row():
with gr.Column():
@@ -217,10 +217,11 @@ def create_settings_menus(default_preset):
shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
-def set_interface_arguments(interface_mode, extensions, cmd_active):
+def set_interface_arguments(interface_mode, extensions, bool_active):
modes = ["default", "notebook", "chat", "cai_chat"]
cmd_list = vars(shared.args)
- cmd_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
+ bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
+ #int_list = [k for k in cmd_list if type(k) is int]
shared.args.extensions = extensions
for k in modes[1:]:
@@ -228,9 +229,9 @@ def set_interface_arguments(interface_mode, extensions, cmd_active):
if interface_mode != "default":
exec(f"shared.args.{interface_mode} = True")
- for k in cmd_list:
+ for k in bool_list:
exec(f"shared.args.{k} = False")
- for k in cmd_active:
+ for k in bool_active:
exec(f"shared.args.{k} = True")
shared.need_restart = True
@@ -408,7 +409,7 @@ def create_interface():
with gr.Row():
with gr.Column(scale=4):
with gr.Tab('Raw'):
- shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_id="textbox", lines=25)
+ shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_id="textbox", lines=27)
with gr.Tab('Markdown'):
shared.gradio['markdown'] = gr.Markdown()
with gr.Tab('HTML'):
@@ -442,7 +443,7 @@ def create_interface():
with gr.Tab("Text generation", elem_id="main"):
with gr.Row():
with gr.Column():
- shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input')
+ shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=21, label='Input')
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
shared.gradio['Generate'] = gr.Button('Generate')
with gr.Row():
@@ -455,7 +456,7 @@ def create_interface():
with gr.Column():
with gr.Tab('Raw'):
- shared.gradio['output_textbox'] = gr.Textbox(lines=25, label='Output')
+ shared.gradio['output_textbox'] = gr.Textbox(lines=27, label='Output')
with gr.Tab('Markdown'):
shared.gradio['markdown'] = gr.Markdown()
with gr.Tab('HTML'):
@@ -483,16 +484,17 @@ def create_interface():
current_mode = mode
break
cmd_list = vars(shared.args)
- cmd_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
- active_cmd_list = [k for k in cmd_list if vars(shared.args)[k]]
+ bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
+ bool_active = [k for k in bool_list if vars(shared.args)[k]]
+ #int_list = [k for k in cmd_list if type(k) is int]
gr.Markdown("*Experimental*")
shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
shared.gradio['extensions_menu'] = gr.CheckboxGroup(choices=get_available_extensions(), value=shared.args.extensions, label="Available extensions")
- shared.gradio['cmd_arguments_menu'] = gr.CheckboxGroup(choices=cmd_list, value=active_cmd_list, label="Boolean command-line flags")
+ shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=bool_list, value=bool_active, label="Boolean command-line flags")
shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", type="primary")
- shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'cmd_arguments_menu']], None)
+ shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None)
shared.gradio['reset_interface'].click(lambda : None, None, None, _js='() => {document.body.innerHTML=\'
Reloading...
\'; setTimeout(function(){location.reload()},2500); return []}')
if shared.args.extensions is not None:
diff --git a/settings-template.json b/settings-template.json
index da767cda..4ce0ca7a 100644
--- a/settings-template.json
+++ b/settings-template.json
@@ -2,6 +2,7 @@
"max_new_tokens": 200,
"max_new_tokens_min": 1,
"max_new_tokens_max": 2000,
+ "seed": -1,
"name1": "You",
"name2": "Assistant",
"context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.",
@@ -18,7 +19,7 @@
],
"presets": {
"default": "NovelAI-Sphinx Moth",
- ".*pygmalion": "Pygmalion",
+ ".*pygmalion": "NovelAI-Storywriter",
".*RWKV": "Naive"
},
"prompts": {