mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-30 22:03:03 +01:00
ggml : fix ggml_get_rows to take into account ne02 / ne11
This commit is contained in:
parent
ee8fb399aa
commit
9064b1ca05
63
ggml.c
63
ggml.c
@ -10342,20 +10342,27 @@ static void ggml_compute_forward_get_rows_q(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int nc = src0->ne[0];
|
GGML_TENSOR_BINARY_OP_LOCALS
|
||||||
const int nr = ggml_nelements(src1);
|
|
||||||
|
const int64_t nc = ne00;
|
||||||
|
const int64_t nr = ggml_nelements(src1);
|
||||||
|
|
||||||
const enum ggml_type type = src0->type;
|
const enum ggml_type type = src0->type;
|
||||||
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
|
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
|
||||||
|
|
||||||
assert( dst->ne[0] == nc);
|
assert(ne0 == nc);
|
||||||
|
assert(ne02 == ne11);
|
||||||
|
assert(nb00 == ggml_type_size(type));
|
||||||
assert(ggml_nrows(dst) == nr);
|
assert(ggml_nrows(dst) == nr);
|
||||||
assert(src0->nb[0] == ggml_type_size(type));
|
|
||||||
|
|
||||||
for (int i = 0; i < nr; ++i) {
|
// TODO: multi-thread
|
||||||
const int r = ((int32_t *) src1->data)[i];
|
for (int64_t i = 0; i < nr; ++i) {
|
||||||
|
const int64_t r = ((int32_t *) src1->data)[i];
|
||||||
|
|
||||||
|
const int64_t i02 = i/ne10;
|
||||||
|
|
||||||
dequantize_row_q(
|
dequantize_row_q(
|
||||||
(const void *) ((char *) src0->data + r*src0->nb[1]),
|
(const void *) ((char *) src0->data + i02*nb02 + r*nb01),
|
||||||
(float *) ((char *) dst->data + i*dst->nb[1]), nc);
|
(float *) ((char *) dst->data + i*dst->nb[1]), nc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -10371,19 +10378,25 @@ static void ggml_compute_forward_get_rows_f16(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int nc = src0->ne[0];
|
GGML_TENSOR_BINARY_OP_LOCALS
|
||||||
const int nr = ggml_nelements(src1);
|
|
||||||
|
|
||||||
assert( dst->ne[0] == nc);
|
const int64_t nc = ne00;
|
||||||
|
const int64_t nr = ggml_nelements(src1);
|
||||||
|
|
||||||
|
assert(ne0 == nc);
|
||||||
|
assert(ne02 == ne11);
|
||||||
|
assert(nb00 == sizeof(ggml_fp16_t));
|
||||||
assert(ggml_nrows(dst) == nr);
|
assert(ggml_nrows(dst) == nr);
|
||||||
assert(src0->nb[0] == sizeof(ggml_fp16_t));
|
|
||||||
|
|
||||||
for (int i = 0; i < nr; ++i) {
|
// TODO: multi-thread
|
||||||
const int r = ((int32_t *) src1->data)[i];
|
for (int64_t i = 0; i < nr; ++i) {
|
||||||
|
const int64_t r = ((int32_t *) src1->data)[i];
|
||||||
|
|
||||||
|
const int64_t i02 = i/ne10;
|
||||||
|
|
||||||
for (int j = 0; j < nc; ++j) {
|
for (int j = 0; j < nc; ++j) {
|
||||||
ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j];
|
ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i02*nb02 + r*nb01))[j];
|
||||||
((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
|
((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -10399,19 +10412,25 @@ static void ggml_compute_forward_get_rows_f32(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int nc = src0->ne[0];
|
GGML_TENSOR_BINARY_OP_LOCALS
|
||||||
const int nr = ggml_nelements(src1);
|
|
||||||
|
|
||||||
assert( dst->ne[0] == nc);
|
const int64_t nc = ne00;
|
||||||
|
const int64_t nr = ggml_nelements(src1);
|
||||||
|
|
||||||
|
assert(ne0 == nc);
|
||||||
|
assert(ne02 == ne11);
|
||||||
|
assert(nb00 == sizeof(float));
|
||||||
assert(ggml_nrows(dst) == nr);
|
assert(ggml_nrows(dst) == nr);
|
||||||
assert(src0->nb[0] == sizeof(float));
|
|
||||||
|
|
||||||
for (int i = 0; i < nr; ++i) {
|
// TODO: multi-thread
|
||||||
const int r = ((int32_t *) src1->data)[i];
|
for (int64_t i = 0; i < nr; ++i) {
|
||||||
|
const int64_t r = ((int32_t *) src1->data)[i];
|
||||||
|
|
||||||
|
const int64_t i02 = i/ne10;
|
||||||
|
|
||||||
ggml_vec_cpy_f32(nc,
|
ggml_vec_cpy_f32(nc,
|
||||||
(float *) ((char *) dst->data + i*dst->nb[1]),
|
(float *) ((char *) dst->data + i*dst->nb[1]),
|
||||||
(float *) ((char *) src0->data + r*src0->nb[1]));
|
(float *) ((char *) src0->data + i02*nb02 + r*nb01));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user