From af3eda9c77ab63a2fa7dd51586f439924d7fa0c1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 24 Jan 2024 11:18:24 +0200 Subject: [PATCH] wip --- ggml-metal.m | 9 +++-- ggml-metal.metal | 101 +++++++++++++++++++++++++++++++++++------------ 2 files changed, 81 insertions(+), 29 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index b41f29681..00a8a0e92 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2253,11 +2253,12 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nsg = 4; // simdgroups per threadgroup (a.k.a. warps) - const int64_t nhptg = 2; // heads per threadgroup !! sync with kernel template arguments !! - const int64_t nqptg = 4; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t nsg = 2; // simdgroups per threadgroup (a.k.a. warps) + const int64_t nhptg = 4; // heads per threadgroup !! sync with kernel template arguments !! + const int64_t nqptg = 2; // queries per threadgroup !! sync with kernel template arguments !! - const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2); + //const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2); + const size_t smem = nqptg*(nhptg*ne00 + nsg*(256))*(sizeof(float)/2); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; diff --git a/ggml-metal.metal b/ggml-metal.metal index 2c22b143c..1a6eaed14 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2078,17 +2078,18 @@ kernel void kernel_flash_attn_ext_f16( const int64_t D4 = D/4; - const int64_t T = (H*D + nsg*(H*D + 256)); // shared memory size per query in half - const int64_t T4 = T/4; // shared memory size per query in half4 + const int64_t T = (H*D + nsg*(256)); // shared memory size per query in half + const int64_t T4 = T/4; // shared memory size per query in half4 - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*H*D); - threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(H*D + 256) + 1*H*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(H*D + 256) + 2*H*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(H*D + 256) + 2*H*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*H*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(256) + 1*H*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(256) + 1*H*D); const uint tiih = tiisg%tph; // thread index in head const uint hiisg = tiisg/tph; // head index in simdgroup + half4 ps4[Q][D4/tph]; + // load H heads from Q to shared memory for (int64_t i = 0; i < D4/tph; ++i) { for (int64_t j = sgitg; j < Q; j += nsg) { @@ -2100,7 +2101,8 @@ kernel void kernel_flash_attn_ext_f16( } for (int64_t j = 0; j < Q; ++j) { - ps4[j*T4 + hiisg*D4 + tph*i + tiih] = 0.0h; + //ps4[j*T4 + hiisg*D4 + tph*i + tiih] = 0.0h; + ps4[j][i] = 0.0h; } } @@ -2191,16 +2193,12 @@ kernel void kernel_flash_attn_ext_f16( } for (int64_t j = 0; j < Q; ++j) { - half4 ps4v = ps4[j*T4 + hiisg*D4 + tph*i + tiih]; - for (int p = 0; p < 8; ++p) { const half ms = ss[j*T + 32*p + 2*hiisg + 0]; const half vs = ss[j*T + 32*p + 2*hiisg + 1]; - ps4v = ps4v*ms + pv4v[p]*vs; + ps4[j][i] = ps4[j][i]*ms + pv4v[p]*vs; } - - ps4[j*T4 + hiisg*D4 + tph*i + tiih] = ps4v; } } } @@ -2215,15 +2213,58 @@ kernel void kernel_flash_attn_ext_f16( threadgroup_barrier(mem_flags::mem_threadgroup); // reduce the warps - if (sgitg == 0) { - for (int64_t j = 0; j < Q; ++j) { - for (int64_t sg = 1; sg < nsg; ++sg) { + //if (sgitg == 0) { + // for (int64_t j = 0; j < Q; ++j) { + // for (int64_t sg = 1; sg < nsg; ++sg) { - const half S0 = ss[j*T + 2*hiisg + 0]; - const half S1 = ss[j*T + sg*(H*D + 256) + 2*hiisg + 0]; + // const half S0 = ss[j*T + 2*hiisg + 0]; + // const half S1 = ss[j*T + sg*(256) + 2*hiisg + 0]; - const half M0 = ss[j*T + 2*hiisg + 1]; - const half M1 = ss[j*T + sg*(H*D + 256) + 2*hiisg + 1]; + // const half M0 = ss[j*T + 2*hiisg + 1]; + // const half M1 = ss[j*T + sg*(256) + 2*hiisg + 1]; + + // M = max(M0, M1); + + // const half ms0 = exp(M0 - M); + // const half ms1 = exp(M1 - M); + + // S = S0*ms0 + S1*ms1; + + // if (tiih == 0) { + // ss[j*T + 2*hiisg + 0] = S; + // ss[j*T + 2*hiisg + 1] = M; + // } + + // for (int64_t i = 0; i < D4/tph; ++i) { + // ps4[j*T4 + hiisg*D4 + tph*i + tiih] = ps4[j*T4 + hiisg*D4 + tph*i + tiih]*ms0 + ps4[j*T4 + sg*(256)/4 + hiisg*D4 + tph*i + tiih]*ms1; + // } + // } + + // for (int64_t i = 0; i < D4/tph; ++i) { + // ps4[j*T4 + hiisg*D4 + tph*i + tiih] = ps4[j*T4 + hiisg*D4 + tph*i + tiih]/S; + // } + // } + //} + + for (int64_t sg = 1; sg < nsg; ++sg) { + if (sgitg == sg) { + // store heads to shared memory - reuse pq4 + for (int64_t j = 0; j < Q; ++j) { + for (int64_t i = 0; i < D4/tph; ++i) { + pq4[j*T4 + hiisg*D4 + tph*i + tiih] = ps4[j][i]; + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg == 0) { + for (int64_t j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 2*hiisg + 0]; + const half S1 = ss[j*T + sg*(256) + 2*hiisg + 0]; + + const half M0 = ss[j*T + 2*hiisg + 1]; + const half M1 = ss[j*T + sg*(256) + 2*hiisg + 1]; M = max(M0, M1); @@ -2238,12 +2279,21 @@ kernel void kernel_flash_attn_ext_f16( } for (int64_t i = 0; i < D4/tph; ++i) { - ps4[j*T4 + hiisg*D4 + tph*i + tiih] = ps4[j*T4 + hiisg*D4 + tph*i + tiih]*ms0 + ps4[j*T4 + sg*(H*D + 256)/4 + hiisg*D4 + tph*i + tiih]*ms1; + //ps4[j*T4 + hiisg*D4 + tph*i + tiih] = ps4[j*T4 + hiisg*D4 + tph*i + tiih]*ms0 + ps4[j*T4 + sg*(256)/4 + hiisg*D4 + tph*i + tiih]*ms1; + ps4[j][i] = ps4[j][i]*ms0 + pq4[j*T4 + hiisg*D4 + tph*i + tiih]*ms1; } } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + if (sgitg == 0) { + for (int64_t j = 0; j < Q; ++j) { + S = ss[j*T + 2*hiisg + 0]; for (int64_t i = 0; i < D4/tph; ++i) { - ps4[j*T4 + hiisg*D4 + tph*i + tiih] = ps4[j*T4 + hiisg*D4 + tph*i + tiih]/S; + //ps4[j*T4 + hiisg*D4 + tph*i + tiih] = ps4[j*T4 + hiisg*D4 + tph*i + tiih]/S; + ps4[j][i] = ps4[j][i]/S; } } } @@ -2255,15 +2305,16 @@ kernel void kernel_flash_attn_ext_f16( if (sgitg == 0) { for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { for (int64_t i = 0; i < D4/tph; ++i) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + tph*i + tiih] = (float4) ps4[j*T4 + hiisg*D4 + tph*i + tiih]; + //dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + tph*i + tiih] = (float4) ps4[j*T4 + hiisg*D4 + tph*i + tiih]; + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + tph*i + tiih] = (float4) ps4[j][i]; } } } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 2, 4>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 2, 4>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 2, 4>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 4, 2>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 4, 2>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 4, 2>; kernel void kernel_cpy_f16_f16( device const half * src0,