mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 11:23:56 +01:00
mpi : add names for layer inputs + prep ggml_mpi_graph_compute()
This commit is contained in:
parent
3232db628c
commit
e339d35579
@ -135,3 +135,12 @@ struct ggml_tensor * ggml_mpi_eval_init(
|
|||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_mpi_graph_compute(
|
||||||
|
struct ggml_mpi_context * ctx_mpi,
|
||||||
|
struct ggml_cgraph * gf,
|
||||||
|
int n_layers,
|
||||||
|
int n_embd,
|
||||||
|
int n_tokens) {
|
||||||
|
|
||||||
|
}
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
struct ggml_context;
|
struct ggml_context;
|
||||||
struct ggml_tensor;
|
struct ggml_tensor;
|
||||||
|
struct ggml_cgraph;
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
@ -35,6 +36,13 @@ struct ggml_tensor * ggml_mpi_eval_init(
|
|||||||
int * n_past,
|
int * n_past,
|
||||||
int * n_threads);
|
int * n_threads);
|
||||||
|
|
||||||
|
void ggml_mpi_graph_compute(
|
||||||
|
struct ggml_mpi_context * ctx_mpi,
|
||||||
|
struct ggml_cgraph * gf,
|
||||||
|
int n_layers,
|
||||||
|
int n_embd,
|
||||||
|
int n_tokens);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
32
llama.cpp
32
llama.cpp
@ -1372,9 +1372,9 @@ static bool llama_eval_internal(
|
|||||||
}
|
}
|
||||||
#endif // GGML_USE_CUBLAS
|
#endif // GGML_USE_CUBLAS
|
||||||
|
|
||||||
// EMM TODO distribute work more evenly - maybe rank=0 gets the smallest amount?
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
int slice_size = (n_layer + (lctx.mpi_size - 1)) / lctx.mpi_size;
|
ggml_format_name(inpL, "layer_inp_%d", il);
|
||||||
for (int il = lctx.mpi_rank * slice_size; il < n_layer && il < (lctx.mpi_rank + 1) * slice_size; ++il) {
|
|
||||||
offload_func_t offload_func = llama_nop;
|
offload_func_t offload_func = llama_nop;
|
||||||
|
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
@ -1588,29 +1588,9 @@ static bool llama_eval_internal(
|
|||||||
// used at the end to optionally extract the embeddings
|
// used at the end to optionally extract the embeddings
|
||||||
struct ggml_tensor * embeddings = NULL;
|
struct ggml_tensor * embeddings = NULL;
|
||||||
|
|
||||||
if (lctx.mpi_size > 1) {
|
|
||||||
#ifdef GGML_USE_MPI
|
|
||||||
cur = ggml_mpi_send_tensor(ctx0, cur, (lctx.mpi_rank+1)%lctx.mpi_size);
|
|
||||||
ggml_set_name(cur, "mpi_send");
|
|
||||||
#else
|
|
||||||
GGML_ASSERT(false);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
if (lctx.mpi_rank == 0) {
|
|
||||||
if (lctx.mpi_size > 1) {
|
|
||||||
#ifdef GGML_USE_MPI
|
|
||||||
cur = ggml_mpi_recv_tensor(ctx0, cur,
|
|
||||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N),
|
|
||||||
lctx.mpi_size-1);
|
|
||||||
ggml_set_name(cur, "mpi_recv");
|
|
||||||
#else
|
|
||||||
GGML_ASSERT(false);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
cur = ggml_rms_norm(ctx0, cur);
|
cur = ggml_rms_norm(ctx0, inpL);
|
||||||
offload_func_nr(cur);
|
offload_func_nr(cur);
|
||||||
ggml_set_name(cur, "rms_norm_2");
|
ggml_set_name(cur, "rms_norm_2");
|
||||||
|
|
||||||
@ -1622,11 +1602,9 @@ static bool llama_eval_internal(
|
|||||||
embeddings = cur;
|
embeddings = cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// lm_head
|
// lm_head
|
||||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||||
ggml_set_name(cur, "result_output");
|
ggml_set_name(cur, "result_output");
|
||||||
}
|
|
||||||
|
|
||||||
lctx.use_buf(ctx0, -1);
|
lctx.use_buf(ctx0, -1);
|
||||||
|
|
||||||
@ -1659,6 +1637,8 @@ static bool llama_eval_internal(
|
|||||||
|
|
||||||
ggml_graph_compute(ctx0, &gf);
|
ggml_graph_compute(ctx0, &gf);
|
||||||
}
|
}
|
||||||
|
#elif GGML_USE_MPI
|
||||||
|
ggml_mpi_graph_compute(lctx.ctx_mpi, &gf, n_layer, n_embd, n_tokens);
|
||||||
#else
|
#else
|
||||||
ggml_graph_compute(ctx0, &gf);
|
ggml_graph_compute(ctx0, &gf);
|
||||||
#endif
|
#endif
|
||||||
|
Loading…
Reference in New Issue
Block a user