diff --git a/ggml-metal.m b/ggml-metal.m index e00069624..2bbb6d17a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2285,8 +2285,9 @@ static bool ggml_metal_graph_compute( const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - GGML_ASSERT(nqptg % 8 == 0); - GGML_ASSERT(ncpsg % 32 == 0); + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); // simdgroups per threadgroup (a.k.a. warps) // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) diff --git a/ggml-metal.metal b/ggml-metal.metal index 3d5d762d1..d9a536ae8 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2188,6 +2188,8 @@ kernel void kernel_flash_attn_ext_f16( // online softmax if (C == 32) { + half ms[Q]; + for (int64_t j = 0; j < Q; ++j) { const int64_t p = tiisg; @@ -2197,20 +2199,22 @@ kernel void kernel_flash_attn_ext_f16( smax = simd_max(max(smax, s)); M[j] = simd_max(max(M[j], s)); - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); - S[j] = S[j]*ms + simd_sum(vs); - - // create a QxQ diagonal matrix for rescaling the output - if (p == j) { - ss[j*T + C + j] = ms; - } + S[j] = S[j]*ms[j] + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) ss[j*T + p] = vs; } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*T + C + tiisg] = ms[tiisg]; + } } else { + half ms[Q]; + for (int64_t j = 0; j < Q; ++j) { const half m = M[j]; @@ -2224,12 +2228,7 @@ kernel void kernel_flash_attn_ext_f16( smax = simd_max(smax); M[j] = simd_max(M[j]); - const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); - - // create a QxQ diagonal matrix for rescaling the output - if (tiisg == j) { - ss[j*T + C + j] = ms; - } + ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); // local sum half ls = 0.0h; @@ -2245,7 +2244,12 @@ kernel void kernel_flash_attn_ext_f16( ss[j*T + p] = vs; } - S[j] = S[j]*ms + simd_sum(ls); + S[j] = S[j]*ms[j] + simd_sum(ls); + } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*T + C + tiisg] = ms[tiisg]; } }