From ac3f7d8e23f1b4785de6e9d2c40d499d2ca94518 Mon Sep 17 00:00:00 2001 From: slaren Date: Sat, 9 Dec 2023 19:19:03 +0100 Subject: [PATCH] ggml : get_rows : support non-contiguos tensors with gaps, generalize up to 3D --- ggml.c | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/ggml.c b/ggml.c index 07d23f427..fb6ba1fc3 100644 --- a/ggml.c +++ b/ggml.c @@ -4734,7 +4734,8 @@ struct ggml_tensor * ggml_get_rows( struct ggml_tensor * a, struct ggml_tensor * b) { GGML_ASSERT(a->ne[2] == b->ne[1]); - GGML_ASSERT(ggml_is_matrix(b) && b->type == GGML_TYPE_I32); + GGML_ASSERT(b->ne[3] == 1); + GGML_ASSERT(b->type == GGML_TYPE_I32); bool is_node = false; @@ -4744,7 +4745,7 @@ struct ggml_tensor * ggml_get_rows( // TODO: implement non F32 return //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); - struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1]); + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]); result->op = GGML_OP_GET_ROWS; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -10414,7 +10415,6 @@ static void ggml_compute_forward_get_rows_f32( GGML_TENSOR_BINARY_OP_LOCALS const int64_t nc = ne00; - const int64_t nr = ggml_nelements(src1); assert(ne0 == nc); assert(ne02 == ne11); @@ -10422,14 +10422,17 @@ static void ggml_compute_forward_get_rows_f32( assert(ggml_nrows(dst) == nr); // TODO: multi-thread - for (int64_t i = 0; i < nr; ++i) { - const int64_t r = ((int32_t *) src1->data)[i]; + // TODO: same impl for get_rows_q and get_rows_f16 + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + for (int64_t i10 = 0; i10 < ne10; ++i10) { + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - const int64_t i02 = i/ne10; - - ggml_vec_cpy_f32(nc, - (float *) ((char *) dst->data + i*nb1), - (float *) ((char *) src0->data + i02*nb02 + r*nb01)); + ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), + (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); + } + } } }