mpi : trying to move more MPI stuff into ggml-mpi (WIP) (#2099)

This commit is contained in:
Georgi Gerganov 2023-07-09 14:08:53 +03:00
parent ef61acfbf5
commit 3232db628c
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
11 changed files with 134 additions and 67 deletions

View File

@ -34,7 +34,7 @@ struct MyModel* create_mymodel(int argc, char ** argv) {
} }
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
llama_init_backend(params.numa); llama_backend_init(params.numa);
llama_model * model; llama_model * model;
llama_context * ctx; llama_context * ctx;

View File

@ -35,7 +35,7 @@ int main(int argc, char ** argv) {
params.prompt = gpt_random_prompt(rng); params.prompt = gpt_random_prompt(rng);
} }
llama_init_backend(params.numa); llama_backend_init(params.numa);
llama_model * model; llama_model * model;
llama_context * ctx; llama_context * ctx;
@ -93,5 +93,7 @@ int main(int argc, char ** argv) {
llama_free(ctx); llama_free(ctx);
llama_free_model(model); llama_free_model(model);
llama_backend_free();
return 0; return 0;
} }

View File

@ -105,7 +105,7 @@ int main(int argc, char ** argv) {
params.prompt = gpt_random_prompt(rng); params.prompt = gpt_random_prompt(rng);
} }
llama_init_backend(params.numa); llama_backend_init(params.numa);
llama_model * model; llama_model * model;
llama_context * ctx; llama_context * ctx;
@ -671,7 +671,7 @@ int main(int argc, char ** argv) {
llama_free(ctx); llama_free(ctx);
llama_free_model(model); llama_free_model(model);
llama_finalize_backend(); llama_backend_free();
return 0; return 0;
} }

View File

@ -147,7 +147,7 @@ int main(int argc, char ** argv) {
params.prompt = gpt_random_prompt(rng); params.prompt = gpt_random_prompt(rng);
} }
llama_init_backend(params.numa); llama_backend_init(params.numa);
llama_model * model; llama_model * model;
llama_context * ctx; llama_context * ctx;
@ -172,7 +172,7 @@ int main(int argc, char ** argv) {
llama_free(ctx); llama_free(ctx);
llama_free_model(model); llama_free_model(model);
llama_finalize_backend(); llama_backend_free();
return 0; return 0;
} }

View File

@ -180,7 +180,7 @@ int main(int argc, char ** argv) {
usage(argv[0]); usage(argv[0]);
} }
llama_init_backend(false); llama_backend_init(false);
// parse command line arguments // parse command line arguments
const std::string fname_inp = argv[arg_idx]; const std::string fname_inp = argv[arg_idx];
@ -257,5 +257,7 @@ int main(int argc, char ** argv) {
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0); printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0);
} }
llama_backend_free();
return 0; return 0;
} }

View File

@ -1079,7 +1079,7 @@ int main(int argc, char **argv)
params.model_alias = params.model; params.model_alias = params.model;
} }
llama_init_backend(params.numa); llama_backend_init(params.numa);
LOG_INFO("build info", {{"build", BUILD_NUMBER}, LOG_INFO("build info", {{"build", BUILD_NUMBER},
{"commit", BUILD_COMMIT}}); {"commit", BUILD_COMMIT}});
@ -1309,5 +1309,7 @@ int main(int argc, char **argv)
return 1; return 1;
} }
llama_backend_free();
return 0; return 0;
} }

View File

@ -66,7 +66,7 @@ int main(int argc, char ** argv)
// Init LLM : // Init LLM :
//--------------------------------- //---------------------------------
llama_init_backend(params.numa); llama_backend_init(params.numa);
llama_model * model; llama_model * model;
llama_context * ctx; llama_context * ctx;
@ -173,7 +173,7 @@ int main(int argc, char ** argv)
llama_free( ctx ); llama_free( ctx );
llama_free_model( model ); llama_free_model( model );
llama_finalize_backend(); llama_backend_free();
return 0; return 0;
} }

View File

@ -2,9 +2,11 @@
#include "ggml.h" #include "ggml.h"
#include <mpi.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <mpi.h>
#define UNUSED GGML_UNUSED #define UNUSED GGML_UNUSED
struct ggml_mpi_tensor_info { struct ggml_mpi_tensor_info {
@ -52,9 +54,8 @@ static void ggml_mpi_compute_forward_recv(
struct ggml_tensor * ggml_mpi_send_tensor( struct ggml_tensor * ggml_mpi_send_tensor(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor *src, struct ggml_tensor * src,
int dst_rank) { int dst_rank) {
struct ggml_tensor * result = ggml_map_custom1_inplace_f32(ctx, src, ggml_mpi_compute_forward_send); struct ggml_tensor * result = ggml_map_custom1_inplace_f32(ctx, src, ggml_mpi_compute_forward_send);
// TODO how/when to free this struct? // TODO how/when to free this struct?
@ -67,9 +68,9 @@ struct ggml_tensor * ggml_mpi_send_tensor(
struct ggml_tensor * ggml_mpi_recv_tensor( struct ggml_tensor * ggml_mpi_recv_tensor(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor *parent, struct ggml_tensor * parent,
struct ggml_tensor *dst, struct ggml_tensor * dst,
int src_rank) { int src_rank) {
struct ggml_tensor * result = ggml_map_custom2_inplace_f32(ctx, dst, parent, ggml_mpi_compute_forward_recv); struct ggml_tensor * result = ggml_map_custom2_inplace_f32(ctx, dst, parent, ggml_mpi_compute_forward_recv);
// TODO how/when to free this struct? // TODO how/when to free this struct?
@ -79,3 +80,58 @@ struct ggml_tensor * ggml_mpi_recv_tensor(
return result; return result;
} }
struct ggml_mpi_context {
int mpi_rank;
int mpi_size;
};
void ggml_mpi_backend_init(void) {
MPI_Init(NULL, NULL);
}
void ggml_mpi_backend_free(void) {
MPI_Finalize();
}
struct ggml_mpi_context * ggml_mpi_init(void) {
struct ggml_mpi_context * ctx = calloc(1, sizeof(struct ggml_mpi_context));
MPI_Comm_rank(MPI_COMM_WORLD, &ctx->mpi_rank);
MPI_Comm_size(MPI_COMM_WORLD, &ctx->mpi_size);
return ctx;
}
void ggml_mpi_free(struct ggml_mpi_context * ctx) {
free(ctx);
}
int ggml_mpi_rank(struct ggml_mpi_context * ctx) {
return ctx->mpi_rank;
}
struct ggml_tensor * ggml_mpi_eval_init(
struct ggml_mpi_context * ctx_mpi,
struct ggml_context * ctx,
int n_embd,
int * n_tokens,
int * n_past,
int * n_threads) {
struct ggml_tensor * res = NULL;
// synchronize the worker node parameters with the root node
MPI_Barrier(MPI_COMM_WORLD);
MPI_Bcast(n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD);
MPI_Bcast(n_past, 1, MPI_INT, 0, MPI_COMM_WORLD);
MPI_Bcast(n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD);
if (ctx_mpi->mpi_rank > 0) {
res = ggml_mpi_recv_tensor(ctx, NULL,
ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, *n_tokens), ctx_mpi->mpi_rank - 1);
ggml_set_name(res, "mpi_recv");
}
return res;
}

View File

@ -9,13 +9,31 @@ extern "C" {
struct ggml_tensor * ggml_mpi_send_tensor( struct ggml_tensor * ggml_mpi_send_tensor(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor *src, struct ggml_tensor * src,
int dst_rank); int dst_rank);
struct ggml_tensor * ggml_mpi_recv_tensor( struct ggml_tensor * ggml_mpi_recv_tensor(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor *parent, struct ggml_tensor * parent,
struct ggml_tensor *dst, struct ggml_tensor * dst,
int src_rank); int src_rank);
struct ggml_mpi_context;
void ggml_mpi_backend_init(void);
void ggml_mpi_backend_free(void);
struct ggml_mpi_context * ggml_mpi_init(void);
void ggml_mpi_free(struct ggml_mpi_context * ctx);
int ggml_mpi_rank(struct ggml_mpi_context * ctx);
struct ggml_tensor * ggml_mpi_eval_init(
struct ggml_mpi_context * ctx_mpi,
struct ggml_context * ctx,
int n_embd,
int * n_tokens,
int * n_past,
int * n_threads);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -52,10 +52,6 @@
#include <sstream> #include <sstream>
#include <numeric> #include <numeric>
#ifdef GGML_USE_MPI
#include <mpi.h>
#endif
#if defined(_MSC_VER) #if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
#endif #endif
@ -337,8 +333,9 @@ struct llama_context {
ggml_metal_context * ctx_metal = NULL; ggml_metal_context * ctx_metal = NULL;
#endif #endif
int mpi_rank; #ifdef GGML_USE_MPI
int mpi_size; ggml_mpi_context * ctx_mpi = NULL;
#endif
int buf_last = 0; int buf_last = 0;
size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 }; size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
@ -859,7 +856,7 @@ bool llama_mlock_supported() {
return llama_mlock::SUPPORTED; return llama_mlock::SUPPORTED;
} }
void llama_init_backend(bool numa) { void llama_backend_init(bool numa) {
ggml_time_init(); ggml_time_init();
// needed to initialize f16 tables // needed to initialize f16 tables
@ -872,14 +869,15 @@ void llama_init_backend(bool numa) {
if (numa) { if (numa) {
ggml_numa_init(); ggml_numa_init();
} }
#ifdef GGML_USE_MPI #ifdef GGML_USE_MPI
MPI_Init(NULL, NULL); ggml_mpi_backend_init();
#endif #endif
} }
void llama_finalize_backend() { void llama_backend_free() {
#ifdef GGML_USE_MPI #ifdef GGML_USE_MPI
MPI_Finalize(); ggml_mpi_backend_free();
#endif #endif
} }
@ -1282,9 +1280,9 @@ static bool llama_eval_internal(
llama_context & lctx, llama_context & lctx,
const llama_token * tokens, const llama_token * tokens,
const float * embd, const float * embd,
const int n_tokens, int n_tokens,
const int n_past, int n_past,
const int n_threads, int n_threads,
const char * cgraph_fname) { const char * cgraph_fname) {
LLAMA_ASSERT((!tokens && embd) || (tokens && !embd)); LLAMA_ASSERT((!tokens && embd) || (tokens && !embd));
@ -1333,16 +1331,14 @@ static bool llama_eval_internal(
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
if (lctx.mpi_rank > 0) {
#ifdef GGML_USE_MPI #ifdef GGML_USE_MPI
inpL = ggml_mpi_recv_tensor(ctx0, NULL, inpL = ggml_mpi_eval_init(lctx.ctx_mpi, ctx0, n_embd, &n_tokens, &n_past, &n_threads);
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N),
lctx.mpi_rank-1); if (inpL) {
ggml_set_name(inpL, "mpi_recv"); // only rank 0 loads uses the input
#else } else
GGML_ASSERT(false);
#endif #endif
} else if (tokens) { if (tokens) {
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
ggml_set_name(embd, "embd"); ggml_set_name(embd, "embd");
memcpy(embd->data, tokens, N*ggml_element_size(embd)); memcpy(embd->data, tokens, N*ggml_element_size(embd));
@ -1585,7 +1581,6 @@ static bool llama_eval_internal(
// input for next layer // input for next layer
inpL = cur; inpL = cur;
} }
lctx.use_buf(ctx0, 0); lctx.use_buf(ctx0, 0);
@ -1601,6 +1596,7 @@ static bool llama_eval_internal(
GGML_ASSERT(false); GGML_ASSERT(false);
#endif #endif
} }
if (lctx.mpi_rank == 0) { if (lctx.mpi_rank == 0) {
if (lctx.mpi_size > 1) { if (lctx.mpi_size > 1) {
#ifdef GGML_USE_MPI #ifdef GGML_USE_MPI
@ -1688,7 +1684,11 @@ static bool llama_eval_internal(
// update kv token count // update kv token count
lctx.kv_self.n = n_past + N; lctx.kv_self.n = n_past + N;
if (lctx.mpi_rank == 0) { #ifdef GGML_USE_MPI
if (ggml_mpi_rank(lctx.ctx_mpi) == 0) {
#else
{
#endif
// extract logits // extract logits
{ {
auto & logits_out = lctx.logits; auto & logits_out = lctx.logits;
@ -2659,14 +2659,6 @@ struct llama_context * llama_new_context_with_model(
ctx->rng = std::mt19937(params.seed); ctx->rng = std::mt19937(params.seed);
ctx->logits_all = params.logits_all; ctx->logits_all = params.logits_all;
#ifdef GGML_USE_MPI
MPI_Comm_size(MPI_COMM_WORLD, &ctx->mpi_size);
MPI_Comm_rank(MPI_COMM_WORLD, &ctx->mpi_rank);
#else
ctx->mpi_size = 1;
ctx->mpi_rank = 0;
#endif
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
// reserve memory for context buffers // reserve memory for context buffers
@ -2739,15 +2731,17 @@ struct llama_context * llama_new_context_with_model(
} }
#endif #endif
if (ctx->mpi_rank > 0) { #ifdef GGML_USE_MPI
ctx->ctx_mpi = ggml_mpi_init();
if (ggml_mpi_rank(ctx->ctx_mpi) > 0) {
// Enter a blocking eval loop with dummy input, letting rank=0 drive the process // Enter a blocking eval loop with dummy input, letting rank=0 drive the process
const std::vector<llama_token> tmp = { llama_token_bos(), }; const std::vector<llama_token> tmp = { llama_token_bos(), };
while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)); while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {};
#ifdef GGML_USE_MPI llama_backend_free();
MPI_Finalize();
#endif
exit(1); exit(1);
} }
#endif
return ctx; return ctx;
} }
@ -3425,13 +3419,6 @@ int llama_eval(
int n_tokens, int n_tokens,
int n_past, int n_past,
int n_threads) { int n_threads) {
#ifdef GGML_USE_MPI
// Synchronize the worker node parameters with the root node
MPI_Barrier(MPI_COMM_WORLD);
MPI_Bcast(&n_past, 1, MPI_INT, 0, MPI_COMM_WORLD);
MPI_Bcast(&n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD);
MPI_Bcast(&n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD);
#endif
if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) { if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) {
fprintf(stderr, "%s: failed to eval\n", __func__); fprintf(stderr, "%s: failed to eval\n", __func__);
return 1; return 1;

View File

@ -158,9 +158,9 @@ extern "C" {
// Initialize the llama + ggml backend // Initialize the llama + ggml backend
// If numa is true, use NUMA optimizations // If numa is true, use NUMA optimizations
// Call once at the start of the program // Call once at the start of the program
LLAMA_API void llama_init_backend(bool numa); LLAMA_API void llama_backend_init(bool numa);
// Call once at the end of the program - currently only used for MPI // Call once at the end of the program - currently only used for MPI
LLAMA_API void llama_finalize_backend(); LLAMA_API void llama_backend_free();
LLAMA_API int64_t llama_time_us(); LLAMA_API int64_t llama_time_us();