UI: Set cache_type to fp16 by default

This commit is contained in:
oobabooga 2024-12-17 19:44:20 -08:00
parent ddccc0d657
commit 60c93e0c66
5 changed files with 6 additions and 11 deletions

View File

@ -59,9 +59,7 @@ class Exllamav2Model:
model.load(split)
# Determine the correct cache type
kv_cache_type = 'fp16'
if shared.args.cache_type:
kv_cache_type = shared.args.cache_type.lower()
kv_cache_type = shared.args.cache_type.lower()
if kv_cache_type == 'fp16':
cache_type = ExLlamaV2Cache

View File

@ -47,9 +47,7 @@ class Exllamav2HF(PreTrainedModel):
self.ex_model.load(split)
# Determine the correct cache type
kv_cache_type = 'fp16'
if shared.args.cache_type:
kv_cache_type = shared.args.cache_type.lower()
kv_cache_type = shared.args.cache_type.lower()
if kv_cache_type == 'fp16':
cache_type = ExLlamaV2Cache

View File

@ -197,7 +197,7 @@ class LlamacppHF(PreTrainedModel):
'flash_attn': shared.args.flash_attn
}
if shared.args.cache_type:
if shared.args.cache_type != 'fp16':
params["type_k"] = get_llamacpp_cache_type_for_string(shared.args.cache_type)
params["type_v"] = get_llamacpp_cache_type_for_string(shared.args.cache_type)

View File

@ -104,7 +104,7 @@ class LlamaCppModel:
'flash_attn': shared.args.flash_attn
}
if shared.args.cache_type:
if shared.args.cache_type != 'fp16':
params["type_k"] = get_llamacpp_cache_type_for_string(shared.args.cache_type)
params["type_v"] = get_llamacpp_cache_type_for_string(shared.args.cache_type)

View File

@ -166,7 +166,7 @@ group.add_argument('--cpp-runner', action='store_true', help='Use the ModelRunne
# Cache
group = parser.add_argument_group('Cache')
group.add_argument('--cache_type', type=str, default=None, help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.')
group.add_argument('--cache_type', type=str, default='fp16', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.')
# DeepSpeed
group = parser.add_argument_group('DeepSpeed')
@ -295,12 +295,11 @@ def transform_legacy_kv_cache_options(opts):
# Retrieve values
loader = get('loader')
cache_type = get('cache_type')
cache_8bit = get('cache_8bit')
cache_4bit = get('cache_4bit')
# Determine cache type based on loader or legacy flags
if not cache_type:
if cache_8bit or cache_4bit:
if not loader:
# Legacy behavior: prefer 8-bit over 4-bit to minimize breakage
if cache_8bit: