From 1173f49c3bbe30810af4aeb77219eba7e05f658d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 Jan 2024 17:32:28 +0200 Subject: [PATCH] metal : initial implementation --- ggml-metal.m | 69 +++++++++++++------ ggml-metal.metal | 138 ++++++++++++++++++++++++++++++++++--- ggml.c | 2 +- tests/test-backend-ops.cpp | 4 ++ 4 files changed, 180 insertions(+), 33 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 4d85dd3dd..556c53482 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -278,6 +278,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { NSURL * libURL = [NSURL fileURLWithPath:libPath]; GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]); ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; + if (error) { + GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } } else { GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); @@ -316,13 +320,12 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { //[options setFastMathEnabled:false]; ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; + if (error) { + GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } } } - - if (error) { - GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } } // print MTL GPU family: @@ -396,6 +399,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \ + GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ + (int) kernel->pipeline.threadExecutionWidth); \ if (error) { \ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ return NULL; \ @@ -2171,12 +2177,28 @@ static bool ggml_metal_graph_compute( struct ggml_tensor * src2 = gf->nodes[i]->src[2]; struct ggml_tensor * src3 = gf->nodes[i]->src[3]; + GGML_ASSERT(ggml_are_same_shape(src1, src2)); + size_t offs_src2 = 0; size_t offs_src3 = 0; - id id_src2 = src2 ? ggml_metal_get_buffer(ctx, src2, &offs_src2) : nil; + GGML_ASSERT(src2); + id id_src2 = ggml_metal_get_buffer(ctx, src2, &offs_src2); + id id_src3 = src3 ? ggml_metal_get_buffer(ctx, src3, &offs_src3) : nil; + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); + const int64_t ne31 = src3 ? src3->ne[1] : 0; + const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); + const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); + + const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); + const uint64_t nb31 = src3 ? src3->nb[1] : 0; + const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); + const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); + + const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); + float scale; memcpy(&scale, dst->op_params, sizeof(float)); @@ -2197,25 +2219,28 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:16]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&scale length:sizeof( float) atIndex:21]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; + [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int nwarps = 4; + const int nwarps = 1; - // each warp needs n_embd_head elements - GGML_ASSERT(nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:nwarps*ne00*sizeof(float) atIndex:0]; + GGML_ASSERT(2*32*nwarps*ne00*sizeof(float) <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*sizeof(float) atIndex:0]; - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nwarps, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: diff --git a/ggml-metal.metal b/ggml-metal.metal index a1e1755a3..5986bcb42 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1960,10 +1960,10 @@ kernel void kernel_leaky_relu_f32( } kernel void kernel_flash_attn_ext_f16( - device const half * q, - device const half * k, - device const half * v, - device const float * mask, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, device float * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -1973,20 +1973,138 @@ kernel void kernel_flash_attn_ext_f16( constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, constant float & scale, threadgroup float * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - // TODO: implement + uint3 ntg[[threads_per_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int64_t iq3 = tgpig[2]; + const int64_t iq2 = tgpig[1]; + const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg; + + if (iq1 >= ne01) { + return; + } + + const int64_t D = ne00; + + // TODO: can we move this to the stack? + threadgroup half * V16 = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D); + + // initialize with zeros + for (int64_t d = 0; d < D; ++d) { + V16[d] = 0.0h; + } + + threadgroup half * pq = (threadgroup half *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D); + + half S = 0.0h; + half M = -INFINITY; + + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; + + device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr; + + // assume K and V are same shape + const int64_t ne22 = ne12; + const int64_t ne23 = ne13; + + const uint64_t nb21 = nb11; + const uint64_t nb22 = nb12; + const uint64_t nb23 = nb13; + + // broadcast + const int64_t rk2 = ne02/ne12; + const int64_t rk3 = ne03/ne13; + + const int64_t rv2 = ne02/ne22; + const int64_t rv3 = ne03/ne23; + + // k indices + const int64_t ik2 = iq2 / rk2; + const int64_t ik3 = iq3 / rk3; + + // v indices + const int64_t iv2 = iq2 / rv2; + const int64_t iv3 = iq3 / rv3; + + // load Q to shared memory + for (int64_t d = 0; d < D; ++d) { + pq[d] = ((device const half *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d]; + } + + for (int64_t ic = 0; ic < ne11; ++ic) { + const half mv = mp ? mp[ic] : 0.0h; + if (mv == -INFINITY) { + continue; + } + + half s = 0.0f; + + //device const half * pq = (device const half *) ((device char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + device const half * pk = (device const half *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13)); + + for (int64_t d = 0; d < D; ++d) { + s += pk[d] * pq[d]; + } + + s = s*scale + mv; + + const half Mold = M; + + half ms = 1.0f; + half vs = 1.0f; + + if (s > M) { + M = s; + ms = exp(Mold - M); + + // V = V*exp(Mold - M) + for (int64_t d = 0; d < D; ++d) { + V16[d] *= ms; + } + } else { + vs = exp(s - M); + } + + device const half * pv = (device const half *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + + // V += v*exp(s - M) + for (int64_t d = 0; d < D; ++d) { + V16[d] += pv[d] * vs; + } + + S = S*ms + vs; + } + + for (int64_t d = 0; d < D; ++d) { + V16[d] /= S; + } + + // dst indices + const int64_t i1 = iq1; + const int64_t i2 = iq2; + const int64_t i3 = iq3; + + for (int64_t d = 0; d < D; ++d) { + dst[(i3*ne2*ne1 + i2 + i1*ne1)*D + d] = V16[d]; + } } kernel void kernel_cpy_f16_f16( diff --git a/ggml.c b/ggml.c index e64a328fa..10df03c9c 100644 --- a/ggml.c +++ b/ggml.c @@ -13419,8 +13419,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int ik2 = iq2 / rk2; // v indices - const int iv2 = iq2 / rv2; const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; // online softmax / attention // loop over n_kv and n_head_kv diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a56c0d6c5..51a33c662 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1396,6 +1396,10 @@ struct test_flash_attn_ext : public test_case { return VARS_TO_STR5(typeq, hs, nh, kv, nb); } + double max_nmse_err() override { + return 5e-4; + } + test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) : typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {}