mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-12 05:17:40 +01:00
Add cache_8bit option
This commit is contained in:
parent
42f816312d
commit
c0655475ae
@ -337,6 +337,7 @@ Optionally, you can use the following command-line flags:
|
||||
|`--max_seq_len MAX_SEQ_LEN` | Maximum sequence length. |
|
||||
|`--cfg-cache` | ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama. |
|
||||
|`--no_flash_attn` | Force flash-attention to not be used. |
|
||||
|`--cache_8bit` | Use 8-bit cache to save VRAM. |
|
||||
|
||||
#### AutoGPTQ
|
||||
|
||||
|
@ -6,6 +6,7 @@ import torch
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Cache,
|
||||
ExLlamaV2Cache_8bit,
|
||||
ExLlamaV2Config,
|
||||
ExLlamaV2Tokenizer
|
||||
)
|
||||
@ -57,7 +58,11 @@ class Exllamav2Model:
|
||||
model.load(split)
|
||||
|
||||
tokenizer = ExLlamaV2Tokenizer(config)
|
||||
cache = ExLlamaV2Cache(model)
|
||||
if shared.args.cache_8bit:
|
||||
cache = ExLlamaV2Cache_8bit(model)
|
||||
else:
|
||||
cache = ExLlamaV2Cache(model)
|
||||
|
||||
generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
|
||||
|
||||
result = self()
|
||||
|
@ -4,7 +4,12 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Cache,
|
||||
ExLlamaV2Cache_8bit,
|
||||
ExLlamaV2Config
|
||||
)
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
@ -40,11 +45,18 @@ class Exllamav2HF(PreTrainedModel):
|
||||
self.generation_config = GenerationConfig()
|
||||
self.loras = None
|
||||
|
||||
self.ex_cache = ExLlamaV2Cache(self.ex_model)
|
||||
self.past_seq = None
|
||||
if shared.args.cache_8bit:
|
||||
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model)
|
||||
else:
|
||||
self.ex_cache = ExLlamaV2Cache(self.ex_model)
|
||||
|
||||
self.past_seq = None
|
||||
if shared.args.cfg_cache:
|
||||
self.ex_cache_negative = ExLlamaV2Cache(self.ex_model)
|
||||
if shared.args.cache_8bit:
|
||||
self.ex_cache_negative = ExLlamaV2Cache_8bit(self.ex_model)
|
||||
else:
|
||||
self.ex_cache_negative = ExLlamaV2Cache(self.ex_model)
|
||||
|
||||
self.past_seq_negative = None
|
||||
|
||||
def _validate_model_class(self):
|
||||
|
@ -41,6 +41,8 @@ loaders_and_params = OrderedDict({
|
||||
'gpu_split',
|
||||
'max_seq_len',
|
||||
'cfg_cache',
|
||||
'no_flash_attn',
|
||||
'cache_8bit',
|
||||
'alpha_value',
|
||||
'compress_pos_emb',
|
||||
'use_fast',
|
||||
@ -56,6 +58,8 @@ loaders_and_params = OrderedDict({
|
||||
'ExLlamav2': [
|
||||
'gpu_split',
|
||||
'max_seq_len',
|
||||
'no_flash_attn',
|
||||
'cache_8bit',
|
||||
'alpha_value',
|
||||
'compress_pos_emb',
|
||||
],
|
||||
|
@ -118,6 +118,7 @@ parser.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM
|
||||
parser.add_argument('--max_seq_len', type=int, default=2048, help='Maximum sequence length.')
|
||||
parser.add_argument('--cfg-cache', action='store_true', help='ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama.')
|
||||
parser.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')
|
||||
parser.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.')
|
||||
|
||||
# AutoGPTQ
|
||||
parser.add_argument('--triton', action='store_true', help='Use triton.')
|
||||
|
@ -68,6 +68,8 @@ def list_model_elements():
|
||||
'no_use_cuda_fp16',
|
||||
'disable_exllama',
|
||||
'cfg_cache',
|
||||
'no_flash_attn',
|
||||
'cache_8bit',
|
||||
'threads',
|
||||
'threads_batch',
|
||||
'n_batch',
|
||||
|
@ -125,6 +125,8 @@ def create_ui():
|
||||
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='To enable this option, start the web UI with the --trust-remote-code flag. It is necessary for some models.', interactive=shared.args.trust_remote_code)
|
||||
shared.gradio['use_fast'] = gr.Checkbox(label="use_fast", value=shared.args.use_fast, info='Set use_fast=True while loading the tokenizer. May trigger a conversion that takes several minutes.')
|
||||
shared.gradio['disable_exllama'] = gr.Checkbox(label="disable_exllama", value=shared.args.disable_exllama, info='Disable ExLlama kernel.')
|
||||
shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.')
|
||||
shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.')
|
||||
shared.gradio['gptq_for_llama_info'] = gr.Markdown('GPTQ-for-LLaMa support is currently only kept for compatibility with older GPUs. AutoGPTQ or ExLlama is preferred when compatible. GPTQ-for-LLaMa is installed by default with the webui on supported systems. Otherwise, it has to be installed manually following the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md#installation-1).')
|
||||
shared.gradio['exllama_info'] = gr.Markdown('For more information, consult the [docs](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama.md).')
|
||||
shared.gradio['exllama_HF_info'] = gr.Markdown('ExLlama_HF is a wrapper that lets you use ExLlama like a Transformers model, which means it can use the Transformers samplers. It\'s a bit slower than the regular ExLlama.')
|
||||
|
Loading…
x
Reference in New Issue
Block a user