diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3a171925b..154a4113a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -671,5 +671,7 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); + llama_finalize_backend(); + return 0; } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index dd54ed3c4..5c4523795 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -172,5 +172,7 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); + llama_finalize_backend(); + return 0; } diff --git a/ggml.c b/ggml.c index afeb72ff0..2f00428d3 100644 --- a/ggml.c +++ b/ggml.c @@ -26,6 +26,8 @@ #include #include +#include + #ifdef GGML_USE_METAL #include #endif @@ -4648,6 +4650,35 @@ struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggm return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, NULL); } +struct ggml_tensor * ggml_send_tensor( + struct ggml_context * ctx, + const struct ggml_tensor *src, + int dst_rank) { + + struct ggml_tensor * result = ggml_new_i32(ctx, 0); + + result->op = GGML_OP_SEND; + result->src0 = src; + result->extra = (void *)dst_rank; + + return result; +} + +struct ggml_tensor * ggml_recv_tensor( + struct ggml_context * ctx, + const struct ggml_tensor *parent, + struct ggml_tensor *dst, + int src_rank) { + + struct ggml_tensor * result = dst; + + result->op = GGML_OP_RECV; + result->src0 = parent; // just used for graph computation + result->extra = (void *)src_rank; + + return result; +} + struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { memset(tensor->data, 0, ggml_nbytes(tensor)); return tensor; @@ -8191,6 +8222,52 @@ static void ggml_compute_forward_dup( } } +// ggml_compute_forward_recv + +static void ggml_compute_forward_recv( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#ifdef GGML_USE_MPI + MPI_Status status; + int my_rank; + MPI_Comm_rank(MPI_COMM_WORLD, &my_rank); + // fprintf(stderr, "(%d) Receiving from (%d)\n", my_rank, (int)dst->extra); + int retval = MPI_Recv(dst->data, dst->ne[0] * dst->ne[1], MPI_FLOAT, (int)dst->extra, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + // fprintf(stderr, "(%d) Received from (%d)\n", my_rank, (int)dst->extra); + GGML_ASSERT(retval == MPI_SUCCESS); +#else + GGML_ASSERT(false); +#endif +} + +// ggml_compute_forward_send + +static void ggml_compute_forward_send( + const struct ggml_compute_params * params, + struct ggml_tensor * src, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + GGML_ASSERT(src->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_I32); +#ifdef GGML_USE_MPI + int my_rank; + MPI_Comm_rank(MPI_COMM_WORLD, &my_rank); + // fprintf(stderr, "(%d) Sending to (%d)\n", my_rank, (int)dst->extra); + int retval = MPI_Send(src->data, src->ne[0] * src->ne[1], MPI_FLOAT, (int)dst->extra, 0, MPI_COMM_WORLD); + // fprintf(stderr, "(%d) Sent to (%d)\n", my_rank, (int)dst->extra); + ggml_set_i32(dst, retval); + GGML_ASSERT(retval == MPI_SUCCESS); +#else + GGML_ASSERT(false); +#endif +} + // ggml_compute_forward_add static void ggml_compute_forward_add_f32( @@ -15420,6 +15497,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_dup(params, tensor->src0, tensor); } break; + case GGML_OP_SEND: + { + ggml_compute_forward_send(params, tensor->src0, tensor); + } break; + case GGML_OP_RECV: + { + ggml_compute_forward_recv(params, tensor); + } break; case GGML_OP_ADD: { ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor); @@ -15710,6 +15795,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); } } break; + case GGML_OP_SEND: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_RECV: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_ADD: { if (src0->grad) { @@ -17058,6 +17151,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { node->n_tasks = 1; } break; + case GGML_OP_SEND: + case GGML_OP_RECV: case GGML_OP_SET: case GGML_OP_CONT: case GGML_OP_RESHAPE: diff --git a/ggml.h b/ggml.h index 11b51f8bd..aa78f17dd 100644 --- a/ggml.h +++ b/ggml.h @@ -353,6 +353,9 @@ extern "C" { GGML_OP_CROSS_ENTROPY_LOSS_BACK, GGML_OP_COUNT, + + GGML_OP_SEND, + GGML_OP_RECV, }; @@ -556,6 +559,16 @@ extern "C" { GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src); GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src); + GGML_API struct ggml_tensor * ggml_send_tensor( + struct ggml_context * ctx, + const struct ggml_tensor *src, + int dst_rank); + GGML_API struct ggml_tensor * ggml_recv_tensor( + struct ggml_context * ctx, + const struct ggml_tensor *parent, + struct ggml_tensor *dst, + int src_rank); + GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name); GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); diff --git a/llama.cpp b/llama.cpp index a869bbac8..c7de0bc60 100644 --- a/llama.cpp +++ b/llama.cpp @@ -49,6 +49,8 @@ #include #include +#include + #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #endif @@ -330,6 +332,9 @@ struct llama_context { ggml_metal_context * ctx_metal = NULL; #endif + int mpi_rank; + int mpi_size; + int buf_last = 0; size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 }; @@ -864,6 +869,15 @@ void llama_init_backend(bool numa) { if (numa) { ggml_numa_init(); } +#ifdef GGML_USE_MPI + MPI_Init(NULL, NULL); +#endif +} + +void llama_finalize_backend() { +#ifdef GGML_USE_MPI + MPI_Finalize(); +#endif } int64_t llama_time_us() { @@ -1307,7 +1321,16 @@ static bool llama_eval_internal( struct ggml_tensor * cur; struct ggml_tensor * inpL; - if (tokens) { + if (lctx.mpi_rank > 0) { +#ifdef GGML_USE_MPI + inpL = ggml_recv_tensor(ctx0, NULL, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N), + lctx.mpi_rank-1); + ggml_set_name(inpL, "recv"); +#else + GGML_ASSERT(false); +#endif + } else if (tokens) { struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); ggml_set_name(embd, "embd"); memcpy(embd->data, tokens, N*ggml_element_size(embd)); @@ -1341,7 +1364,9 @@ static bool llama_eval_internal( } #endif // GGML_USE_CUBLAS - for (int il = 0; il < n_layer; ++il) { + // EMM TODO distribute work more evenly - maybe rank=0 gets the smallest amount? + int slice_size = (n_layer + (lctx.mpi_size - 1)) / lctx.mpi_size; + for (int il = lctx.mpi_rank * slice_size; il < n_layer && il < (lctx.mpi_rank + 1) * slice_size; ++il) { offload_func_t offload_func = llama_nop; #ifdef GGML_USE_CUBLAS @@ -1556,26 +1581,37 @@ static bool llama_eval_internal( // used at the end to optionally extract the embeddings struct ggml_tensor * embeddings = NULL; +#ifdef GGML_USE_MPI + cur = ggml_send_tensor(ctx0, cur, (lctx.mpi_rank+1)%lctx.mpi_size); + ggml_set_name(cur, "send"); +#endif + if (lctx.mpi_rank == 0) { +#ifdef GGML_USE_MPI + cur = ggml_recv_tensor(ctx0, cur, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N), + lctx.mpi_size-1); + ggml_set_name(cur, "recv"); +#endif + // norm + { + cur = ggml_rms_norm(ctx0, cur); + offload_func_nr(cur); + ggml_set_name(cur, "rms_norm_2"); - // norm - { - cur = ggml_rms_norm(ctx0, inpL); - offload_func_nr(cur); - ggml_set_name(cur, "rms_norm_2"); + // cur = cur*norm(broadcasted) + cur = ggml_mul(ctx0, cur, model.norm); + // offload_func_nr(cur); // TODO CPU + GPU mirrored backend + ggml_set_name(cur, "result_norm"); - // cur = cur*norm(broadcasted) - cur = ggml_mul(ctx0, cur, model.norm); - // offload_func_nr(cur); // TODO CPU + GPU mirrored backend - ggml_set_name(cur, "result_norm"); + embeddings = cur; + } - embeddings = cur; + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + ggml_set_name(cur, "result_output"); } - - // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); - ggml_set_name(cur, "result_output"); - lctx.use_buf(ctx0, -1); // logits -> probs @@ -1632,26 +1668,28 @@ static bool llama_eval_internal( // update kv token count lctx.kv_self.n = n_past + N; - // extract logits - { - auto & logits_out = lctx.logits; + if (lctx.mpi_rank == 0) { + // extract logits + { + auto & logits_out = lctx.logits; - if (lctx.logits_all) { - logits_out.resize(n_vocab * N); - memcpy(logits_out.data(), (float *) ggml_get_data(cur), sizeof(float)*n_vocab*N); - } else { - // return result for just the last token - logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *) ggml_get_data(cur) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + if (lctx.logits_all) { + logits_out.resize(n_vocab * N); + memcpy(logits_out.data(), (float *) ggml_get_data(cur), sizeof(float)*n_vocab*N); + } else { + // return result for just the last token + logits_out.resize(n_vocab); + memcpy(logits_out.data(), (float *) ggml_get_data(cur) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + } } - } - // extract embeddings - if (!lctx.embedding.empty()) { - auto & embedding_out = lctx.embedding; + // extract embeddings + if (!lctx.embedding.empty()) { + auto & embedding_out = lctx.embedding; - embedding_out.resize(n_embd); - memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); + embedding_out.resize(n_embd); + memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); + } } if (mem_per_token == 0) { @@ -2603,6 +2641,14 @@ struct llama_context * llama_new_context_with_model( ctx->rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; +#ifdef GGML_USE_MPI + MPI_Comm_size(MPI_COMM_WORLD, &ctx->mpi_size); + MPI_Comm_rank(MPI_COMM_WORLD, &ctx->mpi_rank); +#else + ctx->mpi_size = 1; + ctx->mpi_rank = 0; +#endif + ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; // reserve memory for context buffers @@ -2675,6 +2721,16 @@ struct llama_context * llama_new_context_with_model( } #endif + if (ctx->mpi_rank > 0) { + // Enter a blocking eval loop with dummy input, letting rank=0 drive the process + const std::vector tmp = { llama_token_bos(), }; + while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)); +#ifdef GGML_USE_MPI + MPI_Finalize(); +#endif + exit(1); + } + return ctx; } @@ -3351,6 +3407,13 @@ int llama_eval( int n_tokens, int n_past, int n_threads) { +#ifdef GGML_USE_MPI + // Synchronize the worker node parameters with the root node + MPI_Barrier(MPI_COMM_WORLD); + MPI_Bcast(&n_past, 1, MPI_INT, 0, MPI_COMM_WORLD); + MPI_Bcast(&n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD); + MPI_Bcast(&n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD); +#endif if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; @@ -3434,6 +3497,14 @@ int llama_n_embd(const struct llama_context * ctx) { return ctx->model.hparams.n_embd; } +int llama_mpi_rank(const struct llama_context * ctx) { + return ctx->mpi_rank; +} + +int llama_mpi_size(const struct llama_context * ctx) { + return ctx->mpi_size; +} + int llama_get_vocab( const struct llama_context * ctx, const char * * strings, diff --git a/llama.h b/llama.h index 5bb1964bd..1920584c5 100644 --- a/llama.h +++ b/llama.h @@ -145,6 +145,8 @@ extern "C" { // If numa is true, use NUMA optimizations // Call once at the start of the program LLAMA_API void llama_init_backend(bool numa); + // Call once at the end of the program - currently only used for MPI + LLAMA_API void llama_finalize_backend(); LLAMA_API int64_t llama_time_us(); @@ -257,6 +259,8 @@ extern "C" { LLAMA_API int llama_n_vocab(const struct llama_context * ctx); LLAMA_API int llama_n_ctx (const struct llama_context * ctx); LLAMA_API int llama_n_embd (const struct llama_context * ctx); + LLAMA_API int llama_mpi_rank (const struct llama_context * ctx); + LLAMA_API int llama_mpi_size (const struct llama_context * ctx); // Get the vocabulary as output parameters. // Returns number of results.