diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index edce74108..89f12724d 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -3356,7 +3356,7 @@ kernel void kernel_flash_attn_ext_vec( const short D4 = D/4; const short D16 = D/16; const short NW = N_SIMDWIDTH; - const short NW4 = NW/4; + const short NL = NW/4; const short SH = 2*C; // shared memory per simdgroup const short T = D + nsg*SH; // shared memory size per query in (half) @@ -3370,7 +3370,7 @@ kernel void kernel_flash_attn_ext_vec( threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - o4x4_t lo[D16/NW4]; + o4x4_t lo[D16/NL]; // load heads from Q to shared memory device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); @@ -3384,7 +3384,7 @@ kernel void kernel_flash_attn_ext_vec( } // zero out lo - for (short i = 0; i < D16/NW4; i += NW4) { + for (short i = 0; i < D16/NL; ++i) { lo[i] = (o4x4_t) 0.0f; } @@ -3400,8 +3400,8 @@ kernel void kernel_flash_attn_ext_vec( half M = -__FLT16_MAX__/2; // thread indices inside the simdgroup - const short tx = tiisg%8; - const short ty = tiisg/8; + const short tx = tiisg%NL; + const short ty = tiisg/NL; // broadcast kv //const short rk2 = ne02/ne12; @@ -3411,10 +3411,10 @@ kernel void kernel_flash_attn_ext_vec( const short ikv3 = iq3/(ne03/ne_12_3); // load the queries from shared memory into local memory - q4x4_t mq[D16/NW4]; + q4x4_t mq[D16/NL]; - for (short ii = 0; ii < D16; ii += NW4) { - mq[ii/NW4] = sq4x4[ii + tx]; + for (short ii = 0; ii < D16; ii += NL) { + mq[ii/NL] = sq4x4[ii + tx]; } const bool has_mask = mask != q; @@ -3455,17 +3455,17 @@ kernel void kernel_flash_attn_ext_vec( device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); #pragma unroll - for (short ii = 0; ii < D16; ii += NW4) { + for (short ii = 0; ii < D16; ii += NL) { const short i = ii + tx; k4x4_t mk; deq_k(pk + i/nl_k, i%nl_k, mk); mqk += - dot(mq[ii/NW4][0], mk[0]) + - dot(mq[ii/NW4][1], mk[1]) + - dot(mq[ii/NW4][2], mk[2]) + - dot(mq[ii/NW4][3], mk[3]); + dot(mq[ii/NL][0], mk[0]) + + dot(mq[ii/NL][1], mk[1]) + + dot(mq[ii/NL][2], mk[2]) + + dot(mq[ii/NL][3], mk[3]); } // simdgroup reduce @@ -3513,8 +3513,8 @@ kernel void kernel_flash_attn_ext_vec( // O = diag(ms)*O #pragma unroll - for (short ii = 0; ii < D16; ii += NW4) { - lo[ii/NW4] *= ms; + for (short ii = 0; ii < D16; ii += NL) { + lo[ii/NL] *= ms; } } @@ -3529,13 +3529,13 @@ kernel void kernel_flash_attn_ext_vec( const s4x4_t ms(ss[4*cc + ty]); #pragma unroll - for (short ii = 0; ii < D16; ii += NW4) { + for (short ii = 0; ii < D16; ii += NL) { const short i = ii + tx; v4x4_t mv; deq_v(pv4 + i/nl_v, i%nl_v, mv); - lo[ii/NW4] += mv*ms; + lo[ii/NL] += mv*ms; } } } @@ -3557,23 +3557,37 @@ kernel void kernel_flash_attn_ext_vec( // [ 5, 13, 21, 29] -> [ 5] // [ 6, 14, 22, 30] -> [ 6] // [ 7, 15, 23, 31] -> [ 7] - for (short ii = 0; ii < D16; ii += NW4) { - lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 16); - lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 8); + for (short ii = 0; ii < D16; ii += NL) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16); + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8); + //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4); + //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2); + //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1); - lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 16); - lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 8); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8); + //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4); + //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2); + //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1); - lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 16); - lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 8); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8); + //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4); + //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2); + //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1); - lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 16); - lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 8); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8); + //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4); + //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2); + //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1); } + threadgroup_barrier(mem_flags::mem_threadgroup); + // store results to shared memory - for (short i = tiisg; i < D16; i += NW4) { - sr4x4[i] = lo[i/NW4]; + for (short i = tiisg; i < D16; i += NL) { + sr4x4[i] = lo[i/NL]; } threadgroup_barrier(mem_flags::mem_threadgroup);