llama : handle KV shift for recurrent models (#10402)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-21 10:22:47 +02:00 committed by GitHub
parent 87a533be57
commit 1bb30bf28c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -18211,13 +18211,13 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
static void llama_kv_cache_update_internal(struct llama_context & lctx) { static void llama_kv_cache_update_internal(struct llama_context & lctx) {
bool need_reserve = false; bool need_reserve = false;
// apply K-shift if needed if (lctx.kv_self.has_shift) {
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
if (!llama_kv_cache_can_shift(&lctx)) { if (!llama_kv_cache_can_shift(&lctx)) {
GGML_ABORT("Deepseek2 does not support K-shift"); GGML_ABORT("The current context does not support K-shift");
} }
{ // apply K-shift if needed
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
ggml_backend_sched_reset(lctx.sched.get()); ggml_backend_sched_reset(lctx.sched.get());
ggml_cgraph * gf = llama_build_graph_k_shift(lctx); ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
@ -20463,7 +20463,7 @@ void llama_kv_cache_update(struct llama_context * ctx) {
} }
bool llama_kv_cache_can_shift(struct llama_context * ctx) { bool llama_kv_cache_can_shift(struct llama_context * ctx) {
return ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA return !ctx->kv_self.recurrent && ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
} }
// deprecated // deprecated