ggml : get_rows : support non-contiguos tensors with gaps, generalize up to 3D

This commit is contained in:
slaren 2023-12-09 19:19:03 +01:00
parent 8c5b66eeaa
commit ac3f7d8e23

21
ggml.c
View File

@ -4734,7 +4734,8 @@ struct ggml_tensor * ggml_get_rows(
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b) { struct ggml_tensor * b) {
GGML_ASSERT(a->ne[2] == b->ne[1]); GGML_ASSERT(a->ne[2] == b->ne[1]);
GGML_ASSERT(ggml_is_matrix(b) && b->type == GGML_TYPE_I32); GGML_ASSERT(b->ne[3] == 1);
GGML_ASSERT(b->type == GGML_TYPE_I32);
bool is_node = false; bool is_node = false;
@ -4744,7 +4745,7 @@ struct ggml_tensor * ggml_get_rows(
// TODO: implement non F32 return // TODO: implement non F32 return
//struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1]); struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
result->op = GGML_OP_GET_ROWS; result->op = GGML_OP_GET_ROWS;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -10414,7 +10415,6 @@ static void ggml_compute_forward_get_rows_f32(
GGML_TENSOR_BINARY_OP_LOCALS GGML_TENSOR_BINARY_OP_LOCALS
const int64_t nc = ne00; const int64_t nc = ne00;
const int64_t nr = ggml_nelements(src1);
assert(ne0 == nc); assert(ne0 == nc);
assert(ne02 == ne11); assert(ne02 == ne11);
@ -10422,14 +10422,17 @@ static void ggml_compute_forward_get_rows_f32(
assert(ggml_nrows(dst) == nr); assert(ggml_nrows(dst) == nr);
// TODO: multi-thread // TODO: multi-thread
for (int64_t i = 0; i < nr; ++i) { // TODO: same impl for get_rows_q and get_rows_f16
const int64_t r = ((int32_t *) src1->data)[i]; for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
const int64_t i02 = i/ne10; for (int64_t i10 = 0; i10 < ne10; ++i10) {
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
ggml_vec_cpy_f32(nc, ggml_vec_cpy_f32(nc,
(float *) ((char *) dst->data + i*nb1), (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
(float *) ((char *) src0->data + i02*nb02 + r*nb01)); (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
}
}
} }
} }