mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 06:10:29 +01:00
llama : fix type of KQ_mask and KQ_pos
This commit is contained in:
parent
9495d3982d
commit
3a468e6f9f
10
llama.cpp
10
llama.cpp
@ -5810,20 +5810,20 @@ struct llm_build_context {
|
||||
|
||||
struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
|
||||
if (causal) {
|
||||
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
} else {
|
||||
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
}
|
||||
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
|
||||
ggml_set_input(lctx.inp_KQ_mask);
|
||||
return lctx.inp_KQ_mask;
|
||||
return ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16);
|
||||
}
|
||||
|
||||
struct ggml_tensor * build_inp_KQ_pos() {
|
||||
lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F16, n_kv);
|
||||
lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv);
|
||||
cb(lctx.inp_KQ_pos, "KQ_pos", -1);
|
||||
ggml_set_input(lctx.inp_KQ_pos);
|
||||
return lctx.inp_KQ_pos;
|
||||
return ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16);
|
||||
}
|
||||
|
||||
struct ggml_tensor * build_inp_mean() {
|
||||
|
Loading…
Reference in New Issue
Block a user