diff --git a/modules/exllamav2.py b/modules/exllamav2.py index be5ef8d3..9b6da83c 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -59,9 +59,7 @@ class Exllamav2Model: model.load(split) # Determine the correct cache type - kv_cache_type = 'fp16' - if shared.args.cache_type: - kv_cache_type = shared.args.cache_type.lower() + kv_cache_type = shared.args.cache_type.lower() if kv_cache_type == 'fp16': cache_type = ExLlamaV2Cache diff --git a/modules/exllamav2_hf.py b/modules/exllamav2_hf.py index f6b943c8..62d1e054 100644 --- a/modules/exllamav2_hf.py +++ b/modules/exllamav2_hf.py @@ -47,9 +47,7 @@ class Exllamav2HF(PreTrainedModel): self.ex_model.load(split) # Determine the correct cache type - kv_cache_type = 'fp16' - if shared.args.cache_type: - kv_cache_type = shared.args.cache_type.lower() + kv_cache_type = shared.args.cache_type.lower() if kv_cache_type == 'fp16': cache_type = ExLlamaV2Cache diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py index 79098250..f9964fe8 100644 --- a/modules/llamacpp_hf.py +++ b/modules/llamacpp_hf.py @@ -197,7 +197,7 @@ class LlamacppHF(PreTrainedModel): 'flash_attn': shared.args.flash_attn } - if shared.args.cache_type: + if shared.args.cache_type != 'fp16': params["type_k"] = get_llamacpp_cache_type_for_string(shared.args.cache_type) params["type_v"] = get_llamacpp_cache_type_for_string(shared.args.cache_type) diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index c8b3456e..6a76ee4e 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -104,7 +104,7 @@ class LlamaCppModel: 'flash_attn': shared.args.flash_attn } - if shared.args.cache_type: + if shared.args.cache_type != 'fp16': params["type_k"] = get_llamacpp_cache_type_for_string(shared.args.cache_type) params["type_v"] = get_llamacpp_cache_type_for_string(shared.args.cache_type) diff --git a/modules/shared.py b/modules/shared.py index 41865e22..cab61226 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -166,7 +166,7 @@ group.add_argument('--cpp-runner', action='store_true', help='Use the ModelRunne # Cache group = parser.add_argument_group('Cache') -group.add_argument('--cache_type', type=str, default=None, help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.') +group.add_argument('--cache_type', type=str, default='fp16', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.') # DeepSpeed group = parser.add_argument_group('DeepSpeed') @@ -295,12 +295,11 @@ def transform_legacy_kv_cache_options(opts): # Retrieve values loader = get('loader') - cache_type = get('cache_type') cache_8bit = get('cache_8bit') cache_4bit = get('cache_4bit') # Determine cache type based on loader or legacy flags - if not cache_type: + if cache_8bit or cache_4bit: if not loader: # Legacy behavior: prefer 8-bit over 4-bit to minimize breakage if cache_8bit: