mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-06 16:40:34 +01:00
ggml : update soft max cpu
This commit is contained in:
parent
ebd062bc19
commit
c7c8dabcf7
11
ggml.c
11
ggml.c
@ -10602,7 +10602,7 @@ static void ggml_compute_forward_soft_max_f32(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
float * wdata = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
||||
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
||||
@ -10611,9 +10611,10 @@ static void ggml_compute_forward_soft_max_f32(
|
||||
// broadcast the mask across rows
|
||||
float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
|
||||
|
||||
float * wp = wdata;
|
||||
for (int i = 0; i < nc; i++) {
|
||||
wp[i] = sp[i]*scale + (mp ? mp[i] : 0.0f);
|
||||
ggml_vec_cpy_f32 (nc, wp, sp);
|
||||
ggml_vec_scale_f32(nc, wp, scale);
|
||||
if (mp) {
|
||||
ggml_vec_acc_f32(nc, wp, mp);
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
@ -15939,7 +15940,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
|
||||
n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0]));
|
||||
|
||||
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
||||
} break;
|
||||
|
Loading…
Reference in New Issue
Block a user