From a1cdd29cd20a1c1fe7a7c33f77dbf38a4db90296 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 17 May 2023 19:49:04 +0300 Subject: [PATCH] ggml : rms_norm in chunks --- ggml.c | 67 +++++++++++++++++++++++++++++++++++----------------------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/ggml.c b/ggml.c index d3b806cc0..65ec66f24 100644 --- a/ggml.c +++ b/ggml.c @@ -9033,18 +9033,20 @@ static void ggml_compute_forward_rms_norm_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + atomic_store(params->aic, 0); + return; } GGML_ASSERT(src0->nb[0] == sizeof(float)); - const int ith = params->ith; + const int ith = params->ith; UNUSED(ith); const int nth = params->nth; const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; + const int64_t ne03 = src0->ne[3]; UNUSED(ne03); const size_t nb01 = src0->nb[1]; const size_t nb02 = src0->nb[2]; @@ -9056,30 +9058,45 @@ static void ggml_compute_forward_rms_norm_f32( const float eps = 1e-6f; // TODO: make this a parameter - // TODO: optimize - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ith; i01 < ne01; i01 += nth) { - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + const int nr = ggml_nrows(src0); + const int dr = (nr + 8*nth - 1)/(8*nth); - ggml_float sum = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - sum += (ggml_float)(x[i00] * x[i00]); - } + while (true) { + const int ir0 = atomic_fetch_add(params->aic, dr); - float mean = sum/ne00; - - float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - - memcpy(y, x, ne00 * sizeof(float)); - // for (int i00 = 0; i00 < ne00; i00++) { - // y[i00] = x[i00]; - // } - - const float scale = 1.0f/sqrtf(mean + eps); - - ggml_vec_scale_f32(ne00, y, scale); + for (int ir = ir0; ir < ir0 + dr; ++ir) { + if (ir >= nr) { + break; } + + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + ggml_float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += (ggml_float)(x[i00] * x[i00]); + } + + float mean = sum/ne00; + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + memcpy(y, x, ne00 * sizeof(float)); + // for (int i00 = 0; i00 < ne00; i00++) { + // y[i00] = x[i00]; + // } + + const float scale = 1.0f/sqrtf(mean + eps); + + ggml_vec_scale_f32(ne00, y, scale); + } + + if (ir0 + dr >= nr) { + break; } } } @@ -9754,11 +9771,9 @@ static void ggml_compute_forward_mul_mat_q_f32( const int nb2 = dst->nb[2]; const int nb3 = dst->nb[3]; - const int ith = params->ith; + const int ith = params->ith; UNUSED(ith); const int nth = params->nth; - UNUSED(ith); - GGML_ASSERT(ne02 == ne12); GGML_ASSERT(ne03 == ne13); GGML_ASSERT(ne2 == ne12);