From 1bb30bf28cb5a7adf111bc41c935bdaf128397e7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 21 Nov 2024 10:22:47 +0200 Subject: [PATCH] llama : handle KV shift for recurrent models (#10402) ggml-ci --- src/llama.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index c51b36e66..001711037 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -18211,13 +18211,13 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { static void llama_kv_cache_update_internal(struct llama_context & lctx) { bool need_reserve = false; - // apply K-shift if needed - if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) { + if (lctx.kv_self.has_shift) { if (!llama_kv_cache_can_shift(&lctx)) { - GGML_ABORT("Deepseek2 does not support K-shift"); + GGML_ABORT("The current context does not support K-shift"); } - { + // apply K-shift if needed + if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { ggml_backend_sched_reset(lctx.sched.get()); ggml_cgraph * gf = llama_build_graph_k_shift(lctx); @@ -20463,7 +20463,7 @@ void llama_kv_cache_update(struct llama_context * ctx) { } bool llama_kv_cache_can_shift(struct llama_context * ctx) { - return ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA + return !ctx->kv_self.recurrent && ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA } // deprecated