diff --git a/ggml.c b/ggml.c index ad546a731..026c2205d 100644 --- a/ggml.c +++ b/ggml.c @@ -9650,7 +9650,7 @@ static void ggml_compute_forward_mul_mat( for (int64_t i12 = 0; i12 < ne12; i12++) { // broadcast src0 into src1 across 2nd,3rd dimension const int64_t i03 = i13/r3; - const int64_t i02 = i12/r2; + const int64_t i02 = i12%5; const void * x = (char *) src0->data + i02*nb02 + i03*nb03; const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13); @@ -9761,7 +9761,7 @@ static void ggml_compute_forward_mul_mat( // broadcast src0 into src1 const int64_t i03 = i13/r3; - const int64_t i02 = i12/r2; + const int64_t i02 = i12%5; const int64_t i1 = i11; const int64_t i2 = i12; diff --git a/llama.cpp b/llama.cpp index 25a1a4a7c..f585c0e67 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5543,7 +5543,7 @@ struct llm_build_context { cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "attention_norm", il); + cb(cur, "attn_norm", il); struct ggml_tensor * attention_norm = cur; @@ -5573,73 +5573,9 @@ struct llm_build_context { llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); - auto plamo_llm_build_kqv = []( - struct ggml_context * ctx, - const llama_hparams & hparams, - const llama_kv_cache & kv, - struct ggml_tensor * wo, - struct ggml_tensor * q_cur, - struct ggml_tensor * kq_mask, - int64_t n_ctx, - int32_t n_tokens, - int32_t n_kv, - const llm_build_cb & cb, - int il) { - const int64_t n_embd = hparams.n_embd; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_embd_gqa = hparams.n_embd_gqa(); - - struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3); - cb(q, "q", il); - - struct ggml_tensor * k = - ggml_view_3d(ctx, kv.k_l[il], - n_embd_head, n_kv, n_head_kv, - ggml_row_size(kv.k_l[il]->type, n_embd_gqa), - ggml_row_size(kv.k_l[il]->type, n_embd_head), - 0); - cb(k, "k", il); - - // we should avoid to repeat K but current ggml_mul_mat generates wrong values for grouped query att - struct ggml_tensor * k_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k->ne[0], k->ne[1], q->ne[2]); - cb(k_repeated, "k_repeated", il); - - struct ggml_tensor * kq = ggml_mul_mat(ctx, ggml_repeat(ctx, k, k_repeated), q); - cb(kq, "kq", il); - - kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head))); - cb(kq, "kq_soft_max_ext", il); - - // split cached v into n_head heads - struct ggml_tensor * v = - ggml_view_3d(ctx, kv.v_l[il], - n_kv, n_embd_head, n_head_kv, - ggml_element_size(kv.v_l[il])*n_ctx, - ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head, - 0); - cb(v, "v", il); - - // we should avoid to repeat V but current ggml_mul_mat generates wrong values for grouped query att - struct ggml_tensor * v_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, v->ne[0], v->ne[1], q->ne[2]); - cb(k_repeated, "v_repeated", il); - - struct ggml_tensor * kqv = ggml_mul_mat(ctx, ggml_repeat(ctx, v, v_repeated), kq); - cb(kqv, "kqv", il); - - struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); - - struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd, n_tokens); - cb(cur, "kqv_merged_cont", il); - - cur = ggml_mul_mat(ctx, wo, cur); - return cur; - }; - - cur = plamo_llm_build_kqv(ctx0, hparams, kv_self, - model.layers[il].wo, - Qcur, KQ_mask, n_ctx, n_tokens, n_kv, cb, il); + cur = llm_build_kqv(ctx0, hparams, kv_self, + model.layers[il].wo, NULL, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); cb(cur, "kqv_out", il); } struct ggml_tensor * sa_out = cur; @@ -5653,13 +5589,14 @@ struct llm_build_context { model.layers[il].ffn_gate, NULL, model.layers[il].ffn_down, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); - cb(cur, "mlp_out", il); + cb(cur, "ffn_out", il); } cur = ggml_add(ctx0, cur, sa_out); - cb(cur, "mlp_out + sa_out", il); + cb(cur, "l_out", il); + cur = ggml_add(ctx0, cur, inpL); - cb(cur, "mlp_out + sa_out + inpL", il); + cb(cur, "l_out", il); // input for next layer inpL = cur;