From d917746ddb053b73e868fd6e1854ac17b62bd863 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 15:00:49 +0200 Subject: [PATCH] metal : avoid redundant loads of the attention --- ggml-metal.metal | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 9b6ceec4e..785a60e50 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2184,20 +2184,22 @@ kernel void kernel_flash_attn_ext_f16( ss[j*T + p] = vs; } + simdgroup_barrier(mem_flags::mem_none); + // (Q*K^T)*V { simdgroup_half8x8 mv; + simdgroup_half8x8 mp[C/8]; + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); + } + for (int64_t i = 0; i < D8; ++i) { - simdgroup_half8x8 mp[C/8]; simdgroup_half8x8 mqkv; simdgroup_load(mqkv, ps + i*8, T, 0, false); - for (int cc = 0; cc < C/8; ++cc) { - simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); - } - for (int cc = 0; cc < C/8; ++cc) { device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));