mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 14:20:31 +01:00
llama : compute BERT graph with F16 K, V
ggml-ci
This commit is contained in:
parent
6cdabe6526
commit
0ba20ed97a
@ -6175,7 +6175,7 @@ struct llm_build_context {
|
||||
}
|
||||
|
||||
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
|
||||
struct ggml_tensor * k = ggml_cast(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3), GGML_TYPE_F16);
|
||||
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
cb(kq, "kq", il);
|
||||
@ -6183,7 +6183,7 @@ struct llm_build_context {
|
||||
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
|
||||
cb(kq, "kq_soft_max_ext", il);
|
||||
|
||||
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
|
||||
struct ggml_tensor * v = ggml_cast(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)), GGML_TYPE_F16);
|
||||
cb(v, "v", il);
|
||||
|
||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
|
||||
|
Loading…
Reference in New Issue
Block a user