ggml : update soft max cpu

This commit is contained in:
Georgi Gerganov 2023-11-30 20:05:41 +02:00
parent ebd062bc19
commit c7c8dabcf7
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

11
ggml.c
View File

@ -10602,7 +10602,7 @@ static void ggml_compute_forward_soft_max_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
float * wdata = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
for (int i1 = ir0; i1 < ir1; i1++) {
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
@ -10611,9 +10611,10 @@ static void ggml_compute_forward_soft_max_f32(
// broadcast the mask across rows
float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
float * wp = wdata;
for (int i = 0; i < nc; i++) {
wp[i] = sp[i]*scale + (mp ? mp[i] : 0.0f);
ggml_vec_cpy_f32 (nc, wp, sp);
ggml_vec_scale_f32(nc, wp, scale);
if (mp) {
ggml_vec_acc_f32(nc, wp, mp);
}
#ifndef NDEBUG
@ -15939,7 +15940,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
} break;
case GGML_OP_SOFT_MAX:
{
n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0]));
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
} break;