From da23b56f259843ccbda705bea58457d872dc3074 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 24 Jan 2024 13:25:34 +0200 Subject: [PATCH] wip : no ic 8 step --- ggml-metal.m | 4 +- ggml-metal.metal | 105 ++++++++++++++++++++--------------------------- 2 files changed, 47 insertions(+), 62 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 00a8a0e92..0a9fefeda 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2253,12 +2253,12 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nsg = 2; // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = 8; // simdgroups per threadgroup (a.k.a. warps) const int64_t nhptg = 4; // heads per threadgroup !! sync with kernel template arguments !! const int64_t nqptg = 2; // queries per threadgroup !! sync with kernel template arguments !! //const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2); - const size_t smem = nqptg*(nhptg*ne00 + nsg*(256))*(sizeof(float)/2); + const size_t smem = nqptg*(nhptg*ne00 + nsg*(32))*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; diff --git a/ggml-metal.metal b/ggml-metal.metal index 1a6eaed14..9f9841baa 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2078,12 +2078,12 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; - const int64_t T = (H*D + nsg*(256)); // shared memory size per query in half + const int64_t T = (H*D + nsg*(32)); // shared memory size per query in half const int64_t T4 = T/4; // shared memory size per query in half4 - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*H*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(256) + 1*H*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(256) + 1*H*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*H*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(32) + 1*H*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(32) + 1*H*D); const uint tiih = tiisg%tph; // thread index in head const uint hiisg = tiisg/tph; // head index in simdgroup @@ -2116,39 +2116,35 @@ kernel void kernel_flash_attn_ext_f16( half S = { 0.0h }; half M = { -INFINITY }; - for (int64_t iic = 8*sgitg; iic < ne11; iic += 8*nsg) { + for (int64_t ic = sgitg; ic < ne11; ic += nsg) { half mv[Q]; bool skip = true; for (int64_t j = 0; j < Q; ++j) { - mv[j] = mp[j][iic]; + mv[j] = mp[j][ic]; skip = skip && (mv[j] == -INFINITY); } if (skip) { continue; } - for (int p = 0; p < 8; ++p) { - const int64_t ic = iic + p; + device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); - device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); + half s[Q] = { 0.0h }; + half4 pk4v[D4/tph]; - half s[Q] = { 0.0h }; - half4 pk4v[D4/tph]; + for (int64_t i = 0; i < D4/tph; ++i) { + pk4v[i] = pk4[tph*i + tiih]; + } + for (int64_t j = 0; j < Q; ++j) { for (int64_t i = 0; i < D4/tph; ++i) { - pk4v[i] = pk4[tph*i + tiih]; + s[j] += dot(pq4[j*T4 + hiisg*D4 + tph*i + tiih], pk4v[i]); } + } - for (int64_t j = 0; j < Q; ++j) { - for (int64_t i = 0; i < D4/tph; ++i) { - s[j] += dot(pq4[j*T4 + hiisg*D4 + tph*i + tiih], pk4v[i]); - } - } - - for (int64_t j = 0; j < Q; ++j) { - ss[j*T + 32*p + hiisg*tph + tiih] = s[j]; - } + for (int64_t j = 0; j < Q; ++j) { + ss[j*T + hiisg*tph + tiih] = s[j]; } simdgroup_barrier(mem_flags::mem_none); @@ -2156,49 +2152,38 @@ kernel void kernel_flash_attn_ext_f16( if (tiih < Q) { const int64_t j = tiih; - for (int p = 0; p < 8; ++p) { - half4 s4 = 0.0h; + half4 s4 = 0.0h; - for (int64_t i = 0; i < tph/4; ++i) { - s4 += ss4[j*T4 + 8*p + hiisg*tph/4 + i]; - } - - half s = (s4.x + s4.y + s4.z + s4.w)*scale + mp[j][iic + p]; - - const half m = M; - - M = max(M, s); - - const half ms = m == -INFINITY ? 0.0h : exp(m - M); - const half vs = s == -INFINITY ? 0.0h : exp(s - M); - - S = S*ms + vs; - - ss[j*T + 32*p + 2*hiisg + 0] = ms; - ss[j*T + 32*p + 2*hiisg + 1] = vs; + for (int64_t i = 0; i < tph/4; ++i) { + s4 += ss4[j*T4 + hiisg*tph/4 + i]; } + + half s = (s4.x + s4.y + s4.z + s4.w)*scale + mp[j][ic]; + + const half m = M; + + M = max(M, s); + + const half ms = m == -INFINITY ? 0.0h : exp(m - M); + const half vs = s == -INFINITY ? 0.0h : exp(s - M); + + S = S*ms + vs; + + ss[j*T + 2*hiisg + 0] = ms; + ss[j*T + 2*hiisg + 1] = vs; } simdgroup_barrier(mem_flags::mem_none); + device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + for (int64_t i = 0; i < D4/tph; ++i) { - half4 pv4v[8]; - - for (int p = 0; p < 8; ++p) { - const int64_t ic = iic + p; - - device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); - - pv4v[p] = pv4[tph*i + tiih]; - } - for (int64_t j = 0; j < Q; ++j) { - for (int p = 0; p < 8; ++p) { - const half ms = ss[j*T + 32*p + 2*hiisg + 0]; - const half vs = ss[j*T + 32*p + 2*hiisg + 1]; - ps4[j][i] = ps4[j][i]*ms + pv4v[p]*vs; - } + const half ms = ss[j*T + 2*hiisg + 0]; + const half vs = ss[j*T + 2*hiisg + 1]; + + ps4[j][i] = ps4[j][i]*ms + pv4[tph*i + tiih]*vs; } } } @@ -2260,11 +2245,11 @@ kernel void kernel_flash_attn_ext_f16( if (sgitg == 0) { for (int64_t j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 2*hiisg + 0]; - const half S1 = ss[j*T + sg*(256) + 2*hiisg + 0]; + const half S0 = ss[j*T + 2*hiisg + 0]; + const half S1 = ss[j*T + sg*(32) + 2*hiisg + 0]; - const half M0 = ss[j*T + 2*hiisg + 1]; - const half M1 = ss[j*T + sg*(256) + 2*hiisg + 1]; + const half M0 = ss[j*T + 2*hiisg + 1]; + const half M1 = ss[j*T + sg*(32) + 2*hiisg + 1]; M = max(M0, M1); @@ -2279,7 +2264,7 @@ kernel void kernel_flash_attn_ext_f16( } for (int64_t i = 0; i < D4/tph; ++i) { - //ps4[j*T4 + hiisg*D4 + tph*i + tiih] = ps4[j*T4 + hiisg*D4 + tph*i + tiih]*ms0 + ps4[j*T4 + sg*(256)/4 + hiisg*D4 + tph*i + tiih]*ms1; + //ps4[j*T4 + hiisg*D4 + tph*i + tiih] = ps4[j*T4 + hiisg*D4 + tph*i + tiih]*ms0 + ps4[j*T4 + sg*(32)/4 + hiisg*D4 + tph*i + tiih]*ms1; ps4[j][i] = ps4[j][i]*ms0 + pq4[j*T4 + hiisg*D4 + tph*i + tiih]*ms1; } }