ggml : use less ggml_mul tasks when src0 rows are few

This commit is contained in:
Georgi Gerganov 2023-08-30 19:37:26 +03:00
parent 253eab8ae1
commit df54d2f1d4
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

11
ggml.c
View File

@ -9329,11 +9329,12 @@ static void ggml_compute_forward_mul_f32(
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
const int ith = params->ith;
const int nth = params->nth;
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return; return;
} }
const int ith = params->ith;
const int nth = params->nth;
#ifdef GGML_USE_CLBLAST #ifdef GGML_USE_CLBLAST
if (src1->backend == GGML_BACKEND_GPU) { if (src1->backend == GGML_BACKEND_GPU) {
@ -17229,7 +17230,13 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
} }
} break; } break;
case GGML_OP_SILU_BACK: case GGML_OP_SILU_BACK:
{
n_tasks = n_threads;
} break;
case GGML_OP_MUL: case GGML_OP_MUL:
{
n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
} break;
case GGML_OP_NORM: case GGML_OP_NORM:
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK: case GGML_OP_RMS_NORM_BACK: