ggml : fix ggml_get_rows to take into account ne02 / ne11

This commit is contained in:
Georgi Gerganov 2023-12-09 14:04:54 +02:00
parent ee8fb399aa
commit 9064b1ca05
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

63
ggml.c
View File

@ -10342,20 +10342,27 @@ static void ggml_compute_forward_get_rows_q(
return; return;
} }
const int nc = src0->ne[0]; GGML_TENSOR_BINARY_OP_LOCALS
const int nr = ggml_nelements(src1);
const int64_t nc = ne00;
const int64_t nr = ggml_nelements(src1);
const enum ggml_type type = src0->type; const enum ggml_type type = src0->type;
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
assert( dst->ne[0] == nc); assert(ne0 == nc);
assert(ne02 == ne11);
assert(nb00 == ggml_type_size(type));
assert(ggml_nrows(dst) == nr); assert(ggml_nrows(dst) == nr);
assert(src0->nb[0] == ggml_type_size(type));
for (int i = 0; i < nr; ++i) { // TODO: multi-thread
const int r = ((int32_t *) src1->data)[i]; for (int64_t i = 0; i < nr; ++i) {
const int64_t r = ((int32_t *) src1->data)[i];
const int64_t i02 = i/ne10;
dequantize_row_q( dequantize_row_q(
(const void *) ((char *) src0->data + r*src0->nb[1]), (const void *) ((char *) src0->data + i02*nb02 + r*nb01),
(float *) ((char *) dst->data + i*dst->nb[1]), nc); (float *) ((char *) dst->data + i*dst->nb[1]), nc);
} }
} }
@ -10371,19 +10378,25 @@ static void ggml_compute_forward_get_rows_f16(
return; return;
} }
const int nc = src0->ne[0]; GGML_TENSOR_BINARY_OP_LOCALS
const int nr = ggml_nelements(src1);
assert( dst->ne[0] == nc); const int64_t nc = ne00;
const int64_t nr = ggml_nelements(src1);
assert(ne0 == nc);
assert(ne02 == ne11);
assert(nb00 == sizeof(ggml_fp16_t));
assert(ggml_nrows(dst) == nr); assert(ggml_nrows(dst) == nr);
assert(src0->nb[0] == sizeof(ggml_fp16_t));
for (int i = 0; i < nr; ++i) { // TODO: multi-thread
const int r = ((int32_t *) src1->data)[i]; for (int64_t i = 0; i < nr; ++i) {
const int64_t r = ((int32_t *) src1->data)[i];
const int64_t i02 = i/ne10;
for (int j = 0; j < nc; ++j) { for (int j = 0; j < nc; ++j) {
ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j]; ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i02*nb02 + r*nb01))[j];
((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v); ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
} }
} }
} }
@ -10399,19 +10412,25 @@ static void ggml_compute_forward_get_rows_f32(
return; return;
} }
const int nc = src0->ne[0]; GGML_TENSOR_BINARY_OP_LOCALS
const int nr = ggml_nelements(src1);
assert( dst->ne[0] == nc); const int64_t nc = ne00;
const int64_t nr = ggml_nelements(src1);
assert(ne0 == nc);
assert(ne02 == ne11);
assert(nb00 == sizeof(float));
assert(ggml_nrows(dst) == nr); assert(ggml_nrows(dst) == nr);
assert(src0->nb[0] == sizeof(float));
for (int i = 0; i < nr; ++i) { // TODO: multi-thread
const int r = ((int32_t *) src1->data)[i]; for (int64_t i = 0; i < nr; ++i) {
const int64_t r = ((int32_t *) src1->data)[i];
const int64_t i02 = i/ne10;
ggml_vec_cpy_f32(nc, ggml_vec_cpy_f32(nc,
(float *) ((char *) dst->data + i*dst->nb[1]), (float *) ((char *) dst->data + i*dst->nb[1]),
(float *) ((char *) src0->data + r*src0->nb[1])); (float *) ((char *) src0->data + i02*nb02 + r*nb01));
} }
} }