mpi : fix after master merge

This commit is contained in:
Georgi Gerganov 2023-07-09 22:23:04 +03:00
parent 81c5ddd532
commit 0492363137
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 6 additions and 4 deletions

View File

@ -78,7 +78,8 @@ void ggml_mpi_graph_compute(
struct ggml_mpi_context * ctx_mpi,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
int n_layers) {
int n_layers,
int n_threads) {
const int mpi_rank = ctx_mpi->rank;
const int mpi_size = ctx_mpi->size;
@ -194,7 +195,7 @@ void ggml_mpi_graph_compute(
//fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, mpi_rank, gf->n_nodes, il0, il1);
}
ggml_graph_compute(ctx, gf);
ggml_graph_compute_with_ctx(ctx, gf, n_threads);
//fprintf(stderr, "%s: node %d: done\n", __func__, mpi_rank);

View File

@ -28,7 +28,8 @@ void ggml_mpi_graph_compute(
struct ggml_mpi_context * ctx_mpi,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
int n_layers);
int n_layers,
int n_threads);
#ifdef __cplusplus
}

View File

@ -1657,7 +1657,7 @@ static bool llama_eval_internal(
ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
}
#elif GGML_USE_MPI
ggml_mpi_graph_compute(lctx.ctx_mpi, ctx0, &gf, n_layer);
ggml_mpi_graph_compute(lctx.ctx_mpi, ctx0, &gf, n_layer, n_threads);
cur = gf.nodes[gf.n_nodes - 1];
#else