From c7c8dabcf74da8e66577948d2498e140d299e7e8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 30 Nov 2023 20:05:41 +0200 Subject: [PATCH] ggml : update soft max cpu --- ggml.c | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ggml.c b/ggml.c index 9dc4678cb..e2687ef4f 100644 --- a/ggml.c +++ b/ggml.c @@ -10602,7 +10602,7 @@ static void ggml_compute_forward_soft_max_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - float * wdata = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; + float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; for (int i1 = ir0; i1 < ir1; i1++) { float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); @@ -10611,9 +10611,10 @@ static void ggml_compute_forward_soft_max_f32( // broadcast the mask across rows float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; - float * wp = wdata; - for (int i = 0; i < nc; i++) { - wp[i] = sp[i]*scale + (mp ? mp[i] : 0.0f); + ggml_vec_cpy_f32 (nc, wp, sp); + ggml_vec_scale_f32(nc, wp, scale); + if (mp) { + ggml_vec_acc_f32(nc, wp, mp); } #ifndef NDEBUG @@ -15939,7 +15940,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { } break; case GGML_OP_SOFT_MAX: { - n_tasks = MIN(n_threads, ggml_nrows(node->src[0])); + n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0])); cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; } break;