mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
llama : compute BERT graph with F16 K, V
ggml-ci
This commit is contained in:
parent
6cdabe6526
commit
0ba20ed97a
@ -6175,7 +6175,7 @@ struct llm_build_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||||
struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
|
struct ggml_tensor * k = ggml_cast(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3), GGML_TYPE_F16);
|
||||||
|
|
||||||
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||||
cb(kq, "kq", il);
|
cb(kq, "kq", il);
|
||||||
@ -6183,7 +6183,7 @@ struct llm_build_context {
|
|||||||
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
|
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
|
||||||
cb(kq, "kq_soft_max_ext", il);
|
cb(kq, "kq_soft_max_ext", il);
|
||||||
|
|
||||||
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
|
struct ggml_tensor * v = ggml_cast(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)), GGML_TYPE_F16);
|
||||||
cb(v, "v", il);
|
cb(v, "v", il);
|
||||||
|
|
||||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
|
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
|
||||||
|
Loading…
Reference in New Issue
Block a user