From 8e752a777b272606f22cb741b03e062de4ddb8fe Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 19 Nov 2024 13:29:26 +0200 Subject: [PATCH] llama : add check for KV cache shifts (#10401) ggml-ci --- common/common.cpp | 6 ++++++ include/llama.h | 3 +++ src/llama.cpp | 6 +++++- 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/common/common.cpp b/common/common.cpp index 930374621..d314523db 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -875,6 +875,12 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } + if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) { + LOG_ERR("%s: KV cache shifting is not supported for this model (--no-context-shift to disable)'\n", __func__); + llama_free_model(model); + return iparams; + } + if (!params.control_vectors.empty()) { if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1; if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model); diff --git a/include/llama.h b/include/llama.h index bc268e799..90791d5f5 100644 --- a/include/llama.h +++ b/include/llama.h @@ -667,6 +667,9 @@ extern "C" { // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) LLAMA_API void llama_kv_cache_update(struct llama_context * ctx); + // Check if the context supports KV cache shifting + LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx); + // // State / sessions // diff --git a/src/llama.cpp b/src/llama.cpp index 4f31f25b1..c51b36e66 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -18213,7 +18213,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { // apply K-shift if needed if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) { - if (lctx.model.arch == LLM_ARCH_DEEPSEEK2) { // not supported due to MLA + if (!llama_kv_cache_can_shift(&lctx)) { GGML_ABORT("Deepseek2 does not support K-shift"); } @@ -20462,6 +20462,10 @@ void llama_kv_cache_update(struct llama_context * ctx) { llama_kv_cache_update_internal(*ctx); } +bool llama_kv_cache_can_shift(struct llama_context * ctx) { + return ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA +} + // deprecated size_t llama_get_state_size(struct llama_context * ctx) { return llama_state_get_size(ctx);