use inp_pos

This commit is contained in:
okada 2023-12-17 21:53:04 +09:00
parent 86d5348fd0
commit f76fd39266

View File

@ -5528,9 +5528,9 @@ struct llm_build_context {
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1); cb(KQ_mask, "KQ_mask", -1);
// KQ_pos - contains the positions // inp_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(KQ_pos, "KQ_pos", -1); cb(inp_pos, "inp_pos", -1);
// shift the entire K-cache if needed // shift the entire K-cache if needed
if (do_rope_shift) { if (do_rope_shift) {
@ -5558,13 +5558,13 @@ struct llm_build_context {
cb(tmpq, "tmpq", il); cb(tmpq, "tmpq", il);
struct ggml_tensor * Kcur = ggml_rope_custom( struct ggml_tensor * Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos,
n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale, n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
struct ggml_tensor * Qcur = ggml_rope_custom( struct ggml_tensor * Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), KQ_pos, ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos,
n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale, n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);