diff --git a/ggml-metal.m b/ggml-metal.m index 046643146..0b1119c4e 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2252,14 +2252,14 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int nwarps = 1; + const int64_t nwarps = 2; - const size_t shalf = sizeof(float)/2; + const size_t smem = nwarps*(2*8*nwarps*ne00 + 128)*(sizeof(float)/2); - GGML_ASSERT(2*32*nwarps*ne00*shalf <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*shalf atIndex:0]; + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 8*nwarps - 1)/(8*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index d97952f2b..789b19bad 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2031,33 +2031,20 @@ kernel void kernel_flash_attn_ext_f16( uint3 ntg[[threads_per_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int64_t iq3 = tgpig[2]; - const int64_t iq2 = tgpig[1]; - const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; + //const int64_t iq3 = tgpig[2]; + //const int64_t iq2 = tgpig[1]; + //const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; - if (iq1 >= ne01) { + const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups + + const int64_t iq3 = tgpig[2]; + const int64_t iq2 = tgpig[1]*(8*nsg) + 8*sgitg + tiisg/4; + const int64_t iq1 = tgpig[0]; + + if (iq2 >= ne02) { return; } - const int64_t D4 = D/4; - - // TODO: can we move this to the stack? - threadgroup half4 * V16 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); - - // initialize with zeros - for (int64_t d = 0; d < D4; ++d) { - V16[d] = 0.0h; - } - - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); - - half S = 0.0h; - half M = -INFINITY; - - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - - device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; - // assume K and V are same shape const int64_t ne22 = ne12; const int64_t ne23 = ne13; @@ -2081,11 +2068,97 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv2 = iq2 / rv2; const int64_t iv3 = iq3 / rv3; - // load Q to shared memory - for (int64_t d = 0; d < D4; ++d) { - pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + + device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; + +// const int64_t D4 = D/4; +// +// // TODO: can we move this to the stack? +// threadgroup half4x4 * V16 = (threadgroup half4x4 *) (shared); +// +// // initialize with zeros +// for (int64_t d = 0; d < D4; ++d) { +// +// } +// +// threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 4*D); +// +// // load Q to shared memory +// for (int64_t d = 0; d < D4; ++d) { +// pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; +// } +// +// half S = 0.0h; +// half M = -INFINITY; +// +// for (int64_t ic = 0; ic < ne11; ++ic) { +// const half mv = mp ? mp[ic] : 0.0h; +// if (mv == -INFINITY) { +// continue; +// } +// +// device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); +// device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); +// +// half4 s4 = 0.0h; +// +// for (int64_t d = 0; d < D4; ++d) { +// s4 += pk4[d] * pq4[d]; +// } +// +// half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; +// +// const half Mold = M; +// +// M = max(M, s); +// +// const half ms = exp(Mold - M); +// const half vs = exp(s - M); +// +// for (int64_t d = 0; d < D4; ++d) { +// V16[d] = V16[d]*ms + pv4[d]*vs; +// } +// +// S = S*ms + vs; +// } +// +// for (int64_t d = 0; d < D4; ++d) { +// V16[d] /= S; +// } +// +// // dst indices +// const int64_t i1 = iq1; +// const int64_t i2 = iq2; +// const int64_t i3 = iq3; +// +// device float4 * dst4 = (device float4 *) dst; +// +// for (int64_t d = 0; d < D4; ++d) { +// dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; +// } + + const int64_t D4 = D/4; + + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) ); + threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 8*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 16*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(16*D + 128) + 16*D); + + const uint tiih = tiisg%4; // thread index in head + const uint hiisg = tiisg/4; // head index in simdgroup + + // load 8 heads from Q to shared memory + for (int64_t i = 0; i < D4/4; ++i) { + pq4[hiisg*D4 + 4*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[4*i + tiih]; + ps4[hiisg*D4 + 4*i + tiih] = 0.0h; } + simdgroup_barrier(mem_flags::mem_threadgroup); + + half S = 0.0h; + half M = -INFINITY; + for (int64_t ic = 0; ic < ne11; ++ic) { const half mv = mp ? mp[ic] : 0.0h; if (mv == -INFINITY) { @@ -2097,30 +2170,52 @@ kernel void kernel_flash_attn_ext_f16( half4 s4 = 0.0h; - for (int64_t d = 0; d < D4; ++d) { - s4 += pk4[d] * pq4[d]; + for (int64_t i = 0; i < D4/4; ++i) { + s4 += pk4[4*i + tiih] * pq4[hiisg*D4 + 4*i + tiih]; } - half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; + ss4[hiisg*4 + tiih] = s4; - const half Mold = M; + simdgroup_barrier(mem_flags::mem_threadgroup); - M = max(M, s); + if (tiih == 0) { + s4 = ss4[4*hiisg + 0] + ss4[4*hiisg + 1] + ss4[4*hiisg + 2] + ss4[4*hiisg + 3]; - const half ms = exp(Mold - M); - const half vs = exp(s - M); + half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv; - for (int64_t d = 0; d < D4; ++d) { - V16[d] = V16[d]*ms + pv4[d]*vs; + const half Mold = M; + + M = max(M, s); + + const half ms = exp(Mold - M); + const half vs = exp(s - M); + + S = S*ms + vs; + + ss[2*hiisg + 0] = ms; + ss[2*hiisg + 1] = vs; } - S = S*ms + vs; + simdgroup_barrier(mem_flags::mem_threadgroup); + + const half ms = ss[2*hiisg + 0]; + const half vs = ss[2*hiisg + 1]; + + for (int64_t i = 0; i < D4/4; ++i) { + ps4[hiisg*D4 + 4*i + tiih] = ps4[hiisg*D4 + 4*i + tiih]*ms + pv4[4*i + tiih]*vs; + } } - for (int64_t d = 0; d < D4; ++d) { - V16[d] /= S; + simdgroup_barrier(mem_flags::mem_threadgroup); + + if (tiih == 0) { + for (int64_t i = 0; i < D4; ++i) { + ps4[hiisg*D4 + i] /= S; + } } + simdgroup_barrier(mem_flags::mem_threadgroup); + // dst indices const int64_t i1 = iq1; const int64_t i2 = iq2; @@ -2128,8 +2223,8 @@ kernel void kernel_flash_attn_ext_f16( device float4 * dst4 = (device float4 *) dst; - for (int64_t d = 0; d < D4; ++d) { - dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d]; + for (int64_t i = 0; i < D4/4; ++i) { + dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 4*i + tiih] = (float4) ps4[hiisg*D4 + 4*i + tiih]; } }