From 3ce30e07c98a8cf5ce7a22a866d4d0fd5436216f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 Dec 2023 15:40:23 +0200 Subject: [PATCH] llama : pass KV cache type through API --- common/common.cpp | 34 ++++++++++++++++++++++++++++++++++ common/common.h | 5 ++++- llama.cpp | 29 ++++++++++++++++++----------- llama.h | 3 +++ 4 files changed, 59 insertions(+), 12 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 43c374d5c..77332d5db 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -500,6 +500,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.dump_kv_cache = true; } else if (arg == "-nkvo" || arg == "--no-kv-offload") { params.no_kv_offload = true; + } else if (arg == "-ctk" || arg == "--cache-type-k") { + params.cache_type_k = argv[++i]; + } else if (arg == "-ctv" || arg == "--cache-type-v") { + params.cache_type_v = argv[++i]; } else if (arg == "--multiline-input") { params.multiline_input = true; } else if (arg == "--simple-io") { @@ -844,6 +848,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" verbose print of the KV cache\n"); printf(" -nkvo, --no-kv-offload\n"); printf(" disable KV offload\n"); + printf(" -ctk TYPE, --cache-type-k TYPE\n"); + printf(" KV cache data type for K (default: %s)\n", params.cache_type_k.c_str()); + printf(" -ctv TYPE, --cache-type-v TYPE\n"); + printf(" KV cache data type for V (default: %s)\n", params.cache_type_v.c_str()); printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n"); printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n"); @@ -908,6 +916,29 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & return mparams; } +static ggml_type kv_cache_type_from_str(const std::string & s) { + if (s == "f16") { + return GGML_TYPE_F16; + } + if (s == "q8_0") { + return GGML_TYPE_Q8_0; + } + if (s == "q4_0") { + return GGML_TYPE_Q4_0; + } + if (s == "q4_1") { + return GGML_TYPE_Q4_1; + } + if (s == "q5_0") { + return GGML_TYPE_Q5_0; + } + if (s == "q5_1") { + return GGML_TYPE_Q5_1; + } + + throw std::runtime_error("Invalid cache type: " + s); +} + struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { auto cparams = llama_context_default_params(); @@ -930,6 +961,9 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.yarn_orig_ctx = params.yarn_orig_ctx; cparams.offload_kqv = !params.no_kv_offload; + cparams.type_k = kv_cache_type_from_str(params.cache_type_k); + cparams.type_v = kv_cache_type_from_str(params.cache_type_v); + return cparams; } diff --git a/common/common.h b/common/common.h index 2664c8fc1..7f0d03e41 100644 --- a/common/common.h +++ b/common/common.h @@ -125,9 +125,12 @@ struct gpt_params { bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes bool no_kv_offload = false; // disable KV offloading + std::string cache_type_k = "f16"; // KV cache data type for the K + std::string cache_type_v = "f16"; // KV cache data type for the V + // multimodal models (see examples/llava) std::string mmproj = ""; // path to multimodal projector - std::string image = ""; // path to an image file + std::string image = ""; // path to an image file }; bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params); diff --git a/llama.cpp b/llama.cpp index a70e40dba..3f951dbe3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8580,6 +8580,8 @@ struct llama_context_params llama_context_default_params() { /*.yarn_beta_fast =*/ 32.0f, /*.yarn_beta_slow =*/ 1.0f, /*.yarn_orig_ctx =*/ 0, + /*.type_k =*/ GGML_TYPE_F16, + /*.type_v =*/ GGML_TYPE_F16, /*.mul_mat_q =*/ true, /*.f16_kv =*/ true, /*.logits_all =*/ false, @@ -8737,31 +8739,36 @@ struct llama_context * llama_new_context_with_model( //const ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - // TODO: move as params - const ggml_type k_type = GGML_TYPE_Q8_0; - const ggml_type v_type = GGML_TYPE_F16; + const ggml_type type_k = params.type_k; + const ggml_type type_v = params.type_v; - GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(k_type) == 0); - GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(v_type) == 0); + GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(type_k) == 0); + GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(type_v) == 0); // reserve memory for context buffers if (!hparams.vocab_only) { - if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, k_type, v_type, cparams.n_ctx, model->n_gpu_layers, cparams.offload_kqv)) { + if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, type_k, type_v, cparams.n_ctx, model->n_gpu_layers, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; } { - // const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v); - size_t memory_size = 0; + size_t memory_size_k = 0; + size_t memory_size_v = 0; + for (auto & k : ctx->kv_self.k_l) { - memory_size += ggml_nbytes(k); + memory_size_k += ggml_nbytes(k); } + for (auto & v : ctx->kv_self.v_l) { - memory_size += ggml_nbytes(v); + memory_size_v += ggml_nbytes(v); } - LLAMA_LOG_INFO("%s: kv self size = %7.2f MiB\n", __func__, memory_size / 1024.0 / 1024.0); + + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } // resized during inference diff --git a/llama.h b/llama.h index c1593c9b0..e45f12975 100644 --- a/llama.h +++ b/llama.h @@ -191,6 +191,9 @@ extern "C" { float yarn_beta_slow; // YaRN high correction dim uint32_t yarn_orig_ctx; // YaRN original context size + ggml_type type_k; // data type for K cache + ggml_type type_v; // data type for V cache + // Keep the booleans together to avoid misalignment during copy-by-value. bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true) bool f16_kv; // use fp16 for KV cache, fp32 otherwise