llama : better express the KV cache dependencies in the graph

This commit is contained in:
Georgi Gerganov 2023-09-04 21:44:48 +03:00
parent 60c2ef6d92
commit f3a84b2e0d
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 41 additions and 31 deletions

2
ggml.c
View File

@ -5213,6 +5213,8 @@ struct ggml_tensor * ggml_view_tensor(
result->nb[i] = src->nb[i];
}
result->op = GGML_OP_VIEW;
return result;
}

View File

@ -2341,45 +2341,53 @@ static struct ggml_cgraph * llm_build_llama(
// compute Q and K and RoPE them
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
offload_func_kq(tmpk);
ggml_set_name(tmpk, "tmpk");
ggml_set_name (tmpk, "tmpk");
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
offload_func_kq(tmpq);
ggml_set_name(tmpq, "tmpq");
ggml_set_name (tmpq, "tmpq");
struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
offload_func_v(tmpv);
ggml_set_name (tmpv, "tmpv");
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Kcur);
ggml_set_name(Kcur, "Kcur");
ggml_set_name (Kcur, "Kcur");
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Qcur);
ggml_set_name(Qcur, "Qcur");
ggml_set_name (Qcur, "Qcur");
// compute the transposed [N, n_embd] V matrix
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N));
offload_func_v(Vcur);
ggml_set_name (Vcur, "Vcur");
struct ggml_tensor * k;
struct ggml_tensor * v;
// store key and value to memory
{
// compute the transposed [N, n_embd] V matrix
struct ggml_tensor * k_view = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
offload_func_kq(k_view);
ggml_set_name (k_view, "k_view");
struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
offload_func_v(tmpv);
ggml_set_name(tmpv, "tmpv");
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N));
offload_func_v(Vcur);
ggml_set_name(Vcur, "Vcur");
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
offload_func_kq(k);
ggml_set_name(k, "k");
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
struct ggml_tensor * v_view = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
offload_func_v(v);
ggml_set_name(v, "v");
offload_func_v(v_view);
ggml_set_name (v_view, "v_view");
// important: storing RoPE-ed version of K in the KV cache!
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
struct ggml_tensor * k_cpy = ggml_cpy(ctx0, Kcur, k_view);
struct ggml_tensor * v_cpy = ggml_cpy(ctx0, Vcur, v_view);
// TODO: replace with ggml_dependency / ggml_depends_on
k = ggml_view_tensor(ctx0, kv_self.k);
v = ggml_view_tensor(ctx0, kv_self.v);
k->src[0] = k_cpy;
v->src[0] = v_cpy;
}
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
@ -2387,11 +2395,11 @@ static struct ggml_cgraph * llm_build_llama(
ggml_set_name(Q, "Q");
struct ggml_tensor * K =
ggml_view_3d(ctx0, kv_self.k,
ggml_view_3d(ctx0, k,
n_embd_head, n_past + N, n_head_kv,
ggml_element_size(kv_self.k)*n_embd_gqa,
ggml_element_size(kv_self.k)*n_embd_head,
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
ggml_element_size(k)*n_embd_gqa,
ggml_element_size(k)*n_embd_head,
ggml_element_size(k)*n_embd_gqa*n_ctx*il);
offload_func_kq(K);
ggml_set_name(K, "K");
@ -2418,11 +2426,11 @@ static struct ggml_cgraph * llm_build_llama(
// split cached V into n_head heads
struct ggml_tensor * V =
ggml_cont(ctx0, ggml_view_3d(ctx0, kv_self.v,
ggml_view_3d(ctx0, v,
n_past + N, n_embd_head, n_head_kv,
ggml_element_size(kv_self.v)*n_ctx,
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il));
ggml_element_size(v)*n_ctx,
ggml_element_size(v)*n_ctx*n_embd_head,
ggml_element_size(v)*n_ctx*n_embd_gqa*il);
offload_func_v(V);
ggml_set_name(V, "V");
@ -2434,7 +2442,7 @@ static struct ggml_cgraph * llm_build_llama(
// make V contiguous in memory to speed up the matmul, however we waste time on the copy
// on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
// is there a better way?
struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head));
struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, v->type, n_past + N, n_embd_head, n_head));
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
#endif