diff --git a/ggml-metal.m b/ggml-metal.m index b64bb7800..b41f29681 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2253,16 +2253,16 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nwarps = 4; - const int64_t nhptg = 2; // heads per threadgroup !! sync with kernel template arguments !! - const int64_t nqptg = 4; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t nsg = 4; // simdgroups per threadgroup (a.k.a. warps) + const int64_t nhptg = 2; // heads per threadgroup !! sync with kernel template arguments !! + const int64_t nqptg = 4; // queries per threadgroup !! sync with kernel template arguments !! - const size_t smem = nqptg*(nhptg*ne00 + nwarps*(nhptg*ne00 + 256))*(sizeof(float)/2); + const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/(nhptg), ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/(nhptg), ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index 3edd7e759..2c22b143c 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2091,8 +2091,7 @@ kernel void kernel_flash_attn_ext_f16( // load H heads from Q to shared memory for (int64_t i = 0; i < D4/tph; ++i) { - if (sgitg < Q) { - const int64_t j = sgitg; + for (int64_t j = sgitg; j < Q; j += nsg) { if (iq1 + j < ne01) { pq4[j*T4 + hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; } else { @@ -2180,28 +2179,28 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_barrier(mem_flags::mem_none); - for (int p = 0; p < 8; ++p) { - const int64_t ic = iic + p; + for (int64_t i = 0; i < D4/tph; ++i) { + half4 pv4v[8]; - device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + for (int p = 0; p < 8; ++p) { + const int64_t ic = iic + p; - half ms[Q] = { 1.0h }; - half vs[Q] = { 0.0h }; + device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); - for (int64_t j = 0; j < Q; ++j) { - ms[j] = ss[j*T + 32*p + 2*hiisg + 0]; - vs[j] = ss[j*T + 32*p + 2*hiisg + 1]; - } - - thread half4 pv4v[D4/tph]; - for (int64_t i = 0; i < D4/tph; ++i) { - pv4v[i] = pv4[tph*i + tiih]; + pv4v[p] = pv4[tph*i + tiih]; } for (int64_t j = 0; j < Q; ++j) { - for (int64_t i = 0; i < D4/tph; ++i) { - ps4[j*T4 + hiisg*D4 + tph*i + tiih] = ps4[j*T4 + hiisg*D4 + tph*i + tiih]*ms[j] + pv4v[i]*vs[j]; + half4 ps4v = ps4[j*T4 + hiisg*D4 + tph*i + tiih]; + + for (int p = 0; p < 8; ++p) { + const half ms = ss[j*T + 32*p + 2*hiisg + 0]; + const half vs = ss[j*T + 32*p + 2*hiisg + 1]; + + ps4v = ps4v*ms + pv4v[p]*vs; } + + ps4[j*T4 + hiisg*D4 + tph*i + tiih] = ps4v; } } }