diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index e7d2ad592..65bb238a0 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -1533,16 +1533,17 @@ int main(int argc, char ** argv) { int n_past = 0; - ggml_cgraph gf = {}; + struct ggml_cgraph * gf = NULL; + gf = ggml_new_graph_custom(ctx0, LLAMA_TRAIN_MAX_NODES, true); get_example_targets_batch(ctx0, 64*ex+0, tokens_input, targets); - struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch); + struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, gf, tokens_input, n_tokens, n_past, n_batch); // struct ggml_tensor * e = cross_entropy_loss(ctx0, targets, logits); struct ggml_tensor * e = square_error_loss(ctx0, targets, logits); - ggml_build_forward_expand(&gf, e); - ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1); + ggml_build_forward_expand(gf, e); + ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1); float error_before_opt = ggml_get_f32_1d(e, 0); @@ -1552,8 +1553,8 @@ int main(int argc, char ** argv) { opt_params_lbfgs.lbfgs.n_iter = 16; ggml_opt(ctx0, opt_params_lbfgs, e); // - ggml_build_forward_expand(&gf, e); - ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1); + ggml_build_forward_expand(gf, e); + ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1); float error_after_opt = ggml_get_f32_1d(e, 0); @@ -1600,13 +1601,14 @@ int main(int argc, char ** argv) { }; struct ggml_context * ctx0 = ggml_init(params); - ggml_cgraph gf = {}; + struct ggml_cgraph * gf = NULL; + gf = ggml_new_graph_custom(ctx0, LLAMA_TRAIN_MAX_NODES, true); int n_past = 0; - struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past); + struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, gf, tokens_input, sample_ctx, n_past); - ggml_build_forward_expand(&gf, logits); - ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1); + ggml_build_forward_expand(gf, logits); + ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1); struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx); struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);