diff --git a/ggml.c b/ggml.c index 101cb733b..d3b806cc0 100644 --- a/ggml.c +++ b/ggml.c @@ -3590,6 +3590,9 @@ struct ggml_compute_params { // work buffer for all threads size_t wsize; void * wdata; + + // atomic counter used to distribute chunks of work + atomic_int * aic; }; // @@ -9754,6 +9757,8 @@ static void ggml_compute_forward_mul_mat_q_f32( const int ith = params->ith; const int nth = params->nth; + UNUSED(ith); + GGML_ASSERT(ne02 == ne12); GGML_ASSERT(ne03 == ne13); GGML_ASSERT(ne2 == ne12); @@ -9867,6 +9872,8 @@ static void ggml_compute_forward_mul_mat_q_f32( } } + atomic_store(params->aic, 0); + return; } @@ -9874,43 +9881,48 @@ static void ggml_compute_forward_mul_mat_q_f32( return; } - // parallelize by src0 rows using ggml_vec_dot_q - - // total rows in src0 - const int nr = ne01*ne02*ne03; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - void * wdata = params->wdata; const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type]; - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + // parallelize by src0 rows using ggml_vec_dot_q - const int i13 = i03; - const int i12 = i02; + const int nr = ggml_nrows(src0); + const int dr = (nr + 8*nth - 1)/(8*nth); - const int i0 = i01; - const int i2 = i02; - const int i3 = i03; + while (true) { + const int ir0 = atomic_fetch_add(params->aic, dr); - void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); - char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size)); + for (int ir = ir0; ir < ir0 + dr; ++ir) { + if (ir >= nr) { + break; + } - float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - assert(ne00 % 32 == 0); + const int i13 = i03; + const int i12 = i02; - for (int64_t ic = 0; ic < ne11; ++ic) { - vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size)); + const int i0 = i01; + const int i2 = i02; + const int i3 = i03; + + void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size)); + + float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); + + assert(ne00 % 32 == 0); + + for (int64_t ic = 0; ic < ne11; ++ic) { + vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size)); + } + } + + if (ir0 + dr >= nr) { + break; } } @@ -13749,6 +13761,7 @@ struct ggml_compute_state_shared { // synchronization primitives atomic_int n_ready; + atomic_int aic; atomic_bool has_work; atomic_bool stop; // stop all threads }; @@ -13817,6 +13830,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) /*.spin =*/ GGML_LOCK_INITIALIZER, /*.n_threads =*/ n_threads, /*.n_ready =*/ 0, + /*.aic =*/ 0, /*.has_work =*/ false, /*.stop =*/ false, }; @@ -13837,6 +13851,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) .nth = n_threads, .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, .wdata = cgraph->work ? cgraph->work->data : NULL, + .aic = &state_shared.aic, }, .node = NULL, .shared = &state_shared, @@ -14126,6 +14141,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) /*.nth =*/ node->n_tasks, /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0, /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL, + /*.aic =*/ &state_shared.aic, }; ggml_compute_forward(¶ms, node); @@ -14149,6 +14165,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) .nth = node->n_tasks, .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, .wdata = cgraph->work ? cgraph->work->data : NULL, + .aic = &state_shared.aic, }; workers[j].node = node; } @@ -14164,6 +14181,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } params.type = GGML_TASK_COMPUTE; + params.aic = &state_shared.aic; ggml_compute_forward(¶ms, node); // wait for thread pool @@ -14204,6 +14222,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) .nth = node->n_tasks, .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, .wdata = cgraph->work ? cgraph->work->data : NULL, + .aic = &state_shared.aic, }; workers[j].node = node; }