diff --git a/ggml-metal.m b/ggml-metal.m index 74a6bff40..3f098d396 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -718,7 +718,8 @@ void ggml_metal_graph_compute( // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224 GGML_ASSERT(ne00 == ne10); - GGML_ASSERT(ne02 == ne12); + // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere + GGML_ASSERT(ne03 == ne13); if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && @@ -746,11 +747,11 @@ void ggml_metal_graph_compute( initWithDevice:ctx->device transposeLeft:false transposeRight:true resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0]; - // we need to do ne02 multiplications + // we need to do ne12 multiplications // TODO: is there a way to do this in parallel - currently very slow .. // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS - for (int64_t i02 = 0; i02 < ne02; ++i02) { - size_t offs_src0_cur = offs_src0 + i02*nb02; + for (int64_t i02 = 0; i02 < ne12; ++i02) { + size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now size_t offs_src1_cur = offs_src1 + i02*nb12; size_t offs_dst_cur = offs_dst + i02*nb2; @@ -772,8 +773,6 @@ void ggml_metal_graph_compute( switch (src0t) { case GGML_TYPE_F16: { - GGML_ASSERT(ne02 == ne12); - nth0 = 64; nth1 = 1; [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; @@ -853,16 +852,18 @@ void ggml_metal_graph_compute( [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { diff --git a/ggml-metal.metal b/ggml-metal.metal index 696b33ce7..8d26b5ec2 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -509,11 +509,13 @@ kernel void kernel_mul_mat_f16_f32( device float * dst, constant int64_t & ne00, constant int64_t & ne01, + constant int64_t & ne02, constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, constant int64_t & ne10, constant int64_t & ne11, + constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, @@ -529,7 +531,7 @@ kernel void kernel_mul_mat_f16_f32( const int64_t r1 = tgpig.y; const int64_t im = tgpig.z; - device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02); + device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); sum[tpitg.x] = 0.0f; @@ -552,6 +554,7 @@ kernel void kernel_mul_mat_f16_f32( } } + kernel void kernel_alibi_f32( device const float * src0, device float * dst,