From 806382a3a640e924ef4f4a44288be92f8e07c3b3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 25 Jan 2024 09:39:22 +0200 Subject: [PATCH] wip : simdify ms, vs --- ggml-metal.metal | 89 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 66 insertions(+), 23 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 0972b5282..be38c17f0 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2052,7 +2052,7 @@ kernel void kernel_flash_attn_ext_f16( threadgroup half * pq = (threadgroup half *) (shared + 0*D); threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D); threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*C) + 1*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*C) + 1*D); + //threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*C) + 1*D); half4 ps4[Q][L4]; @@ -2079,10 +2079,10 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); - half S = { 0.0h }; - half M = { -INFINITY }; - { + half S[Q] = { 0.0h }; + half M[Q] = { -INFINITY }; + // assume K and V are same shape const int64_t ne22 = ne12; const int64_t ne23 = ne13; @@ -2107,6 +2107,7 @@ kernel void kernel_flash_attn_ext_f16( const int64_t iv3 = iq3 / rv3; simdgroup_half8x8 mq[D8]; + for (int64_t i = 0; i < D8; ++i) { simdgroup_load(mq[i], pq + i*8, T); } @@ -2151,46 +2152,85 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_store(mqk, ss, T, 0, false); } - if (tiisg < Q) { - const int64_t j = tiisg; + //if (tiisg < Q) { + // const int64_t j = tiisg; - for (int p = 0; p < C; ++p) { - const half s = ss[j*T + p + 0]*scale + (mp[j][iic + p]); + // for (int p = 0; p < C; ++p) { + // const half s = ss[j*T + p + 0]*scale + (mp[j][iic + p]); - const half m = M; + // const half m = M; - M = max(M, s); + // M = max(M, s); - const half ms = m == -INFINITY ? 0.0h : exp(m - M); - const half vs = s == -INFINITY ? 0.0h : exp(s - M); + // const half ms = m == -INFINITY ? 0.0h : exp(m - M); + // const half vs = s == -INFINITY ? 0.0h : exp(s - M); - S = S*ms + vs; + // S = S*ms + vs; - ss[j*T + 0 + p] = ms; + // ss[j*T + 0 + p] = ms; + // ss[j*T + C + p] = vs; + // } + //} + + // not sure why this barrier is needed + simdgroup_barrier(mem_flags::mem_none); + + for (int64_t j = 0; j < Q; ++j) { + const int64_t p = tiisg % C; + + const half s = ss[j*T + p + 0]*scale + (mp[j][iic + p]); + + half m = M[j]; + + M[j] = simd_max(max(M[j], s)); + + const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + + S[j] = S[j]*ms + 0.25h*simd_sum(vs); + + for (int64_t i = 0; i < L4; ++i) { + ps4[j][i] *= ms; + } + + if (tiisg < C) { ss[j*T + C + p] = vs; } } - for (int p = 0; p < C; ++p) { + for (int64_t p = 0; p < C; ++p) { const int64_t ic = iic + p; device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); for (int64_t j = 0; j < Q; ++j) { - const half ms = ss[j*T + 0 + p]; const half vs = ss[j*T + C + p]; for (int64_t i = 0; i < L4; ++i) { - ps4[j][i] = ps4[j][i]*ms + pv4[N4*i + tiisg]*vs; + ps4[j][i] += pv4[N4*i + tiisg]*vs; } } } + + //for (int p = 0; p < C; ++p) { + // const int64_t ic = iic + p; + // device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + + // for (int64_t j = 0; j < Q; ++j) { + // const half ms = ss[j*T + 0 + p]; + // const half vs = ss[j*T + C + p]; + + // for (int64_t i = 0; i < L4; ++i) { + // ps4[j][i] = ps4[j][i]*ms + pv4[N4*i + tiisg]*vs; + // } + // } + //} } - if (tiisg < Q) { - const int64_t j = tiisg; - - ss[j*T + 0] = S; - ss[j*T + 1] = M; + for (int64_t j = 0; j < Q; ++j) { + if (tiisg == 0) { + ss[j*T + 0] = S[j]; + ss[j*T + 1] = M[j]; + } } } @@ -2198,6 +2238,9 @@ kernel void kernel_flash_attn_ext_f16( // reduce the warps + half S = { 0.0h }; + half M = { -INFINITY }; + for (int64_t sg = 1; sg < nsg; ++sg) { if (sgitg == sg) { // store heads to shared memory - reuse pq4 @@ -2241,7 +2284,7 @@ kernel void kernel_flash_attn_ext_f16( if (sgitg == 0) { for (int64_t j = 0; j < Q; ++j) { - S = ss[j*T + 0]; + const half S = ss[j*T + 0]; for (int64_t i = 0; i < L4; ++i) { ps4[j][i] = ps4[j][i]/S; }