metal : efficient flash_attn_f16 implementation

This commit is contained in:
Georgi Gerganov 2024-01-23 18:27:54 +02:00
parent 17720fad66
commit 1446a12b29
No known key found for this signature in database
GPG Key ID: BF970631944C16B7
3 changed files with 195 additions and 118 deletions

View File

@ -2183,6 +2183,7 @@ static bool ggml_metal_graph_compute(
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
GGML_ASSERT(ggml_are_same_shape(src1, src2));
GGML_ASSERT(src3);
size_t offs_src2 = 0;
size_t offs_src3 = 0;
@ -2252,15 +2253,20 @@ static bool ggml_metal_graph_compute(
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
const int64_t nwarps = 32;
const int64_t nhptg = 2; // heads per threadgroup
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
const int64_t nsg = ne01 < 4 ? 4 : 2; // simdgroups per threadgroup (a.k.a. warps)
const size_t smem = (nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2);
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values)
//const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2);
const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2);
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + nhptg - 1)/(nhptg), ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} break;
case GGML_OP_DUP:
case GGML_OP_CPY:

View File

@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]);
template<int64_t D, int64_t R> // head size, rows per threadgroup
template<int64_t D, int64_t Q, int64_t C> // head size, heads per threadgroup, queries per threadgroup
kernel void kernel_flash_attn_ext_f16(
device const char * q,
device const char * k,
@ -2031,178 +2031,247 @@ kernel void kernel_flash_attn_ext_f16(
uint3 ntg[[threads_per_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const uint nsg = ntg.y; // number of simdgroups
const uint tph = N_SIMDWIDTH/R; // threads per head
const uint nsg = ntg.y; // number of simdgroups
const int64_t iq3 = tgpig[2];
const int64_t iq2 = tgpig[1]*R + tiisg/tph;
const int64_t iq1 = tgpig[0];
const int64_t iq2 = tgpig[1];
const int64_t iq1 = tgpig[0]*Q;
if (iq2 >= ne02) {
return;
}
// assume K and V are same shape
const int64_t ne22 = ne12;
const int64_t ne23 = ne13;
const uint64_t nb21 = nb11;
const uint64_t nb22 = nb12;
const uint64_t nb23 = nb13;
// broadcast
const int64_t rk2 = ne02/ne12;
const int64_t rk3 = ne03/ne13;
const int64_t rv2 = ne02/ne22;
const int64_t rv3 = ne03/ne23;
// k indices
const int64_t ik2 = iq2 / rk2;
const int64_t ik3 = iq3 / rk3;
// v indices
const int64_t iv2 = iq2 / rv2;
const int64_t iv3 = iq3 / rv3;
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr;
const int64_t D4 = D/4;
const int64_t N4 = N_SIMDWIDTH;
const int64_t L4 = (D4 + N4 - 1)/N4;
const int64_t D8 = D/8;
threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*R*D);
threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(R*D + 32) + 1*R*D);
threadgroup half * ss = (threadgroup half *) (shared + sgitg*(R*D + 32) + 2*R*D);
const int64_t T = D + nsg*(D + 1*C); // shared memory size per query in half
const int64_t T4 = T/4; // shared memory size per query in half4
const uint tiih = tiisg%tph; // thread index in head
const uint hiisg = tiisg/tph; // head index in simdgroup
threadgroup half * pq = (threadgroup half *) (shared + 0*D);
threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D);
threadgroup half * ps = (threadgroup half *) (shared + sgitg*(D + 1*C) + 1*D);
threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(D + 1*C) + 1*D);
threadgroup half * ss = (threadgroup half *) (shared + sgitg*(D + 1*C) + 2*D);
// load R heads from Q to shared memory
for (int64_t i = 0; i < D4/tph; ++i) {
if (sgitg == 0) {
pq4[hiisg*D4 + tph*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih];
for (int64_t i = 0; i < L4; ++i) {
// load heads from Q to shared memory
for (int64_t j = sgitg; j < Q; j += nsg) {
if (iq1 + j < ne01) {
pq4[j*T4 + N4*i + tiisg] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + tiisg];
} else {
pq4[j*T4 + N4*i + tiisg] = 0.0h;
}
}
ps4[hiisg*D4 + tph*i + tiih] = 0.0h;
// zero out shared memory
for (int64_t j = 0; j < Q; ++j) {
ps4[j*T4 + N4*i + tiisg] = 0.0h;
}
}
if (tiisg < C) {
for (int64_t j = 0; j < Q; ++j) {
ss[j*T + 0 + tiisg] = 0.0h;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
half S = 0.0h;
half M = -INFINITY;
{
half S[Q] = { 0.0h };
half M[Q] = { -INFINITY };
for (int64_t ic = sgitg; ic < ne11; ic += nsg) {
const half mv = mp ? mp[ic] : 0.0h;
if (mv == -INFINITY) {
continue;
// assume K and V are same shape
const int64_t ne22 = ne12;
const int64_t ne23 = ne13;
const uint64_t nb21 = nb11;
const uint64_t nb22 = nb12;
const uint64_t nb23 = nb13;
// broadcast
const int64_t rk2 = ne02/ne12;
const int64_t rk3 = ne03/ne13;
const int64_t rv2 = ne02/ne22;
const int64_t rv3 = ne03/ne23;
// k indices
const int64_t ik2 = iq2 / rk2;
const int64_t ik3 = iq3 / rk3;
// v indices
const int64_t iv2 = iq2 / rv2;
const int64_t iv3 = iq3 / rv3;
simdgroup_half8x8 mq[D8];
for (int64_t i = 0; i < D8; ++i) {
simdgroup_load(mq[i], pq + i*8, T);
}
device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13));
device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));
// TODO: this can be improved
device const float * mp[Q];
half4 s4 = 0.0h;
{
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
#pragma unroll
for (int64_t i = 0; i < D4/tph; ++i) {
s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih];
for (int64_t j = 0; j < Q; ++j) {
if (iq1 + j < ne01) {
mp[j] = (device const float *) (mask + ((ir + j)%ne31)*nb31);
} else {
mp[j] = nullptr;
}
}
}
ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w);
for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) {
// skip -INF blocks
// TODO: double-check this
{
float smc = -INFINITY;
simdgroup_barrier(mem_flags::mem_threadgroup);
for (int64_t j = 0; j < Q; ++j) {
const float mc = mp[j] ? mp[j][iic + tiisg] : -INFINITY;
smc = simd_max(max(smc, mc));
}
if (tiih == 0) {
half s = 0.0h;
#pragma unroll
for (int64_t i = 0; i < tph; ++i) {
s += ss[hiisg*tph + i];
if (smc == -INFINITY) {
continue;
}
}
s = s*scale + mv;
// Q*K^T
{
simdgroup_half8x8 mk;
const half m = M;
for (int cc = 0; cc < C/8; ++cc) {
simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half, Q>(0.h);
M = max(M, s);
device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
const half ms = exp(m - M);
const half vs = exp(s - M);
for (int64_t i = 0; i < D8; ++i) {
simdgroup_load(mk, pk + i*8, nb11/2, 0, true);
S = S*ms + vs;
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
}
ss[2*hiisg + 0] = ms;
ss[2*hiisg + 1] = vs;
simdgroup_store(mqk, ss + 8*cc, T, 0, false);
}
}
// online softmax
for (int64_t j = 0; j < Q; ++j) {
const int64_t p = tiisg;
const half s = ss[j*T + p]*scale + (mp[j][iic + p]);
half m = M[j];
M[j] = simd_max(max(M[j], s));
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
S[j] = S[j]*ms + simd_sum(vs);
for (int64_t i = 0; i < L4; ++i) {
ps4[j*T4 + N4*i + tiisg] *= ms;
}
ss[j*T + p] = vs;
}
// (Q*K^T)*V
{
simdgroup_half8x8 mv;
for (int64_t i = 0; i < D8; ++i) {
simdgroup_half8x8 mp[C/8];
simdgroup_half8x8 mqkv;
simdgroup_load(mqkv, ps + i*8, T, 0, false);
for (int cc = 0; cc < C/8; ++cc) {
simdgroup_load(mp[cc], ss + 8*cc, T, 0, false);
}
for (int cc = 0; cc < C/8; ++cc) {
device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
simdgroup_load(mv, pv + i*8, nb21/2, 0, false);
simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv);
}
simdgroup_store(mqkv, ps + i*8, T, 0, false);
}
}
}
simdgroup_barrier(mem_flags::mem_threadgroup);
const half ms = ss[2*hiisg + 0];
const half vs = ss[2*hiisg + 1];
#pragma unroll
for (int64_t i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs;
for (int64_t j = 0; j < Q; ++j) {
if (tiisg == 0) {
ss[j*T + 0] = S[j];
ss[j*T + 1] = M[j];
}
}
}
if (tiih == 0) {
ss[2*hiisg + 0] = S;
ss[2*hiisg + 1] = M;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// reduce the warps
// TODO: try parallel reduce
if (sgitg == 0) {
half S = { 0.0h };
half M = { -INFINITY };
for (int64_t sg = 1; sg < nsg; ++sg) {
const half S0 = ss[ 2*hiisg + 0];
const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0];
for (int64_t j = 0; j < Q; ++j) {
const half S0 = ss[j*T + 0];
const half S1 = ss[j*T + sg*(D + 1*C) + 0];
const half M0 = ss[ 2*hiisg + 1];
const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1];
const half M0 = ss[j*T + 1];
const half M1 = ss[j*T + sg*(D + 1*C) + 1];
M = max(M0, M1);
M = max(M0, M1);
const half ms0 = exp(M0 - M);
const half ms1 = exp(M1 - M);
const half ms0 = exp(M0 - M);
const half ms1 = exp(M1 - M);
S = S0*ms0 + S1*ms1;
S = S0*ms0 + S1*ms1;
if (tiih == 0) {
ss[2*hiisg + 0] = S;
ss[2*hiisg + 1] = M;
if (tiisg == 0) {
ss[j*T + 0] = S;
ss[j*T + 1] = M;
}
for (int64_t i = 0; i < L4; ++i) {
ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1;
}
}
for (int64_t i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1;
}
}
for (int64_t i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S;
}
}
simdgroup_barrier(mem_flags::mem_threadgroup);
// dst indices
const int64_t i1 = iq1;
const int64_t i2 = iq2;
const int64_t i3 = iq3;
device float4 * dst4 = (device float4 *) dst;
if (sgitg == 0) {
for (int64_t i = 0; i < D4/tph; ++i) {
dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih] = (float4) ps4[hiisg*D4 + tph*i + tiih];
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
const half S = ss[j*T + 0];
for (int64_t i = 0; i < L4; ++i) {
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + N4*i + tiisg] = (float4) ps4[j*T4 + N4*i + tiisg]/S;
}
}
}
}
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 2>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 2>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 2>;
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>;
kernel void kernel_cpy_f16_f16(
device const half * src0,

View File

@ -1397,7 +1397,7 @@ struct test_flash_attn_ext : public test_case {
}
double max_nmse_err() override {
return 5e-4;
return 5e-5;
}
test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16,
@ -1680,7 +1680,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_pad());
test_cases.emplace_back(new test_leaky_relu());
test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 96, 8));
test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 8));
test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 7));
test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 1));
#if !defined(__SANITIZE_THREAD__)
// FIXME: these tests use too much memory with thread sanitizer