mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-30 22:03:03 +01:00
backend : add eval callback
ggml-ci
This commit is contained in:
parent
4483396751
commit
65648b341f
@ -6,11 +6,36 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
// a function that can be called for every computed node during graph evaluation
|
||||||
|
// the user can choose to whether to observe the data of the node depending on the tensor parameters
|
||||||
|
static bool observe_compute(int node_index, struct ggml_tensor * t, void * user_data) {
|
||||||
|
GGML_UNUSED(user_data);
|
||||||
|
|
||||||
|
// check if name contains soft_max
|
||||||
|
if (strstr(t->name, "soft_max") != 0) {
|
||||||
|
printf("%s: node_index = %5d, t->name = %32s, t->op = %12s, [%5d, %5d, %5d, %5d]\n",
|
||||||
|
__func__, node_index, t->name, ggml_op_name(t->op), (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
|
||||||
|
|
||||||
|
std::vector<float> t_data(ggml_nelements(t));
|
||||||
|
ggml_backend_tensor_get(t, t_data.data(), 0, ggml_nbytes(t));
|
||||||
|
|
||||||
|
// print first row
|
||||||
|
for (int i = 0; i < t->ne[0]; i++) {
|
||||||
|
printf("%8.4f ", t_data[i]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
|
bool observe = false;
|
||||||
|
|
||||||
if (argc == 1 || argv[1][0] == '-') {
|
if (argc == 1 || argv[1][0] == '-') {
|
||||||
printf("usage: %s MODEL_PATH [PROMPT]\n" , argv[0]);
|
printf("usage: %s MODEL_PATH [PROMPT] [OBSERV]\n" , argv[0]);
|
||||||
return 1 ;
|
return 1 ;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -22,6 +47,10 @@ int main(int argc, char ** argv) {
|
|||||||
params.prompt = argv[2];
|
params.prompt = argv[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (argc >= 4) {
|
||||||
|
observe = atoi(argv[3]);
|
||||||
|
}
|
||||||
|
|
||||||
if (params.prompt.empty()) {
|
if (params.prompt.empty()) {
|
||||||
params.prompt = "Hello my name is";
|
params.prompt = "Hello my name is";
|
||||||
}
|
}
|
||||||
@ -37,7 +66,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
llama_model_params model_params = llama_model_default_params();
|
llama_model_params model_params = llama_model_default_params();
|
||||||
|
|
||||||
// model_params.n_gpu_layers = 99; // offload all layers to the GPU
|
model_params.n_gpu_layers = 99; // offload all layers to the GPU
|
||||||
|
|
||||||
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
|
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
|
||||||
|
|
||||||
@ -55,6 +84,9 @@ int main(int argc, char ** argv) {
|
|||||||
ctx_params.n_threads = params.n_threads;
|
ctx_params.n_threads = params.n_threads;
|
||||||
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
||||||
|
|
||||||
|
ctx_params.cb_eval = observe ? observe_compute : NULL;
|
||||||
|
ctx_params.cb_eval_user_data = NULL;
|
||||||
|
|
||||||
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
||||||
|
|
||||||
if (ctx == NULL) {
|
if (ctx == NULL) {
|
||||||
|
@ -802,6 +802,9 @@ struct ggml_backend_sched {
|
|||||||
__attribute__((aligned(GGML_MEM_ALIGN)))
|
__attribute__((aligned(GGML_MEM_ALIGN)))
|
||||||
#endif
|
#endif
|
||||||
char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
|
char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
|
||||||
|
|
||||||
|
ggml_backend_sched_eval_callback callback_eval;
|
||||||
|
void * callback_eval_user_data;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node)
|
#define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node)
|
||||||
@ -1324,9 +1327,30 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
|
|||||||
ggml_graph_dump_dot(split->graph, NULL, split_filename);
|
ggml_graph_dump_dot(split->graph, NULL, split_filename);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
uint64_t compute_start_us = ggml_time_us();
|
uint64_t compute_start_us = ggml_time_us();
|
||||||
|
if (!sched->callback_eval) {
|
||||||
ggml_backend_graph_compute(split_backend, &split->graph);
|
ggml_backend_graph_compute(split_backend, &split->graph);
|
||||||
//ggml_backend_synchronize(split_backend); // necessary to measure compute time
|
//ggml_backend_synchronize(split_backend); // necessary to measure compute time
|
||||||
|
} else {
|
||||||
|
// similar to ggml_backend_compare_graph_backend
|
||||||
|
for (int j = 0; j < split->graph.n_nodes; j++) {
|
||||||
|
struct ggml_tensor * t = split->graph.nodes[j];
|
||||||
|
|
||||||
|
struct ggml_cgraph gv = ggml_graph_view(&split->graph, j, j + 1);
|
||||||
|
|
||||||
|
ggml_backend_graph_compute(split_backend, &gv);
|
||||||
|
|
||||||
|
if (ggml_is_view_op(t->op)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: j is node index in the split, not in the original graph
|
||||||
|
if (!sched->callback_eval(j, t, sched->callback_eval_user_data)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
uint64_t compute_end_us = ggml_time_us();
|
uint64_t compute_end_us = ggml_time_us();
|
||||||
compute_us[split_backend_id] += compute_end_us - compute_start_us;
|
compute_us[split_backend_id] += compute_end_us - compute_start_us;
|
||||||
}
|
}
|
||||||
@ -1352,6 +1376,10 @@ static void sched_reset(ggml_backend_sched_t sched) {
|
|||||||
memset(sched->node_talloc, 0, sizeof(sched->node_talloc[0]) * hash_size);
|
memset(sched->node_talloc, 0, sizeof(sched->node_talloc[0]) * hash_size);
|
||||||
memset(sched->node_copies, 0, sizeof(sched->node_copies[0]) * hash_size);
|
memset(sched->node_copies, 0, sizeof(sched->node_copies[0]) * hash_size);
|
||||||
|
|
||||||
|
// TODO: should we clear the callbacks?
|
||||||
|
//sched->callback_eval = NULL;
|
||||||
|
//sched->callback_eval_user_data = NULL;
|
||||||
|
|
||||||
sched->is_reset = true;
|
sched->is_reset = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1431,6 +1459,12 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
|
|||||||
sched_reset(sched);
|
sched_reset(sched);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
|
||||||
|
sched->callback_eval = callback;
|
||||||
|
sched->callback_eval_user_data = user_data;
|
||||||
|
}
|
||||||
|
|
||||||
int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {
|
int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {
|
||||||
return sched->n_splits;
|
return sched->n_splits;
|
||||||
}
|
}
|
||||||
|
@ -148,6 +148,9 @@ extern "C" {
|
|||||||
struct ggml_backend_sched;
|
struct ggml_backend_sched;
|
||||||
typedef struct ggml_backend_sched * ggml_backend_sched_t;
|
typedef struct ggml_backend_sched * ggml_backend_sched_t;
|
||||||
|
|
||||||
|
// TODO: propose to rename to ggml_backend_sched_callback_eval
|
||||||
|
typedef bool (*ggml_backend_sched_eval_callback)(int node_index, struct ggml_tensor * t, void * user_data);
|
||||||
|
|
||||||
// Initialize a backend scheduler
|
// Initialize a backend scheduler
|
||||||
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
|
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
|
||||||
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
|
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
|
||||||
@ -168,6 +171,9 @@ extern "C" {
|
|||||||
// Reset all assignments and allocators - must be called before using the sched allocators to allocate inputs
|
// Reset all assignments and allocators - must be called before using the sched allocators to allocate inputs
|
||||||
GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
|
GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
|
||||||
|
|
||||||
|
// Set a callback to be called for each resulting node during graph compute
|
||||||
|
GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Utils
|
// Utils
|
||||||
//
|
//
|
||||||
@ -183,6 +189,7 @@ extern "C" {
|
|||||||
GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
|
GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
|
||||||
GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
|
GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
|
||||||
|
|
||||||
|
// TODO: propose to rename this to ggml_backend_callback_compare
|
||||||
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
|
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
|
||||||
|
|
||||||
// Compare the output of two backends
|
// Compare the output of two backends
|
||||||
|
@ -1393,6 +1393,9 @@ struct llama_cparams {
|
|||||||
|
|
||||||
bool mul_mat_q;
|
bool mul_mat_q;
|
||||||
bool offload_kqv;
|
bool offload_kqv;
|
||||||
|
|
||||||
|
ggml_backend_sched_eval_callback cb_eval;
|
||||||
|
void * cb_eval_user_data;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_layer {
|
struct llama_layer {
|
||||||
@ -6254,6 +6257,7 @@ static int llama_decode_internal(
|
|||||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||||
|
|
||||||
ggml_backend_sched_reset(lctx.sched);
|
ggml_backend_sched_reset(lctx.sched);
|
||||||
|
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
|
||||||
|
|
||||||
ggml_cgraph * gf = llama_build_graph(lctx, batch);
|
ggml_cgraph * gf = llama_build_graph(lctx, batch);
|
||||||
|
|
||||||
@ -9267,6 +9271,8 @@ struct llama_context_params llama_context_default_params() {
|
|||||||
/*.logits_all =*/ false,
|
/*.logits_all =*/ false,
|
||||||
/*.embedding =*/ false,
|
/*.embedding =*/ false,
|
||||||
/*.offload_kqv =*/ true,
|
/*.offload_kqv =*/ true,
|
||||||
|
/*.cb_eval =*/ nullptr,
|
||||||
|
/*.cb_eval_user_data =*/ nullptr,
|
||||||
};
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
@ -9401,6 +9407,9 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
|
hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
|
||||||
hparams.n_ctx_train;
|
hparams.n_ctx_train;
|
||||||
|
|
||||||
|
cparams.cb_eval = params.cb_eval;
|
||||||
|
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
||||||
|
|
||||||
auto rope_scaling_type = params.rope_scaling_type;
|
auto rope_scaling_type = params.rope_scaling_type;
|
||||||
if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) {
|
if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) {
|
||||||
rope_scaling_type = hparams.rope_scaling_type_train;
|
rope_scaling_type = hparams.rope_scaling_type_train;
|
||||||
|
4
llama.h
4
llama.h
@ -2,6 +2,7 @@
|
|||||||
#define LLAMA_H
|
#define LLAMA_H
|
||||||
|
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
#include "ggml-cuda.h"
|
#include "ggml-cuda.h"
|
||||||
#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
|
#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
|
||||||
@ -239,6 +240,9 @@ extern "C" {
|
|||||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
||||||
bool embedding; // embedding mode only
|
bool embedding; // embedding mode only
|
||||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||||
|
|
||||||
|
ggml_backend_sched_eval_callback cb_eval;
|
||||||
|
void * cb_eval_user_data;
|
||||||
};
|
};
|
||||||
|
|
||||||
// model quantization parameters
|
// model quantization parameters
|
||||||
|
Loading…
Reference in New Issue
Block a user