llama : fix type of KQ_mask and KQ_pos

This commit is contained in:
Georgi Gerganov 2024-03-22 17:12:17 +02:00
parent 9495d3982d
commit 3a468e6f9f
No known key found for this signature in database
GPG Key ID: BF970631944C16B7

View File

@ -5810,20 +5810,20 @@ struct llm_build_context {
struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
if (causal) { if (causal) {
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
} else { } else {
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
} }
cb(lctx.inp_KQ_mask, "KQ_mask", -1); cb(lctx.inp_KQ_mask, "KQ_mask", -1);
ggml_set_input(lctx.inp_KQ_mask); ggml_set_input(lctx.inp_KQ_mask);
return lctx.inp_KQ_mask; return ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16);
} }
struct ggml_tensor * build_inp_KQ_pos() { struct ggml_tensor * build_inp_KQ_pos() {
lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F16, n_kv); lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv);
cb(lctx.inp_KQ_pos, "KQ_pos", -1); cb(lctx.inp_KQ_pos, "KQ_pos", -1);
ggml_set_input(lctx.inp_KQ_pos); ggml_set_input(lctx.inp_KQ_pos);
return lctx.inp_KQ_pos; return ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16);
} }
struct ggml_tensor * build_inp_mean() { struct ggml_tensor * build_inp_mean() {