Multi-threaded ggml_cpy (#1035)

* Multi-threaded ggml_cpy

* Update ggml.c

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Also fix wdata offset in ggml_compute_forward_add_q_f32

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
slaren 2023-04-19 00:53:24 +02:00 committed by GitHub
parent 77a73403ca
commit 6667401238
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

298
ggml.c
View File

@ -5766,7 +5766,6 @@ static void ggml_compute_forward_dup_f16(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
GGML_ASSERT(params->ith == 0);
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@ -5778,6 +5777,11 @@ static void ggml_compute_forward_dup_f16(
const int64_t ne02 = src0->ne[2]; const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3]; const int64_t ne03 = src0->ne[3];
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
const int64_t ne2 = dst->ne[2];
const int64_t ne3 = dst->ne[3];
const size_t nb00 = src0->nb[0]; const size_t nb00 = src0->nb[0];
const size_t nb01 = src0->nb[1]; const size_t nb01 = src0->nb[1];
const size_t nb02 = src0->nb[2]; const size_t nb02 = src0->nb[2];
@ -5788,19 +5792,40 @@ static void ggml_compute_forward_dup_f16(
const size_t nb2 = dst->nb[2]; const size_t nb2 = dst->nb[2];
const size_t nb3 = dst->nb[3]; const size_t nb3 = dst->nb[3];
const int ith = params->ith; // thread index
const int nth = params->nth; // number of threads
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) { if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]); // parallelize by elements
const int ne = ggml_nelements(dst);
const int dr = (ne + nth - 1) / nth;
const int ie0 = dr * ith;
const int ie1 = MIN(ie0 + dr, ne);
memcpy(
((char *) dst->data + ie0*nb0),
((char *) src0->data + ie0*nb00),
(ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
return; return;
} }
// parallelize by rows
const int nr = ne01;
// number of rows per thread
const int dr = (nr + nth - 1) / nth;
// row range for this thread
const int ir0 = dr * ith;
const int ir1 = MIN(ir0 + dr, nr);
if (src0->type == dst->type && if (src0->type == dst->type &&
src0->ne[0] == dst->ne[0] && ne00 == ne0 &&
src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) { nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
// copy by rows // copy by rows
const size_t rs = ne00*nb00; const size_t rs = ne00*nb00;
for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) { for (int64_t i01 = ir0; i01 < ir1; i01++) {
memcpy( memcpy(
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
@ -5814,21 +5839,21 @@ static void ggml_compute_forward_dup_f16(
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
if (ggml_is_contiguous(dst)) { if (ggml_is_contiguous(dst)) {
if (src0->nb[0] == sizeof(ggml_fp16_t)) { if (nb00 == sizeof(ggml_fp16_t)) {
if (dst->type == GGML_TYPE_F16) { if (dst->type == GGML_TYPE_F16) {
size_t id = 0; size_t id = 0;
const size_t rs = ne00 * nb00; const size_t rs = ne00 * nb00;
char * dst_ptr = (char *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) { for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) { for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = 0; i01 < ne01; i01++) { id += rs * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
char * dst_ptr = (char *) dst->data + id*rs; memcpy(dst_ptr + id, src0_ptr, rs);
id += rs;
memcpy(dst_ptr, src0_ptr, rs);
id++;
} }
id += rs * (ne01 - ir1);
} }
} }
} else if (dst->type == GGML_TYPE_F32) { } else if (dst->type == GGML_TYPE_F32) {
@ -5837,34 +5862,39 @@ static void ggml_compute_forward_dup_f16(
for (int i03 = 0; i03 < ne03; i03++) { for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) { for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = 0; i01 < ne01; i01++) { id += ne00 * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
for (int i00 = 0; i00 < ne00; i00++) { for (int i00 = 0; i00 < ne00; i00++) {
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
id++; id++;
} }
} }
id += ne00 * (ne01 - ir1);
} }
} }
} else if (ggml_is_quantized(dst->type)) { } else if (ggml_is_quantized(dst->type)) {
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q; quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
size_t id = 0; size_t id = 0;
uint8_t * dst_ptr = (uint8_t *) dst->data; size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]); char * dst_ptr = (char *) dst->data;
float * src0_f32 = (float *) params->wdata;
for (int i03 = 0; i03 < ne03; i03++) { for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) { for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = 0; i01 < ne01; i01++) { id += rs * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
// convert to f32 and quantize
for (int i00 = 0; i00 < ne00; i00++) { for (int i00 = 0; i00 < ne00; i00++) {
src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]); src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
} }
quantize_row_q(src0_f32, dst_ptr + id, ne00); quantize_row_q(src0_f32, dst_ptr + id, ne00);
id += dst_row_size; id += rs;
} }
id += rs * (ne01 - ir1);
} }
} }
} else { } else {
@ -5879,7 +5909,8 @@ static void ggml_compute_forward_dup_f16(
for (int i03 = 0; i03 < ne03; i03++) { for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) { for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = 0; i01 < ne01; i01++) { id += ne00 * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
for (int i00 = 0; i00 < ne00; i00++) { for (int i00 = 0; i00 < ne00; i00++) {
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
@ -5887,6 +5918,7 @@ static void ggml_compute_forward_dup_f16(
id++; id++;
} }
} }
id += ne00 * (ne01 - ir1);
} }
} }
} else if (dst->type == GGML_TYPE_F16) { } else if (dst->type == GGML_TYPE_F16) {
@ -5895,7 +5927,8 @@ static void ggml_compute_forward_dup_f16(
for (int i03 = 0; i03 < ne03; i03++) { for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) { for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = 0; i01 < ne01; i01++) { id += ne00 * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
for (int i00 = 0; i00 < ne00; i00++) { for (int i00 = 0; i00 < ne00; i00++) {
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
@ -5903,6 +5936,7 @@ static void ggml_compute_forward_dup_f16(
id++; id++;
} }
} }
id += ne00 * (ne01 - ir1);
} }
} }
} else { } else {
@ -5921,7 +5955,20 @@ static void ggml_compute_forward_dup_f16(
if (dst->type == GGML_TYPE_F16) { if (dst->type == GGML_TYPE_F16) {
for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) { i10 += ne00 * ir0;
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
for (int64_t i01 = ir0; i01 < ir1; i01++) {
for (int64_t i00 = 0; i00 < ne00; i00++) { for (int64_t i00 = 0; i00 < ne00; i00++) {
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
@ -5942,25 +5989,51 @@ static void ggml_compute_forward_dup_f16(
} }
} }
} }
i10 += ne00 * (ne01 - ir1);
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
} }
} }
} else if (dst->type == GGML_TYPE_F32) { } else if (dst->type == GGML_TYPE_F32) {
for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) { i10 += ne00 * ir0;
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
for (int64_t i01 = ir0; i01 < ir1; i01++) {
for (int64_t i00 = 0; i00 < ne00; i00++) { for (int64_t i00 = 0; i00 < ne00; i00++) {
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
*(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr); *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
if (++i10 == ne00) { if (++i10 == ne0) {
i10 = 0; i10 = 0;
if (++i11 == ne01) { if (++i11 == ne1) {
i11 = 0; i11 = 0;
if (++i12 == ne02) { if (++i12 == ne2) {
i12 = 0; i12 = 0;
if (++i13 == ne03) { if (++i13 == ne3) {
i13 = 0; i13 = 0;
} }
} }
@ -5968,6 +6041,19 @@ static void ggml_compute_forward_dup_f16(
} }
} }
} }
i10 += ne00 * (ne01 - ir1);
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
} }
} }
} else { } else {
@ -5979,7 +6065,6 @@ static void ggml_compute_forward_dup_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
GGML_ASSERT(params->ith == 0);
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@ -5991,6 +6076,11 @@ static void ggml_compute_forward_dup_f32(
const int64_t ne02 = src0->ne[2]; const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3]; const int64_t ne03 = src0->ne[3];
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
const int64_t ne2 = dst->ne[2];
const int64_t ne3 = dst->ne[3];
const size_t nb00 = src0->nb[0]; const size_t nb00 = src0->nb[0];
const size_t nb01 = src0->nb[1]; const size_t nb01 = src0->nb[1];
const size_t nb02 = src0->nb[2]; const size_t nb02 = src0->nb[2];
@ -6001,19 +6091,40 @@ static void ggml_compute_forward_dup_f32(
const size_t nb2 = dst->nb[2]; const size_t nb2 = dst->nb[2];
const size_t nb3 = dst->nb[3]; const size_t nb3 = dst->nb[3];
const int ith = params->ith; // thread index
const int nth = params->nth; // number of threads
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) { if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]); // parallelize by elements
const int ne = ggml_nelements(dst);
const int dr = (ne + nth - 1) / nth;
const int ie0 = dr * ith;
const int ie1 = MIN(ie0 + dr, ne);
memcpy(
((char *) dst->data + ie0*nb0),
((char *) src0->data + ie0*nb00),
(ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
return; return;
} }
// parallelize by rows
const int nr = ne01;
// number of rows per thread
const int dr = (nr + nth - 1) / nth;
// row range for this thread
const int ir0 = dr * ith;
const int ir1 = MIN(ir0 + dr, nr);
if (src0->type == dst->type && if (src0->type == dst->type &&
src0->ne[0] == dst->ne[0] && ne00 == ne0 &&
src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) { nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
// copy by rows // copy by rows
const size_t rs = ne00*nb00; const size_t rs = ne00*nb00;
for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) { for (int64_t i01 = ir0; i01 < ir1; i01++) {
memcpy( memcpy(
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
@ -6026,21 +6137,21 @@ static void ggml_compute_forward_dup_f32(
if (ggml_is_contiguous(dst)) { if (ggml_is_contiguous(dst)) {
// TODO: simplify // TODO: simplify
if (src0->nb[0] == sizeof(float)) { if (nb00 == sizeof(float)) {
if (dst->type == GGML_TYPE_F32) { if (dst->type == GGML_TYPE_F32) {
size_t id = 0; size_t id = 0;
const size_t rs = ne00 * nb00; const size_t rs = ne00 * nb00;
char * dst_ptr = (char *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) { for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) { for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = 0; i01 < ne01; i01++) { id += rs * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
char * dst_ptr = (char *) dst->data + id*rs; memcpy(dst_ptr + id, src0_ptr, rs);
id += rs;
memcpy(dst_ptr, src0_ptr, rs);
id++;
} }
id += rs * (ne01 - ir1);
} }
} }
} else if (dst->type == GGML_TYPE_F16) { } else if (dst->type == GGML_TYPE_F16) {
@ -6049,7 +6160,8 @@ static void ggml_compute_forward_dup_f32(
for (int i03 = 0; i03 < ne03; i03++) { for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) { for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = 0; i01 < ne01; i01++) { id += ne00 * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
for (int i00 = 0; i00 < ne00; i00++) { for (int i00 = 0; i00 < ne00; i00++) {
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
@ -6057,21 +6169,25 @@ static void ggml_compute_forward_dup_f32(
id++; id++;
} }
} }
id += ne00 * (ne01 - ir1);
} }
} }
} else if (ggml_is_quantized(dst->type)) { } else if (ggml_is_quantized(dst->type)) {
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q; quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
size_t id = 0; size_t id = 0;
uint8_t * dst_ptr = (uint8_t *) dst->data; size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]); char * dst_ptr = (char *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) { for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) { for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = 0; i01 < ne01; i01++) { id += rs * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
quantize_row_q(src0_ptr, dst_ptr + id, ne00); quantize_row_q(src0_ptr, dst_ptr + id, ne00);
id += dst_row_size; id += rs;
} }
id += rs * (ne01 - ir1);
} }
} }
} else { } else {
@ -6086,7 +6202,8 @@ static void ggml_compute_forward_dup_f32(
for (int i03 = 0; i03 < ne03; i03++) { for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) { for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = 0; i01 < ne01; i01++) { id += ne00 * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
for (int i00 = 0; i00 < ne00; i00++) { for (int i00 = 0; i00 < ne00; i00++) {
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
@ -6094,6 +6211,7 @@ static void ggml_compute_forward_dup_f32(
id++; id++;
} }
} }
id += ne00 * (ne01 - ir1);
} }
} }
} else if (dst->type == GGML_TYPE_F16) { } else if (dst->type == GGML_TYPE_F16) {
@ -6102,7 +6220,8 @@ static void ggml_compute_forward_dup_f32(
for (int i03 = 0; i03 < ne03; i03++) { for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) { for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = 0; i01 < ne01; i01++) { id += ne00 * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
for (int i00 = 0; i00 < ne00; i00++) { for (int i00 = 0; i00 < ne00; i00++) {
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
@ -6110,6 +6229,7 @@ static void ggml_compute_forward_dup_f32(
id++; id++;
} }
} }
id += ne00 * (ne01 - ir1);
} }
} }
} else { } else {
@ -6121,6 +6241,7 @@ static void ggml_compute_forward_dup_f32(
} }
// dst counters // dst counters
int64_t i10 = 0; int64_t i10 = 0;
int64_t i11 = 0; int64_t i11 = 0;
int64_t i12 = 0; int64_t i12 = 0;
@ -6129,20 +6250,34 @@ static void ggml_compute_forward_dup_f32(
if (dst->type == GGML_TYPE_F32) { if (dst->type == GGML_TYPE_F32) {
for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) { i10 += ne00 * ir0;
while (i10 >= ne0) {
i10 -= ne0;
i11++;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
for (int64_t i01 = ir0; i01 < ir1; i01++) {
for (int64_t i00 = 0; i00 < ne00; i00++) { for (int64_t i00 = 0; i00 < ne00; i00++) {
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
memcpy(dst_ptr, src0_ptr, sizeof(float)); memcpy(dst_ptr, src0_ptr, sizeof(float));
if (++i10 == dst->ne[0]) { if (++i10 == ne0) {
i10 = 0; i10 = 0;
if (++i11 == dst->ne[1]) { if (++i11 == ne1) {
i11 = 0; i11 = 0;
if (++i12 == dst->ne[2]) { if (++i12 == ne2) {
i12 = 0; i12 = 0;
if (++i13 == dst->ne[3]) { if (++i13 == ne3) {
i13 = 0; i13 = 0;
} }
} }
@ -6150,25 +6285,51 @@ static void ggml_compute_forward_dup_f32(
} }
} }
} }
i10 += ne00 * (ne01 - ir1);
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
} }
} }
} else if (dst->type == GGML_TYPE_F16) { } else if (dst->type == GGML_TYPE_F16) {
for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) { i10 += ne00 * ir0;
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
for (int64_t i01 = ir0; i01 < ir1; i01++) {
for (int64_t i00 = 0; i00 < ne00; i00++) { for (int64_t i00 = 0; i00 < ne00; i00++) {
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
*(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr); *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
if (++i10 == dst->ne[0]) { if (++i10 == ne0) {
i10 = 0; i10 = 0;
if (++i11 == dst->ne[1]) { if (++i11 == ne1) {
i11 = 0; i11 = 0;
if (++i12 == dst->ne[2]) { if (++i12 == ne2) {
i12 = 0; i12 = 0;
if (++i13 == dst->ne[3]) { if (++i13 == ne3) {
i13 = 0; i13 = 0;
} }
} }
@ -6176,6 +6337,19 @@ static void ggml_compute_forward_dup_f32(
} }
} }
} }
i10 += ne00 * (ne01 - ir1);
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
} }
} }
} else { } else {
@ -6436,7 +6610,7 @@ static void ggml_compute_forward_add_q_f32(
const int ir0 = dr*ith; const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr); const int ir1 = MIN(ir0 + dr, nr);
float * wdata = (float*) params->wdata + ne00 * ith; float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
for (int ir = ir0; ir < ir1; ++ir) { for (int ir = ir0; ir < ir1; ++ir) {
// src0 indices // src0 indices
@ -10636,11 +10810,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
case GGML_OP_CPY: case GGML_OP_CPY:
case GGML_OP_DUP: case GGML_OP_DUP:
{ {
node->n_tasks = 1; node->n_tasks = n_threads;
size_t cur = 0; size_t cur = 0;
if (ggml_is_quantized(node->type)) { if (ggml_is_quantized(node->type)) {
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0]; cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
} }
work_size = MAX(work_size, cur); work_size = MAX(work_size, cur);