From 01b6f68a003e4de97098001ae9650ee1c3645b13 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 14 Jan 2024 17:30:22 +0200 Subject: [PATCH] backend : group nodes in a single compute when user don't need them --- examples/simple/simple.cpp | 32 ++++++++++++++++++-------------- ggml-backend.c | 21 ++++++++++++++------- ggml-backend.h | 8 +++++++- 3 files changed, 39 insertions(+), 22 deletions(-) diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index b3ae68492..dac7aa60a 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -8,24 +8,28 @@ // a function that can be called for every computed node during graph evaluation // the user can choose to whether to observe the data of the node depending on the tensor parameters -static bool observe_compute(int node_index, struct ggml_tensor * t, void * user_data) { +static bool observe_compute(int node_index, struct ggml_tensor * t, bool ask, void * user_data) { GGML_UNUSED(user_data); - // check if name contains soft_max - if (strstr(t->name, "soft_max") != 0) { - printf("%s: node_index = %5d, t->name = %32s, t->op = %12s, [%5d, %5d, %5d, %5d]\n", - __func__, node_index, t->name, ggml_op_name(t->op), (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); - - std::vector t_data(ggml_nelements(t)); - ggml_backend_tensor_get(t, t_data.data(), 0, ggml_nbytes(t)); - - // print first row - for (int i = 0; i < t->ne[0]; i++) { - printf("%8.4f ", t_data[i]); - } - printf("\n"); + // the scheduler is asking us if we want to observe this node + if (ask) { + // check if name contains soft_max + return strstr(t->name, "soft_max") != 0; } + // print the node data + printf("%s: node_index = %5d, t->name = %32s, t->op = %12s, [%5d, %5d, %5d, %5d]\n", + __func__, node_index, t->name, ggml_op_name(t->op), (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); + + std::vector t_data(ggml_nelements(t)); + ggml_backend_tensor_get(t, t_data.data(), 0, ggml_nbytes(t)); + + // print first row + for (int i = 0; i < t->ne[0]; i++) { + printf("%8.4f ", t_data[i]); + } + printf("\n"); + return true; } diff --git a/ggml-backend.c b/ggml-backend.c index ee78f45fa..0ec46ed32 100644 --- a/ggml-backend.c +++ b/ggml-backend.c @@ -1337,18 +1337,25 @@ static void sched_compute_splits(ggml_backend_sched_t sched) { for (int j = 0; j < split->graph.n_nodes; j++) { struct ggml_tensor * t = split->graph.nodes[j]; - struct ggml_cgraph gv = ggml_graph_view(&split->graph, j, j + 1); + int k = j; + + // check if the user needs data from this node + while (!sched->callback_eval(k, t, true, sched->callback_eval_user_data) && k < split->graph.n_nodes - 1) { + t = split->graph.nodes[++k]; + } + + struct ggml_cgraph gv = ggml_graph_view(&split->graph, j, k + 1); ggml_backend_graph_compute(split_backend, &gv); - if (ggml_is_view_op(t->op)) { - continue; - } - - // TODO: j is node index in the split, not in the original graph - if (!sched->callback_eval(j, t, sched->callback_eval_user_data)) { + // TODO: k is node index in the split, not in the original graph + // TODO: avoid the ask == true call here + if (sched->callback_eval(k, t, true, sched->callback_eval_user_data) && + !sched->callback_eval(k, t, false, sched->callback_eval_user_data)) { break; } + + j = k; } } uint64_t compute_end_us = ggml_time_us(); diff --git a/ggml-backend.h b/ggml-backend.h index 057ed1201..0d4ff69ba 100644 --- a/ggml-backend.h +++ b/ggml-backend.h @@ -148,8 +148,14 @@ extern "C" { struct ggml_backend_sched; typedef struct ggml_backend_sched * ggml_backend_sched_t; + // when ask == true, the scheduler wants to know if the user wants to observe this node + // this allows the scheduler to batch nodes together in order to evaluate them in a single call + // + // when ask == false, the scheduler is passing the node tensor to the user for observation + // if the user returns false, the scheduler will cancel the graph compute + // // TODO: propose to rename to ggml_backend_sched_callback_eval - typedef bool (*ggml_backend_sched_eval_callback)(int node_index, struct ggml_tensor * t, void * user_data); + typedef bool (*ggml_backend_sched_eval_callback)(int node_index, struct ggml_tensor * t, bool ask, void * user_data); // Initialize a backend scheduler GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);