backend : clean-up the implementation

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-01-15 15:52:41 +02:00
parent 01b6f68a00
commit 83f3d7a83c
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 22 additions and 20 deletions

View File

@ -8,19 +8,20 @@
// a function that can be called for every computed node during graph evaluation // a function that can be called for every computed node during graph evaluation
// the user can choose to whether to observe the data of the node depending on the tensor parameters // the user can choose to whether to observe the data of the node depending on the tensor parameters
static bool observe_compute(int node_index, struct ggml_tensor * t, bool ask, void * user_data) { static bool observe_compute(struct ggml_tensor * t, bool ask, void * user_data) {
GGML_UNUSED(user_data); GGML_UNUSED(user_data);
// the scheduler is asking us if we want to observe this node // the scheduler is asking us if we want to observe this node
if (ask) { if (ask) {
// check if name contains soft_max // check if name contains soft_max (customize to your needs)
return strstr(t->name, "soft_max") != 0; return strstr(t->name, "soft_max") != 0;
} }
// print the node data // print the node info
printf("%s: node_index = %5d, t->name = %32s, t->op = %12s, [%5d, %5d, %5d, %5d]\n", printf("%s: t->name = %32s, t->op = %12s, [%5d, %5d, %5d, %5d]\n",
__func__, node_index, t->name, ggml_op_name(t->op), (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); __func__, t->name, ggml_op_name(t->op), (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
// this will copy the data to host memory (if needed)
std::vector<float> t_data(ggml_nelements(t)); std::vector<float> t_data(ggml_nelements(t));
ggml_backend_tensor_get(t, t_data.data(), 0, ggml_nbytes(t)); ggml_backend_tensor_get(t, t_data.data(), 0, ggml_nbytes(t));

View File

@ -1334,28 +1334,31 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
//ggml_backend_synchronize(split_backend); // necessary to measure compute time //ggml_backend_synchronize(split_backend); // necessary to measure compute time
} else { } else {
// similar to ggml_backend_compare_graph_backend // similar to ggml_backend_compare_graph_backend
for (int j = 0; j < split->graph.n_nodes; j++) { for (int j0 = 0; j0 < split->graph.n_nodes; j0++) {
struct ggml_tensor * t = split->graph.nodes[j]; struct ggml_tensor * t = split->graph.nodes[j0];
int k = j; int j1 = j0;
// check if the user needs data from this node // determine the range [j0, j1] of nodes that can be computed together
while (!sched->callback_eval(k, t, true, sched->callback_eval_user_data) && k < split->graph.n_nodes - 1) { while (j1 < split->graph.n_nodes - 1) {
t = split->graph.nodes[++k]; // check if the user needs data from this node
if (sched->callback_eval(t, true, sched->callback_eval_user_data)) {
break;
}
t = split->graph.nodes[++j1];
} }
struct ggml_cgraph gv = ggml_graph_view(&split->graph, j, k + 1); struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);
ggml_backend_graph_compute(split_backend, &gv); ggml_backend_graph_compute(split_backend, &gv);
// TODO: k is node index in the split, not in the original graph if (sched->callback_eval(t, true, sched->callback_eval_user_data) && // ask
// TODO: avoid the ask == true call here !sched->callback_eval(t, false, sched->callback_eval_user_data)) { // eval
if (sched->callback_eval(k, t, true, sched->callback_eval_user_data) &&
!sched->callback_eval(k, t, false, sched->callback_eval_user_data)) {
break; break;
} }
j = k; j0 = j1;
} }
} }
uint64_t compute_end_us = ggml_time_us(); uint64_t compute_end_us = ggml_time_us();

View File

@ -154,8 +154,7 @@ extern "C" {
// when ask == false, the scheduler is passing the node tensor to the user for observation // when ask == false, the scheduler is passing the node tensor to the user for observation
// if the user returns false, the scheduler will cancel the graph compute // if the user returns false, the scheduler will cancel the graph compute
// //
// TODO: propose to rename to ggml_backend_sched_callback_eval typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
typedef bool (*ggml_backend_sched_eval_callback)(int node_index, struct ggml_tensor * t, bool ask, void * user_data);
// Initialize a backend scheduler // Initialize a backend scheduler
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size); GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
@ -195,7 +194,6 @@ extern "C" {
GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph); GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy); GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
// TODO: propose to rename this to ggml_backend_callback_compare
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data); typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
// Compare the output of two backends // Compare the output of two backends