From 2cbcba829f9c7e80a77473c0eadc7d14d3287681 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 9 Dec 2023 14:18:42 +0200 Subject: [PATCH] metal : add more general support for ggml_get_rows + tests --- ggml-metal.m | 16 +++++----- ggml-metal.metal | 62 ++++++++++++++++++++++++++++++++++---- ggml.c | 6 ++-- tests/test-backend-ops.cpp | 19 ++++++------ 4 files changed, 78 insertions(+), 25 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 8389373a8..28c628958 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -805,8 +805,9 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) { case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: - case GGML_OP_TRANSPOSE: case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_GET_ROWS: case GGML_OP_CONCAT: case GGML_OP_ADD: case GGML_OP_MUL: @@ -828,7 +829,6 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) { case GGML_OP_MUL_MAT_ID: return true; case GGML_OP_DIAG_MASK_INF: - case GGML_OP_GET_ROWS: { return op->ne[0] % 4 == 0; } @@ -1568,16 +1568,18 @@ void ggml_metal_graph_compute( default: GGML_ASSERT(false && "not implemented"); } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:7]; const int64_t n = ggml_nelements(src1); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; case GGML_OP_RMS_NORM: { diff --git a/ggml-metal.metal b/ggml-metal.metal index 2f8ea22d6..6723200c7 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -3223,14 +3223,16 @@ kernel void kernel_get_rows( device float * dst, constant int64_t & ne00, constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, constant uint64_t & nb1, uint tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], - uint tptg[[threads_per_threadgroup]]) { - const int i = tgpig; - const int r = ((device int32_t *) src1)[i]; + uint tptg [[threads_per_threadgroup]]) { + const int64_t i = tgpig; + const int64_t r = ((device int32_t *) src1)[i]; - for (int ind = tiitg; ind < ne00/16; ind += tptg) { + for (int64_t ind = tiitg; ind < ne00/16; ind += tptg) { float4x4 temp; dequantize_func( ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp); @@ -3238,6 +3240,52 @@ kernel void kernel_get_rows( } } +kernel void kernel_get_rows_f32( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb1, + uint tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tptg [[threads_per_threadgroup]]) { + const int64_t i = tgpig; + const int64_t r = ((device int32_t *) src1)[i]; + + const int64_t i02 = i/ne10; + + for (int ind = tiitg; ind < ne00; ind += tptg) { + ((device float *) ((device char *) dst + i*nb1))[ind] = + ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_f16( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb1, + uint tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tptg [[threads_per_threadgroup]]) { + const int64_t i = tgpig; + const int64_t r = ((device int32_t *) src1)[i]; + + const int64_t i02 = i/ne10; + + for (int ind = tiitg; ind < ne00; ind += tptg) { + ((device float *) ((device char *) dst + i*nb1))[ind] = + ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B #define BLOCK_SIZE_K 32 @@ -3490,11 +3538,13 @@ typedef void (get_rows_t)( device float * dst, constant int64_t & ne00, constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, constant uint64_t & nb1, uint, uint, uint); -template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; +//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; +//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows; diff --git a/ggml.c b/ggml.c index 4bdb70248..5f94ede00 100644 --- a/ggml.c +++ b/ggml.c @@ -10363,7 +10363,7 @@ static void ggml_compute_forward_get_rows_q( dequantize_row_q( (const void *) ((char *) src0->data + i02*nb02 + r*nb01), - (float *) ((char *) dst->data + i*dst->nb[1]), nc); + (float *) ((char *) dst->data + i*nb1), nc); } } @@ -10396,7 +10396,7 @@ static void ggml_compute_forward_get_rows_f16( for (int j = 0; j < nc; ++j) { ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i02*nb02 + r*nb01))[j]; - ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v); + ((float *) ((char *) dst->data + i*nb1))[j] = GGML_FP16_TO_FP32(v); } } } @@ -10429,7 +10429,7 @@ static void ggml_compute_forward_get_rows_f32( const int64_t i02 = i/ne10; ggml_vec_cpy_f32(nc, - (float *) ((char *) dst->data + i*dst->nb[1]), + (float *) ((char *) dst->data + i*nb1), (float *) ((char *) src0->data + i02*nb02 + r*nb01)); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index dddc2b899..c98ca45e0 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -488,17 +488,18 @@ struct test_get_rows : public test_case { const int n; // cols const int m; // rows const int r; // rows to get + const int b; // batch size std::string vars() override { return VARS_TO_STR4(type, n, m, r); } - test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3) - : type(type), n(n), m(m), r(r) {} + test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1) + : type(type), n(n), m(m), r(r), b(b) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * in = ggml_new_tensor_2d(ctx, type, n, m); - ggml_tensor * rows = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, r); + ggml_tensor * in = ggml_new_tensor_3d(ctx, type, n, m, b); + ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b); ggml_tensor * out = ggml_get_rows(ctx, in, rows); return out; } @@ -507,11 +508,11 @@ struct test_get_rows : public test_case { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { if (t->type == GGML_TYPE_I32) { // rows - std::vector data(r); - for (int i = 0; i < r; i++) { + std::vector data(r*b); + for (int i = 0; i < r*b; i++) { data[i] = rand() % m; } - ggml_backend_tensor_set(t, data.data(), 0, r * sizeof(int)); + ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int)); } else { init_tensor_uniform(t); } @@ -1125,8 +1126,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { - test_cases.emplace_back(new test_get_rows(type, 10, 5, 3)); - test_cases.emplace_back(new test_get_rows(type, 16, 5, 3)); + test_cases.emplace_back(new test_get_rows(type, 10, 5, 3, 7)); + test_cases.emplace_back(new test_get_rows(type, 16, 5, 3, 7)); } test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));