From addad3c63eff903f853055a95eb9487df6143a3e Mon Sep 17 00:00:00 2001 From: Diner Burger Date: Tue, 17 Dec 2024 15:43:48 -0500 Subject: [PATCH] Allow more granular KV cache settings (#6561) --- modules/exllamav2.py | 20 +++++++++--- modules/exllamav2_hf.py | 26 +++++++++++----- modules/llamacpp_hf.py | 10 +++--- modules/llamacpp_model.py | 38 +++++++++++++++++++---- modules/loaders.py | 12 +++----- modules/shared.py | 65 +++++++++++++++++++++++++++++++++++++-- modules/ui.py | 3 +- modules/ui_model_menu.py | 3 +- 8 files changed, 140 insertions(+), 37 deletions(-) diff --git a/modules/exllamav2.py b/modules/exllamav2.py index 0498c488..be5ef8d3 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -2,17 +2,19 @@ import traceback from pathlib import Path import torch + from exllamav2 import ( ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Cache_8bit, ExLlamaV2Cache_Q4, + ExLlamaV2Cache_Q6, + ExLlamaV2Cache_Q8, ExLlamaV2Cache_TP, ExLlamaV2Config, ExLlamaV2Tokenizer ) from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator - from modules import shared from modules.logging_colors import logger from modules.text_generation import get_max_prompt_length @@ -57,12 +59,22 @@ class Exllamav2Model: model.load(split) # Determine the correct cache type - if shared.args.cache_8bit: + kv_cache_type = 'fp16' + if shared.args.cache_type: + kv_cache_type = shared.args.cache_type.lower() + + if kv_cache_type == 'fp16': + cache_type = ExLlamaV2Cache + elif kv_cache_type == 'fp8': cache_type = ExLlamaV2Cache_8bit - elif shared.args.cache_4bit: + elif kv_cache_type == 'q8': + cache_type = ExLlamaV2Cache_Q8 + elif kv_cache_type == 'q6': + cache_type = ExLlamaV2Cache_Q6 + elif kv_cache_type == 'q4': cache_type = ExLlamaV2Cache_Q4 else: - cache_type = ExLlamaV2Cache + raise ValueError(f"Invalid cache type for ExLlamaV2: {cache_type}. Valid options are: fp16, fp8, q8, q6, q4.") # Use TP if specified if shared.args.enable_tp: diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index 320a8d24..f6b943c8 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -4,18 +4,20 @@ from pathlib import Path from typing import Any, Dict, Optional, Union import torch +from torch.nn import CrossEntropyLoss +from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast + from exllamav2 import ( ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Cache_8bit, ExLlamaV2Cache_Q4, + ExLlamaV2Cache_Q6, + ExLlamaV2Cache_Q8, ExLlamaV2Cache_TP, ExLlamaV2Config ) -from torch.nn import CrossEntropyLoss -from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel -from transformers.modeling_outputs import CausalLMOutputWithPast - from modules import shared from modules.logging_colors import logger @@ -45,12 +47,22 @@ class Exllamav2HF(PreTrainedModel): self.ex_model.load(split) # Determine the correct cache type - if shared.args.cache_8bit: + kv_cache_type = 'fp16' + if shared.args.cache_type: + kv_cache_type = shared.args.cache_type.lower() + + if kv_cache_type == 'fp16': + cache_type = ExLlamaV2Cache + elif kv_cache_type == 'fp8': cache_type = ExLlamaV2Cache_8bit - elif shared.args.cache_4bit: + elif kv_cache_type == 'q8': + cache_type = ExLlamaV2Cache_Q8 + elif kv_cache_type == 'q6': + cache_type = ExLlamaV2Cache_Q6 + elif kv_cache_type == 'q4': cache_type = ExLlamaV2Cache_Q4 else: - cache_type = ExLlamaV2Cache + raise ValueError(f"Invalid cache type for ExLlamaV2: {cache_type}. Valid options are: fp16, fp8, q8, q6, q4.") # Use TP if specified if shared.args.enable_tp: diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py index 6611a7c1..79098250 100644 --- a/modules/llamacpp_hf.py +++ b/modules/llamacpp_hf.py @@ -9,6 +9,7 @@ from transformers.modeling_outputs import CausalLMOutputWithPast from modules import shared from modules.llama_cpp_python_hijack import llama_cpp_lib +from modules.llamacpp_model import get_llamacpp_cache_type_for_string from modules.logging_colors import logger @@ -196,12 +197,9 @@ class LlamacppHF(PreTrainedModel): 'flash_attn': shared.args.flash_attn } - if shared.args.cache_4bit: - params["type_k"] = 2 - params["type_v"] = 2 - elif shared.args.cache_8bit: - params["type_k"] = 8 - params["type_v"] = 8 + if shared.args.cache_type: + params["type_k"] = get_llamacpp_cache_type_for_string(shared.args.cache_type) + params["type_v"] = get_llamacpp_cache_type_for_string(shared.args.cache_type) Llama = llama_cpp_lib().Llama model = Llama(**params) diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index 96f7ed56..c8b3456e 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -10,6 +10,35 @@ from modules.llama_cpp_python_hijack import llama_cpp_lib from modules.logging_colors import logger from modules.text_generation import get_max_prompt_length +llamacpp_quant_mapping = { + 'f32': 0, + 'fp16': 1, + 'q4_0': 2, + 'q4_1': 3, + 'q5_0': 6, + 'q5_1': 7, + 'q8_0': 8, + 'q8_1': 9, + 'q2_k': 10, + 'q3_k': 11, + 'q4_k': 12, + 'q5_k': 13, + 'q6_k': 14, + 'q8_k': 15, + 'iq4_nl': 20, + 'bf16': 30, +} + +llamacpp_valid_cache_types = {'fp16', 'q8_0', 'q4_0'} + + +def get_llamacpp_cache_type_for_string(quant_type: str): + quant_type = quant_type.lower() + if quant_type in llamacpp_valid_cache_types: + return llamacpp_quant_mapping[quant_type] + else: + raise ValueError(f"Invalid cache type for llama.cpp: {quant_type}. Valid options are: fp16, q8_0, q4_0.") + def ban_eos_logits_processor(eos_token, input_ids, logits): logits[eos_token] = -float('inf') @@ -75,12 +104,9 @@ class LlamaCppModel: 'flash_attn': shared.args.flash_attn } - if shared.args.cache_4bit: - params["type_k"] = 2 - params["type_v"] = 2 - elif shared.args.cache_8bit: - params["type_k"] = 8 - params["type_v"] = 8 + if shared.args.cache_type: + params["type_k"] = get_llamacpp_cache_type_for_string(shared.args.cache_type) + params["type_v"] = get_llamacpp_cache_type_for_string(shared.args.cache_type) result.model = Llama(**params) if cache_capacity > 0: diff --git a/modules/loaders.py b/modules/loaders.py index deee00a7..4cb7e349 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -31,8 +31,7 @@ loaders_and_params = OrderedDict({ 'llama.cpp': [ 'n_ctx', 'n_gpu_layers', - 'cache_8bit', - 'cache_4bit', + 'cache_type', 'tensor_split', 'n_batch', 'threads', @@ -54,8 +53,7 @@ loaders_and_params = OrderedDict({ 'llamacpp_HF': [ 'n_ctx', 'n_gpu_layers', - 'cache_8bit', - 'cache_4bit', + 'cache_type', 'tensor_split', 'n_batch', 'threads', @@ -87,8 +85,7 @@ loaders_and_params = OrderedDict({ 'no_xformers', 'no_sdpa', 'num_experts_per_token', - 'cache_8bit', - 'cache_4bit', + 'cache_type', 'autosplit', 'enable_tp', 'alpha_value', @@ -103,8 +100,7 @@ loaders_and_params = OrderedDict({ 'no_xformers', 'no_sdpa', 'num_experts_per_token', - 'cache_8bit', - 'cache_4bit', + 'cache_type', 'autosplit', 'enable_tp', 'alpha_value', diff --git a/modules/shared.py b/modules/shared.py index 599d6492..41865e22 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -142,8 +142,6 @@ group.add_argument('--cfg-cache', action='store_true', help='ExLlamav2_HF: Creat group.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.') group.add_argument('--no_xformers', action='store_true', help='Force xformers to not be used.') group.add_argument('--no_sdpa', action='store_true', help='Force Torch SDPA to not be used.') -group.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.') -group.add_argument('--cache_4bit', action='store_true', help='Use Q4 cache to save VRAM.') group.add_argument('--num_experts_per_token', type=int, default=2, help='Number of experts to use for generation. Applies to MoE models like Mixtral.') group.add_argument('--enable_tp', action='store_true', help='Enable Tensor Parallelism (TP) in ExLlamaV2.') @@ -166,6 +164,10 @@ group.add_argument('--hqq-backend', type=str, default='PYTORCH_COMPILE', help='B group = parser.add_argument_group('TensorRT-LLM') group.add_argument('--cpp-runner', action='store_true', help='Use the ModelRunnerCpp runner, which is faster than the default ModelRunner but doesn\'t support streaming yet.') +# Cache +group = parser.add_argument_group('Cache') +group.add_argument('--cache_type', type=str, default=None, help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.') + # DeepSpeed group = parser.add_argument_group('DeepSpeed') group.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.') @@ -213,6 +215,8 @@ group.add_argument('--pre_layer', type=int, nargs='+', help='DEPRECATED') group.add_argument('--checkpoint', type=str, help='DEPRECATED') group.add_argument('--monkey-patch', action='store_true', help='DEPRECATED') group.add_argument('--no_inject_fused_attention', action='store_true', help='DEPRECATED') +group.add_argument('--cache_4bit', action='store_true', help='DEPRECATED') +group.add_argument('--cache_8bit', action='store_true', help='DEPRECATED') group.add_argument('--chat-buttons', action='store_true', help='DEPRECATED') args = parser.parse_args() @@ -270,6 +274,59 @@ def fix_loader_name(name): return 'TensorRT-LLM' +def transform_legacy_kv_cache_options(opts): + # Handle both argparse.Namespace and dict here + def get(key): + return opts.get(key) if isinstance(opts, dict) else getattr(opts, key, None) + + def set(key, value): + if isinstance(opts, dict): + opts[key] = value + else: + setattr(opts, key, value) + + def del_key(key, fallback_set): + # only remove from user dict, can't delete from argparse.Namespace + if type(opts) is dict: + if key in opts: + del opts[key] + else: + setattr(opts, key, fallback_set) + + # Retrieve values + loader = get('loader') + cache_type = get('cache_type') + cache_8bit = get('cache_8bit') + cache_4bit = get('cache_4bit') + + # Determine cache type based on loader or legacy flags + if not cache_type: + if not loader: + # Legacy behavior: prefer 8-bit over 4-bit to minimize breakage + if cache_8bit: + set('cache_type', 'fp8') + elif cache_4bit: + set('cache_type', 'q4') + elif loader.lower() in ['exllamav2', 'exllamav2_hf']: + # ExLlamaV2 loader-specific cache type + if cache_8bit: + set('cache_type', 'fp8') + elif cache_4bit: + set('cache_type', 'q4') + elif loader.lower() in ['llama.cpp', 'llamacpp_hf']: + # Llama.cpp loader-specific cache type + if cache_4bit: + set('cache_type', 'q4_0') + elif cache_8bit: + set('cache_type', 'q8_0') + + # Clean up legacy keys + del_key('cache_4bit', False) + del_key('cache_8bit', False) + + return opts + + def add_extension(name, last=False): if args.extensions is None: args.extensions = [name] @@ -298,10 +355,14 @@ def load_user_config(): else: user_config = {} + for model_name in user_config: + user_config[model_name] = transform_legacy_kv_cache_options(user_config[model_name]) + return user_config args.loader = fix_loader_name(args.loader) +args = transform_legacy_kv_cache_options(args) # Activate the multimodal extension if args.multimodal_pipeline is not None: diff --git a/modules/ui.py b/modules/ui.py index 9ef958ab..ff5bc165 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -130,8 +130,7 @@ def list_model_elements(): 'no_xformers', 'no_sdpa', 'num_experts_per_token', - 'cache_8bit', - 'cache_4bit', + 'cache_type', 'autosplit', 'enable_tp', 'threads', diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 34d58177..eb61cd82 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -118,8 +118,7 @@ def create_ui(): shared.gradio['flash_attn'] = gr.Checkbox(label="flash_attn", value=shared.args.flash_attn, info='Use flash-attention.') shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices) shared.gradio['tensorcores'] = gr.Checkbox(label="tensorcores", value=shared.args.tensorcores, info='NVIDIA only: use llama-cpp-python compiled with tensor cores support. This may increase performance on newer cards.') - shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.') - shared.gradio['cache_4bit'] = gr.Checkbox(label="cache_4bit", value=shared.args.cache_4bit, info='Use Q4 cache to save VRAM.') + shared.gradio['cache_type'] = gr.Dropdown(label="cache_type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q6', 'q4'], value=shared.args.cache_type, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.') shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming_llm", value=shared.args.streaming_llm, info='(experimental) Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.') shared.gradio['attention_sink_size'] = gr.Number(label="attention_sink_size", value=shared.args.attention_sink_size, precision=0, info='StreamingLLM: number of sink tokens. Only used if the trimmed prompt doesn\'t share a prefix with the old prompt.') shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='llama.cpp: Use llama-cpp-python compiled without GPU acceleration. Transformers: use PyTorch in CPU mode.')