From c717c5185f76a07422fbf6d66b58bfe7b6f0fd9a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 9 Jul 2023 16:40:16 +0300 Subject: [PATCH] mpi : various fixes - communication now works but results are wrong --- ggml-mpi.c | 21 ++++++++++++++------- llama.cpp | 10 +++++----- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/ggml-mpi.c b/ggml-mpi.c index 8bf4468a1..e890d24d1 100644 --- a/ggml-mpi.c +++ b/ggml-mpi.c @@ -108,19 +108,17 @@ void ggml_mpi_graph_compute( const int mpi_rank_src = mpi_rank - 1; - // fprintf(stderr, "(%d) Receiving from (%d)\n", mpi_rank, mpi_rank_src); - const int retval = MPI_Recv(embd, ggml_nelements(embd), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + //printf("%s: node %d: waiting for %d elements from %d\n", __func__, mpi_rank, (int) ggml_nelements(embd), mpi_rank_src); + const int retval = MPI_Recv(embd->data, ggml_nelements(embd), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); GGML_ASSERT(retval == MPI_SUCCESS); - // fprintf(stderr, "(%d) Received from (%d)\n", mpi_rank, mpi_rank_src); } } else { // node 0 sends the input data to node 1 { const int mpi_rank_dst = mpi_rank + 1; - const int retval = MPI_Send(embd, ggml_nelements(embd), MPI_FLOAT, mpi_rank_dst, 0, MPI_COMM_WORLD); + const int retval = MPI_Send(embd->data, ggml_nelements(embd), MPI_FLOAT, mpi_rank_dst, 0, MPI_COMM_WORLD); GGML_ASSERT(retval == MPI_SUCCESS); - // fprintf(stderr, "(%d) Sent to (%d)\n", mpi_rank, mpi_rank_dst); } // recv the output data from the last node @@ -129,7 +127,8 @@ void ggml_mpi_graph_compute( const int mpi_rank_src = mpi_size - 1; - const int retval = MPI_Recv(embd, ggml_nelements(embd), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + //fprintf(stderr, "%s: node %d: waiting for %d elements from %d\n", __func__, mpi_rank, (int) ggml_nelements(embd), mpi_rank_src); + const int retval = MPI_Recv(embd->data, ggml_nelements(embd), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); GGML_ASSERT(retval == MPI_SUCCESS); } } @@ -165,20 +164,28 @@ void ggml_mpi_graph_compute( for (int i = 1; i < idx_l1 - idx_l0; i++) { gf->nodes[i] = gf->nodes[idx_l0 + i]; gf->grads[i] = gf->grads[idx_l0 + i]; + + //fprintf(stderr, "%s: node %d: %d -> %d\n", __func__, mpi_rank, idx_l0 + i, i); } gf->n_nodes = idx_l1 - idx_l0; + + //fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, mpi_rank, gf->n_nodes, il0, il1); } ggml_graph_compute(ctx, gf); + //fprintf(stderr, "%s: node %d: done\n", __func__, mpi_rank); + // send the output data to the next node if (mpi_rank > 0) { struct ggml_tensor * output = gf->nodes[gf->n_nodes - 1]; const int mpi_rank_dst = (mpi_rank + 1) % mpi_size; - const int retval = MPI_Send(output, ggml_nelements(output), MPI_FLOAT, mpi_rank_dst, 0, MPI_COMM_WORLD); + //fprintf(stderr, "%s: node %d: sending %d elements to node %d\n", __func__, mpi_rank, ggml_nelements(output), mpi_rank_dst); + + const int retval = MPI_Send(output->data, ggml_nelements(output), MPI_FLOAT, mpi_rank_dst, 0, MPI_COMM_WORLD); GGML_ASSERT(retval == MPI_SUCCESS); } } diff --git a/llama.cpp b/llama.cpp index fa8030c36..08a5bd284 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1287,6 +1287,10 @@ static bool llama_eval_internal( LLAMA_ASSERT((!tokens && embd) || (tokens && !embd)); +#ifdef GGML_USE_MPI + ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); +#endif + // enforce that the first token is BOS if (tokens && n_past == 0 && tokens[0] != llama_token_bos()) { fprintf(stderr, "%s: first token must be BOS\n", __func__); @@ -1331,10 +1335,6 @@ static bool llama_eval_internal( struct ggml_tensor * cur; struct ggml_tensor * inpL; -#ifdef GGML_USE_MPI - ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); -#endif - if (tokens) { struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); memcpy(embd->data, tokens, N*ggml_element_size(embd)); @@ -1636,7 +1636,7 @@ static bool llama_eval_internal( ggml_graph_compute(ctx0, &gf); } #elif GGML_USE_MPI - ggml_mpi_graph_compute(lctx.ctx_mpi, ctx0, &gf, n_layer, n_embd, n_tokens); + ggml_mpi_graph_compute(lctx.ctx_mpi, ctx0, &gf, n_layer); #else ggml_graph_compute(ctx0, &gf); #endif