diff --git a/ggml-metal.m b/ggml-metal.m index a7e126bff..484ef8939 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2213,12 +2213,12 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! (multiple of 8) + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! (multiple of 32) // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/32, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); diff --git a/ggml-metal.metal b/ggml-metal.metal index b564f014d..7b604eb61 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2041,7 +2041,6 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; const int64_t D8 = D/8; const int64_t NW = N_SIMDWIDTH; - const int64_t L4 = (D4 + NW - 1)/NW; const int64_t SH = (C + Q); // shared memory per simdgroup in (half) const int64_t T = D + nsg*SH; // shared memory size per query in (half) @@ -2054,14 +2053,15 @@ kernel void kernel_flash_attn_ext_f16( // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) simdgroup_half8x8 lo[D8]; - for (int64_t i = 0; i < L4; ++i) { - // load heads from Q to shared memory - for (int64_t j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + // load heads from Q to shared memory + for (int64_t j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (int64_t i = tiisg; i < D4; i += NW) { if (iq1 + j < ne01) { - sq4[j*T4 + NW*i + tiisg] = (half4) q4[NW*i + tiisg]; + sq4[j*T4 + i] = (half4) q4[i]; } else { - sq4[j*T4 + NW*i + tiisg] = 0.0h; + sq4[j*T4 + i] = 0.0h; } } } @@ -2072,12 +2072,9 @@ kernel void kernel_flash_attn_ext_f16( } // zero out shared memory SH - if (tiisg < C) { - for (int64_t j = 0; j < Q; ++j) { - ss[j*T + tiisg] = 0.0h; - if (tiisg < Q) { - ss[j*T + C + tiisg] = 0.0h; - } + for (int64_t j = 0; j < Q; ++j) { + for (int64_t i = tiisg; i < SH; i += NW) { + ss[j*T + i] = 0.0h; } } @@ -2157,27 +2154,34 @@ kernel void kernel_flash_attn_ext_f16( // online softmax for (int64_t j = 0; j < Q; ++j) { - const int64_t p = tiisg; - - const half s = ss[j*T + p]; - - smax = simd_max(max(smax, s)); - M[j] = simd_max(max(M[j], s)); - const half m = M[j]; - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; - S[j] = S[j]*ms + simd_sum(vs); + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + } + + const half ms = exp(m - M[j]); + + S[j] = S[j]*ms; // create an 8x8 diagonal matrix for rescaling the output - if (p == j) { + if (tiisg == j) { ss[j*T + C + j] = ms; } - // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = vs; + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; + + const half vs = exp(s - M[j]); + + S[j] = S[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } } // skip -INF blocks @@ -2231,7 +2235,7 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); - // each simdgroup stores its output to shared memory, reusing sq4 + // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { for (int64_t i = 0; i < D8; ++i) { simdgroup_store(lo[i], sq + i*8, T, 0, false); @@ -2284,7 +2288,7 @@ kernel void kernel_flash_attn_ext_f16( } } - // store result to shared memory (reuse sq4) + // store result to shared memory (reuse sq) if (sgitg == 0) { for (int64_t i = 0; i < D8; ++i) { simdgroup_store(lo[i], sq + i*8, T, 0, false); @@ -2298,8 +2302,8 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int64_t i = 0; i < L4; ++i) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + NW*i + tiisg] = (float4) sq4[j*T4 + NW*i + tiisg]/S; + for (int64_t i = tiisg; i < D4; i += NW) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; } } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 4c98bef7c..4093a52f2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1395,7 +1395,7 @@ struct test_flash_attn_ext : public test_case { } double max_nmse_err() override { - return 5e-5; + return 5e-4; } test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) @@ -1677,9 +1677,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 8)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 7)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 1)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 8)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 7)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 1)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 8)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 7)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 1)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 8)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 7)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer