diff --git a/ggml-metal.m b/ggml-metal.m index a3191e35a..baade0abc 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2253,9 +2253,9 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nsg = 2; // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = 4; // simdgroups per threadgroup (a.k.a. warps) const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; + const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values) //const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2); const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2); diff --git a/ggml-metal.metal b/ggml-metal.metal index 9c5d1ed2e..9b6ceec4e 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2055,10 +2055,8 @@ kernel void kernel_flash_attn_ext_f16( threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(D + 1*C) + 1*D); threadgroup half * ss = (threadgroup half *) (shared + sgitg*(D + 1*C) + 2*D); - half4 ls4[Q][L4]; - - // load heads from Q to shared memory for (int64_t i = 0; i < L4; ++i) { + // load heads from Q to shared memory for (int64_t j = sgitg; j < Q; j += nsg) { if (iq1 + j < ne01) { pq4[j*T4 + N4*i + tiisg] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + tiisg]; @@ -2067,8 +2065,9 @@ kernel void kernel_flash_attn_ext_f16( } } + // zero out shared memory for (int64_t j = 0; j < Q; ++j) { - ls4[j][i] = 0.0h; + ps4[j*T4 + N4*i + tiisg] = 0.0h; } } @@ -2113,6 +2112,7 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_load(mq[i], pq + i*8, T); } + // TODO: this can be improved device const float * mp[Q]; { @@ -2128,10 +2128,26 @@ kernel void kernel_flash_attn_ext_f16( } for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) { + // skip -INF blocks + // TODO: double-check this + { + float smc = -INFINITY; + + for (int64_t j = 0; j < Q; ++j) { + const float mc = mp[j] ? mp[j][iic + tiisg] : -INFINITY; + smc = simd_max(max(smc, mc)); + } + + if (smc == -INFINITY) { + continue; + } + } + + // Q*K^T { simdgroup_half8x8 mk; - for (int cc = 0; cc < 4; ++cc) { + for (int cc = 0; cc < C/8; ++cc) { simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); @@ -2146,6 +2162,7 @@ kernel void kernel_flash_attn_ext_f16( } } + // online softmax for (int64_t j = 0; j < Q; ++j) { const int64_t p = tiisg; @@ -2161,24 +2178,27 @@ kernel void kernel_flash_attn_ext_f16( S[j] = S[j]*ms + simd_sum(vs); for (int64_t i = 0; i < L4; ++i) { - ls4[j][i] *= ms; + ps4[j*T4 + N4*i + tiisg] *= ms; } ss[j*T + p] = vs; } + // (Q*K^T)*V { simdgroup_half8x8 mv; for (int64_t i = 0; i < D8; ++i) { - simdgroup_half8x8 mp[4]; - simdgroup_half8x8 mqkv = make_filled_simdgroup_matrix(0.h); + simdgroup_half8x8 mp[C/8]; + simdgroup_half8x8 mqkv; - for (int cc = 0; cc < 4; ++cc) { + simdgroup_load(mqkv, ps + i*8, T, 0, false); + + for (int cc = 0; cc < C/8; ++cc) { simdgroup_load(mp[cc], ss + 8*cc, T, 0, false); } - for (int cc = 0; cc < 4; ++cc) { + for (int cc = 0; cc < C/8; ++cc) { device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); simdgroup_load(mv, pv + i*8, nb21/2, 0, false); @@ -2189,12 +2209,6 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_store(mqkv, ps + i*8, T, 0, false); } } - - for (int64_t j = 0; j < Q; ++j) { - for (int64_t i = 0; i < L4; ++i) { - ls4[j][i] += ps4[j*T4 + N4*i + tiisg]; - } - } } for (int64_t j = 0; j < Q; ++j) { @@ -2208,23 +2222,12 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); // reduce the warps + // TODO: try parallel reduce + if (sgitg == 0) { + half S = { 0.0h }; + half M = { -INFINITY }; - half S = { 0.0h }; - half M = { -INFINITY }; - - for (int64_t sg = 1; sg < nsg; ++sg) { - if (sgitg == sg) { - // store heads to shared memory - reuse pq4 - for (int64_t j = 0; j < Q; ++j) { - for (int64_t i = 0; i < L4; ++i) { - pq4[j*T4 + N4*i + tiisg] = ls4[j][i]; - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (sgitg == 0) { + for (int64_t sg = 1; sg < nsg; ++sg) { for (int64_t j = 0; j < Q; ++j) { const half S0 = ss[j*T + 0]; const half S1 = ss[j*T + sg*(D + 1*C) + 0]; @@ -2245,21 +2248,10 @@ kernel void kernel_flash_attn_ext_f16( } for (int64_t i = 0; i < L4; ++i) { - ls4[j][i] = ls4[j][i]*ms0 + pq4[j*T4 + N4*i + tiisg]*ms1; + ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1; } } } - - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - if (sgitg == 0) { - for (int64_t j = 0; j < Q; ++j) { - const half S = ss[j*T + 0]; - for (int64_t i = 0; i < L4; ++i) { - ls4[j][i] = ls4[j][i]/S; - } - } } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -2268,8 +2260,10 @@ kernel void kernel_flash_attn_ext_f16( if (sgitg == 0) { for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + const half S = ss[j*T + 0]; + for (int64_t i = 0; i < L4; ++i) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + N4*i + tiisg] = (float4) ls4[j][i]; + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + N4*i + tiisg] = (float4) ps4[j*T4 + N4*i + tiisg]/S; } } }