From 4ea260098f297900853912eeb9e13c0381eb483d Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 29 Jun 2024 09:10:33 -0700 Subject: [PATCH] llama.cpp: add 4-bit/8-bit kv cache options --- modules/llamacpp_hf.py | 7 +++++++ modules/llamacpp_model.py | 7 +++++++ modules/loaders.py | 4 ++++ 3 files changed, 18 insertions(+) diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py index 74af5fbf..ed0347d7 100644 --- a/modules/llamacpp_hf.py +++ b/modules/llamacpp_hf.py @@ -221,6 +221,13 @@ class LlamacppHF(PreTrainedModel): 'flash_attn': shared.args.flash_attn } + if shared.args.cache_4bit: + params["type_k"] = 2 + params["type_v"] = 2 + elif shared.args.cache_8bit: + params["type_k"] = 8 + params["type_v"] = 8 + Llama = llama_cpp_lib().Llama model = Llama(**params) diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index d62fd517..fe7c1efe 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -100,6 +100,13 @@ class LlamaCppModel: 'flash_attn': shared.args.flash_attn } + if shared.args.cache_4bit: + params["type_k"] = 2 + params["type_v"] = 2 + elif shared.args.cache_8bit: + params["type_k"] = 8 + params["type_v"] = 8 + result.model = Llama(**params) if cache_capacity > 0: result.model.set_cache(LlamaCache(capacity_bytes=cache_capacity)) diff --git a/modules/loaders.py b/modules/loaders.py index 1da37595..5d3adacf 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -30,6 +30,8 @@ loaders_and_params = OrderedDict({ 'llama.cpp': [ 'n_ctx', 'n_gpu_layers', + 'cache_8bit', + 'cache_4bit', 'tensor_split', 'n_batch', 'threads', @@ -51,6 +53,8 @@ loaders_and_params = OrderedDict({ 'llamacpp_HF': [ 'n_ctx', 'n_gpu_layers', + 'cache_8bit', + 'cache_4bit', 'tensor_split', 'n_batch', 'threads',