From 4c9388fb96ac2415fbb1239b7ba8346616606e2e Mon Sep 17 00:00:00 2001 From: Jun Hee Yoo Date: Wed, 23 Oct 2024 19:33:45 +0900 Subject: [PATCH] metal : add POOL2D and fix IM2COL (#9943) * add pool_2d Signed-off-by: Junhee Yoo * fix im2col and add unittest for N>=1024 Signed-off-by: Junhee Yoo * add tests for N % 1024 != 0 Signed-off-by: Junhee Yoo * remove trailing whitespaces Signed-off-by: Junhee Yoo * apply suggestions Signed-off-by: Junhee Yoo * apply more optimization - original IM2COL kernel + _ext with MIN() Signed-off-by: Junhee Yoo * apply review: change kernel name of pool_2d Signed-off-by: Junhee Yoo * apply review Signed-off-by: Junhee Yoo * fix more formatting and enhance readability Signed-off-by: Junhee Yoo --------- Signed-off-by: Junhee Yoo --- ggml/src/ggml-metal.m | 128 ++++++++++++++++++++++---- ggml/src/ggml-metal.metal | 178 +++++++++++++++++++++++++++++++++++++ tests/test-backend-ops.cpp | 10 +++ 3 files changed, 298 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 172a0f925..e9541441c 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -241,6 +241,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F32, + GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, + GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, GGML_METAL_KERNEL_TYPE_UPSCALE_F32, GGML_METAL_KERNEL_TYPE_PAD_F32, GGML_METAL_KERNEL_TYPE_ARANGE_F32, @@ -272,6 +274,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_SIN, GGML_METAL_KERNEL_TYPE_COS, GGML_METAL_KERNEL_TYPE_SUM_ROWS, + GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, + GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, GGML_METAL_KERNEL_TYPE_COUNT }; @@ -685,6 +689,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); @@ -716,6 +722,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); } [metal_library release]; @@ -844,8 +852,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_IM2COL: return op->src[0]->type == GGML_TYPE_F16; case GGML_OP_POOL_1D: - case GGML_OP_POOL_2D: return false; + case GGML_OP_POOL_2D: case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_ARANGE: @@ -2545,6 +2553,8 @@ static void ggml_metal_encode_node( } break; case GGML_OP_IM2COL: { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); @@ -2574,30 +2584,54 @@ static void ggml_metal_encode_node( const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; - id pipeline = nil; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; + + const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup; switch (dst->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; + case GGML_TYPE_F32: { + pipeline = (is_gt_mttpt ? + ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline + : + ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline); + } break; + case GGML_TYPE_F16: { + pipeline = (is_gt_mttpt ? + ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline + : + ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline); + } break; default: GGML_ABORT("fatal error"); }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; - [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; - [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; - [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5]; - [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6]; - [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7]; - [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8]; - [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9]; - [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10]; - [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; - [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2]; + [encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3]; + [encoder setBytes:&IW length:sizeof(int32_t) atIndex:4]; + [encoder setBytes:&IH length:sizeof(int32_t) atIndex:5]; + [encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6]; + [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7]; + [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8]; + [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9]; + [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10]; + [encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11]; + [encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12]; - [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; + if (is_gt_mttpt) { + [encoder setBytes:&N length:sizeof(int32_t) atIndex:13]; + [encoder setBytes:&KH length:sizeof(int32_t) atIndex:14]; + [encoder setBytes:&KW length:sizeof(int32_t) atIndex:15]; + + const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N); + + const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0); + + [encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; + } else { + [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; + } } break; case GGML_OP_UPSCALE: { @@ -3001,6 +3035,64 @@ static void ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; + case GGML_OP_POOL_2D: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt); + + const int32_t * opts = dst->op_params; + enum ggml_op_pool op = opts[0]; + + id pipeline = nil; + switch (src0t) { + case GGML_TYPE_F32: { + switch(op) { + case GGML_OP_POOL_AVG: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break; + case GGML_OP_POOL_MAX: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break; + default: GGML_ASSERT(false && "not implemented"); + } + } break; + default: GGML_ASSERT(false && "not implemented"); + } + + const int32_t k0 = opts[1]; + const int32_t k1 = opts[2]; + const int32_t s0 = opts[3]; + const int32_t s1 = opts[4]; + const int32_t p0 = opts[5]; + const int32_t p1 = opts[6]; + + const int64_t IH = src0->ne[1]; + const int64_t IW = src0->ne[0]; + + const int64_t N = dst->ne[3]; + const int64_t OC = dst->ne[2]; + const int64_t OH = dst->ne[1]; + const int64_t OW = dst->ne[0]; + + const int64_t parallel_elements = N * OC * OH * OW; + const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements); + const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2]; + [encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3]; + [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4]; + [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5]; + [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6]; + [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7]; + [encoder setBytes:&IH length:sizeof(int64_t) atIndex:8]; + [encoder setBytes:&IW length:sizeof(int64_t) atIndex:9]; + [encoder setBytes:&OH length:sizeof(int64_t) atIndex:10]; + [encoder setBytes:&OW length:sizeof(int64_t) atIndex:11]; + [encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12]; + + [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; + } break; default: { GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 2b2000323..71b58be1f 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1933,6 +1933,85 @@ kernel void kernel_im2col( template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; +typedef void (im2col_ext_t)( + device const float * x, + device char * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + constant int32_t & N, + constant int32_t & KH, + constant int32_t & KW, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template +kernel void kernel_im2col_ext( + device const float * x, + device char * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + constant int32_t & N, + constant int32_t & KH, + constant int32_t & KW, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] + const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2] + + const int32_t d = tgpig[0] / CHW; + const int32_t chw = tgpig[0] % CHW; + const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) + const int32_t HW = tgpig[0] % KHW; + + const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0]; + if (tpitg_0 >= N) { + return; + } + + const int32_t tpitg_1 = HW / KW; + const int32_t tpitg_2 = HW % KW; + + const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0; + const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1; + + const int32_t offset_dst = + (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + + (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2); + + device T * pdst = (device T *) (dst); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + pdst[offset_dst] = 0.0f; + } else { + const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1; + pdst[offset_dst] = x[offset_src + iih * IW + iiw]; + } +} + +template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; +template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; + kernel void kernel_upscale_f32( device const char * src0, device char * dst, @@ -6372,3 +6451,102 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; + +kernel void kernel_pool_2d_max_f32( + device const float * src0, + device float * dst, + constant int32_t & k0, + constant int32_t & k1, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int64_t & IH, + constant int64_t & IW, + constant int64_t & OH, + constant int64_t & OW, + constant int64_t & parallel_elements, + uint gid[[thread_position_in_grid]]) { + + if (gid >= parallel_elements) { + return; + } + + const int idx = gid; + const int I_HW = IH * IW; + const int O_HW = OH * OW; + const int nc = idx / O_HW; + const int cur_oh = idx % O_HW / OW; + const int cur_ow = idx % O_HW % OW; + + device const float * i_ptr = src0 + nc * I_HW; + device float * o_ptr = dst + nc * O_HW; + + const int start_h = cur_oh * s1 - p1; + const int bh = MAX(0, start_h); + const int eh = MIN(IH, start_h + k1); + const int start_w = cur_ow * s0 - p0; + const int bw = MAX(0, start_w); + const int ew = MIN(IW, start_w + k0); + + float res = -INFINITY; + + for (int i = bh; i < eh; i += 1) { + for (int j = bw; j < ew; j += 1) { + res = MAX(res, i_ptr[i * IW + j]); + } + } + + o_ptr[cur_oh * OW + cur_ow] = res; +} + +kernel void kernel_pool_2d_avg_f32( + device const float * src0, + device float * dst, + constant int32_t & k0, + constant int32_t & k1, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int64_t & IH, + constant int64_t & IW, + constant int64_t & OH, + constant int64_t & OW, + constant int64_t & parallel_elements, + uint gid[[thread_position_in_grid]]) { + + if (gid >= parallel_elements) { + return; + } + + const int idx = gid; + const int I_HW = IH * IW; + const int O_HW = OH * OW; + const int nc = idx / O_HW; + const int cur_oh = idx % O_HW / OW; + const int cur_ow = idx % O_HW % OW; + + device const float * i_ptr = src0 + nc * I_HW; + device float * o_ptr = dst + nc * O_HW; + + const int start_h = cur_oh * s1 - p1; + const int bh = MAX(0, start_h); + const int eh = MIN(IH, start_h + k1); + const int start_w = cur_ow * s0 - p0; + const int bw = MAX(0, start_w); + const int ew = MIN(IW, start_w + k0); + // const float scale = 1. / ((eh - bh) * (ew - bw)); + const float scale = 1. / (k0 * k1); + + float res = 0; + + for (int i = bh; i < eh; i += 1) { + for (int j = bw; j < ew; j += 1) { + float cur = i_ptr[i * IW + j]; + res += cur * scale; + } + } + + o_ptr[cur_oh * OW + cur_ow] = res; +} diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ee1a8877e..e087f7ba5 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3316,6 +3316,16 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); + // test cases for 2D im2col + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 1024}, {3, 3, 2, 1024}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2048}, {3, 3, 1, 2048}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2048}, {3, 3, 2, 2048}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true)); + // sycl backend will limit task global_range < MAX_INT // test cases for 2D im2col with large input W and H (occurs in stable-diffusion) // however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.)