Add cache_8bit option

This commit is contained in:
oobabooga 2023-11-02 11:23:04 -07:00
parent 42f816312d
commit c0655475ae
7 changed files with 32 additions and 5 deletions

View File

@ -337,6 +337,7 @@ Optionally, you can use the following command-line flags:
|`--max_seq_len MAX_SEQ_LEN` | Maximum sequence length. | |`--max_seq_len MAX_SEQ_LEN` | Maximum sequence length. |
|`--cfg-cache` | ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama. | |`--cfg-cache` | ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama. |
|`--no_flash_attn` | Force flash-attention to not be used. | |`--no_flash_attn` | Force flash-attention to not be used. |
|`--cache_8bit` | Use 8-bit cache to save VRAM. |
#### AutoGPTQ #### AutoGPTQ

View File

@ -6,6 +6,7 @@ import torch
from exllamav2 import ( from exllamav2 import (
ExLlamaV2, ExLlamaV2,
ExLlamaV2Cache, ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Config, ExLlamaV2Config,
ExLlamaV2Tokenizer ExLlamaV2Tokenizer
) )
@ -57,7 +58,11 @@ class Exllamav2Model:
model.load(split) model.load(split)
tokenizer = ExLlamaV2Tokenizer(config) tokenizer = ExLlamaV2Tokenizer(config)
if shared.args.cache_8bit:
cache = ExLlamaV2Cache_8bit(model)
else:
cache = ExLlamaV2Cache(model) cache = ExLlamaV2Cache(model)
generator = ExLlamaV2BaseGenerator(model, cache, tokenizer) generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
result = self() result = self()

View File

@ -4,7 +4,12 @@ from pathlib import Path
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
import torch import torch
from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config from exllamav2 import (
ExLlamaV2,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Config
)
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast
@ -40,11 +45,18 @@ class Exllamav2HF(PreTrainedModel):
self.generation_config = GenerationConfig() self.generation_config = GenerationConfig()
self.loras = None self.loras = None
if shared.args.cache_8bit:
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model)
else:
self.ex_cache = ExLlamaV2Cache(self.ex_model) self.ex_cache = ExLlamaV2Cache(self.ex_model)
self.past_seq = None
self.past_seq = None
if shared.args.cfg_cache: if shared.args.cfg_cache:
if shared.args.cache_8bit:
self.ex_cache_negative = ExLlamaV2Cache_8bit(self.ex_model)
else:
self.ex_cache_negative = ExLlamaV2Cache(self.ex_model) self.ex_cache_negative = ExLlamaV2Cache(self.ex_model)
self.past_seq_negative = None self.past_seq_negative = None
def _validate_model_class(self): def _validate_model_class(self):

View File

@ -41,6 +41,8 @@ loaders_and_params = OrderedDict({
'gpu_split', 'gpu_split',
'max_seq_len', 'max_seq_len',
'cfg_cache', 'cfg_cache',
'no_flash_attn',
'cache_8bit',
'alpha_value', 'alpha_value',
'compress_pos_emb', 'compress_pos_emb',
'use_fast', 'use_fast',
@ -56,6 +58,8 @@ loaders_and_params = OrderedDict({
'ExLlamav2': [ 'ExLlamav2': [
'gpu_split', 'gpu_split',
'max_seq_len', 'max_seq_len',
'no_flash_attn',
'cache_8bit',
'alpha_value', 'alpha_value',
'compress_pos_emb', 'compress_pos_emb',
], ],

View File

@ -118,6 +118,7 @@ parser.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM
parser.add_argument('--max_seq_len', type=int, default=2048, help='Maximum sequence length.') parser.add_argument('--max_seq_len', type=int, default=2048, help='Maximum sequence length.')
parser.add_argument('--cfg-cache', action='store_true', help='ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama.') parser.add_argument('--cfg-cache', action='store_true', help='ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama.')
parser.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.') parser.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')
parser.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.')
# AutoGPTQ # AutoGPTQ
parser.add_argument('--triton', action='store_true', help='Use triton.') parser.add_argument('--triton', action='store_true', help='Use triton.')

View File

@ -68,6 +68,8 @@ def list_model_elements():
'no_use_cuda_fp16', 'no_use_cuda_fp16',
'disable_exllama', 'disable_exllama',
'cfg_cache', 'cfg_cache',
'no_flash_attn',
'cache_8bit',
'threads', 'threads',
'threads_batch', 'threads_batch',
'n_batch', 'n_batch',

View File

@ -125,6 +125,8 @@ def create_ui():
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='To enable this option, start the web UI with the --trust-remote-code flag. It is necessary for some models.', interactive=shared.args.trust_remote_code) shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='To enable this option, start the web UI with the --trust-remote-code flag. It is necessary for some models.', interactive=shared.args.trust_remote_code)
shared.gradio['use_fast'] = gr.Checkbox(label="use_fast", value=shared.args.use_fast, info='Set use_fast=True while loading the tokenizer. May trigger a conversion that takes several minutes.') shared.gradio['use_fast'] = gr.Checkbox(label="use_fast", value=shared.args.use_fast, info='Set use_fast=True while loading the tokenizer. May trigger a conversion that takes several minutes.')
shared.gradio['disable_exllama'] = gr.Checkbox(label="disable_exllama", value=shared.args.disable_exllama, info='Disable ExLlama kernel.') shared.gradio['disable_exllama'] = gr.Checkbox(label="disable_exllama", value=shared.args.disable_exllama, info='Disable ExLlama kernel.')
shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.')
shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.')
shared.gradio['gptq_for_llama_info'] = gr.Markdown('GPTQ-for-LLaMa support is currently only kept for compatibility with older GPUs. AutoGPTQ or ExLlama is preferred when compatible. GPTQ-for-LLaMa is installed by default with the webui on supported systems. Otherwise, it has to be installed manually following the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md#installation-1).') shared.gradio['gptq_for_llama_info'] = gr.Markdown('GPTQ-for-LLaMa support is currently only kept for compatibility with older GPUs. AutoGPTQ or ExLlama is preferred when compatible. GPTQ-for-LLaMa is installed by default with the webui on supported systems. Otherwise, it has to be installed manually following the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md#installation-1).')
shared.gradio['exllama_info'] = gr.Markdown('For more information, consult the [docs](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama.md).') shared.gradio['exllama_info'] = gr.Markdown('For more information, consult the [docs](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama.md).')
shared.gradio['exllama_HF_info'] = gr.Markdown('ExLlama_HF is a wrapper that lets you use ExLlama like a Transformers model, which means it can use the Transformers samplers. It\'s a bit slower than the regular ExLlama.') shared.gradio['exllama_HF_info'] = gr.Markdown('ExLlama_HF is a wrapper that lets you use ExLlama like a Transformers model, which means it can use the Transformers samplers. It\'s a bit slower than the regular ExLlama.')