diff --git a/llama.cpp b/llama.cpp index c5a1fa0f6..e11f0ac4b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6622,6 +6622,7 @@ static struct ggml_tensor * llm_build_kqv( const int64_t n_embd_head_k = hparams.n_embd_head_k; const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_head_v = hparams.n_embd_head_v; + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3); cb(q, "q", il); @@ -6644,8 +6645,8 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], n_embd_head_v, n_kv, n_head_kv, - ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv.v_l[il]->type, n_embd_head_k), + ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv.v_l[il]->type, n_embd_head_v), 0); cb(v, "v", il); @@ -6655,7 +6656,7 @@ static struct ggml_tensor * llm_build_kqv( ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); } - cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); + cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens); } else { struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); cb(kq, "kq", il); @@ -6700,7 +6701,7 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); cb(kqv_merged, "kqv_merged", il); - cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens); cb(cur, "kqv_merged_cont", il); }