llama : fix type of KQ_mask and KQ_pos

This commit is contained in:
Georgi Gerganov 2024-03-22 17:12:17 +02:00
parent 9495d3982d
commit 3a468e6f9f
No known key found for this signature in database
GPG Key ID: BF970631944C16B7

View File

@ -5810,20 +5810,20 @@ struct llm_build_context {
struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
if (causal) {
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
} else {
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
}
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
ggml_set_input(lctx.inp_KQ_mask);
return lctx.inp_KQ_mask;
return ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16);
}
struct ggml_tensor * build_inp_KQ_pos() {
lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F16, n_kv);
lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv);
cb(lctx.inp_KQ_pos, "KQ_pos", -1);
ggml_set_input(lctx.inp_KQ_pos);
return lctx.inp_KQ_pos;
return ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16);
}
struct ggml_tensor * build_inp_mean() {