diff --git a/README.md b/README.md index ad3736a4..17c9add7 100644 --- a/README.md +++ b/README.md @@ -337,6 +337,7 @@ Optionally, you can use the following command-line flags: |`--max_seq_len MAX_SEQ_LEN` | Maximum sequence length. | |`--cfg-cache` | ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama. | |`--no_flash_attn` | Force flash-attention to not be used. | +|`--cache_8bit` | Use 8-bit cache to save VRAM. | #### AutoGPTQ diff --git a/modules/exllamav2.py b/modules/exllamav2.py index 3f3b3587..e2bcfd1b 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -6,6 +6,7 @@ import torch from exllamav2 import ( ExLlamaV2, ExLlamaV2Cache, + ExLlamaV2Cache_8bit, ExLlamaV2Config, ExLlamaV2Tokenizer ) @@ -57,7 +58,11 @@ class Exllamav2Model: model.load(split) tokenizer = ExLlamaV2Tokenizer(config) - cache = ExLlamaV2Cache(model) + if shared.args.cache_8bit: + cache = ExLlamaV2Cache_8bit(model) + else: + cache = ExLlamaV2Cache(model) + generator = ExLlamaV2BaseGenerator(model, cache, tokenizer) result = self() diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index 5d4aa515..30e3fe48 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -4,7 +4,12 @@ from pathlib import Path from typing import Any, Dict, Optional, Union import torch -from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config +from exllamav2 import ( + ExLlamaV2, + ExLlamaV2Cache, + ExLlamaV2Cache_8bit, + ExLlamaV2Config +) from torch.nn import CrossEntropyLoss from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast @@ -40,11 +45,18 @@ class Exllamav2HF(PreTrainedModel): self.generation_config = GenerationConfig() self.loras = None - self.ex_cache = ExLlamaV2Cache(self.ex_model) - self.past_seq = None + if shared.args.cache_8bit: + self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model) + else: + self.ex_cache = ExLlamaV2Cache(self.ex_model) + self.past_seq = None if shared.args.cfg_cache: - self.ex_cache_negative = ExLlamaV2Cache(self.ex_model) + if shared.args.cache_8bit: + self.ex_cache_negative = ExLlamaV2Cache_8bit(self.ex_model) + else: + self.ex_cache_negative = ExLlamaV2Cache(self.ex_model) + self.past_seq_negative = None def _validate_model_class(self): diff --git a/modules/loaders.py b/modules/loaders.py index 577ac9d5..cbd8211f 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -41,6 +41,8 @@ loaders_and_params = OrderedDict({ 'gpu_split', 'max_seq_len', 'cfg_cache', + 'no_flash_attn', + 'cache_8bit', 'alpha_value', 'compress_pos_emb', 'use_fast', @@ -56,6 +58,8 @@ loaders_and_params = OrderedDict({ 'ExLlamav2': [ 'gpu_split', 'max_seq_len', + 'no_flash_attn', + 'cache_8bit', 'alpha_value', 'compress_pos_emb', ], diff --git a/modules/shared.py b/modules/shared.py index e1da167f..8523930f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -118,6 +118,7 @@ parser.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM parser.add_argument('--max_seq_len', type=int, default=2048, help='Maximum sequence length.') parser.add_argument('--cfg-cache', action='store_true', help='ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama.') parser.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.') +parser.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.') # AutoGPTQ parser.add_argument('--triton', action='store_true', help='Use triton.') diff --git a/modules/ui.py b/modules/ui.py index 9d87bad6..c095c013 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -68,6 +68,8 @@ def list_model_elements(): 'no_use_cuda_fp16', 'disable_exllama', 'cfg_cache', + 'no_flash_attn', + 'cache_8bit', 'threads', 'threads_batch', 'n_batch', diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index b9a8cbe7..833c1308 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -125,6 +125,8 @@ def create_ui(): shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='To enable this option, start the web UI with the --trust-remote-code flag. It is necessary for some models.', interactive=shared.args.trust_remote_code) shared.gradio['use_fast'] = gr.Checkbox(label="use_fast", value=shared.args.use_fast, info='Set use_fast=True while loading the tokenizer. May trigger a conversion that takes several minutes.') shared.gradio['disable_exllama'] = gr.Checkbox(label="disable_exllama", value=shared.args.disable_exllama, info='Disable ExLlama kernel.') + shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.') + shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.') shared.gradio['gptq_for_llama_info'] = gr.Markdown('GPTQ-for-LLaMa support is currently only kept for compatibility with older GPUs. AutoGPTQ or ExLlama is preferred when compatible. GPTQ-for-LLaMa is installed by default with the webui on supported systems. Otherwise, it has to be installed manually following the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md#installation-1).') shared.gradio['exllama_info'] = gr.Markdown('For more information, consult the [docs](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama.md).') shared.gradio['exllama_HF_info'] = gr.Markdown('ExLlama_HF is a wrapper that lets you use ExLlama like a Transformers model, which means it can use the Transformers samplers. It\'s a bit slower than the regular ExLlama.')