diff --git a/ggml-mpi.c b/ggml-mpi.c index 4bde41808..70639e078 100644 --- a/ggml-mpi.c +++ b/ggml-mpi.c @@ -78,7 +78,8 @@ void ggml_mpi_graph_compute( struct ggml_mpi_context * ctx_mpi, struct ggml_context * ctx, struct ggml_cgraph * gf, - int n_layers) { + int n_layers, + int n_threads) { const int mpi_rank = ctx_mpi->rank; const int mpi_size = ctx_mpi->size; @@ -194,7 +195,7 @@ void ggml_mpi_graph_compute( //fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, mpi_rank, gf->n_nodes, il0, il1); } - ggml_graph_compute(ctx, gf); + ggml_graph_compute_with_ctx(ctx, gf, n_threads); //fprintf(stderr, "%s: node %d: done\n", __func__, mpi_rank); diff --git a/ggml-mpi.h b/ggml-mpi.h index 02e125cfb..2ad0a4386 100644 --- a/ggml-mpi.h +++ b/ggml-mpi.h @@ -28,7 +28,8 @@ void ggml_mpi_graph_compute( struct ggml_mpi_context * ctx_mpi, struct ggml_context * ctx, struct ggml_cgraph * gf, - int n_layers); + int n_layers, + int n_threads); #ifdef __cplusplus } diff --git a/llama.cpp b/llama.cpp index 325db7d56..322e37a7d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1657,7 +1657,7 @@ static bool llama_eval_internal( ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads); } #elif GGML_USE_MPI - ggml_mpi_graph_compute(lctx.ctx_mpi, ctx0, &gf, n_layer); + ggml_mpi_graph_compute(lctx.ctx_mpi, ctx0, &gf, n_layer, n_threads); cur = gf.nodes[gf.n_nodes - 1]; #else