baby-llama : allocate graphs in ggml_context (#5573)

* Fixed the baby-llama issue (see issue #4830)

* minor : fix whitespaces

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
NawafAlansari 2024-02-19 03:25:38 -05:00 committed by GitHub
parent 11b12de39b
commit 4480542b22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1533,16 +1533,17 @@ int main(int argc, char ** argv) {
int n_past = 0; int n_past = 0;
ggml_cgraph gf = {}; struct ggml_cgraph * gf = NULL;
gf = ggml_new_graph_custom(ctx0, LLAMA_TRAIN_MAX_NODES, true);
get_example_targets_batch(ctx0, 64*ex+0, tokens_input, targets); get_example_targets_batch(ctx0, 64*ex+0, tokens_input, targets);
struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch); struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, gf, tokens_input, n_tokens, n_past, n_batch);
// struct ggml_tensor * e = cross_entropy_loss(ctx0, targets, logits); // struct ggml_tensor * e = cross_entropy_loss(ctx0, targets, logits);
struct ggml_tensor * e = square_error_loss(ctx0, targets, logits); struct ggml_tensor * e = square_error_loss(ctx0, targets, logits);
ggml_build_forward_expand(&gf, e); ggml_build_forward_expand(gf, e);
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1); ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1);
float error_before_opt = ggml_get_f32_1d(e, 0); float error_before_opt = ggml_get_f32_1d(e, 0);
@ -1552,8 +1553,8 @@ int main(int argc, char ** argv) {
opt_params_lbfgs.lbfgs.n_iter = 16; opt_params_lbfgs.lbfgs.n_iter = 16;
ggml_opt(ctx0, opt_params_lbfgs, e); ggml_opt(ctx0, opt_params_lbfgs, e);
// //
ggml_build_forward_expand(&gf, e); ggml_build_forward_expand(gf, e);
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1); ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1);
float error_after_opt = ggml_get_f32_1d(e, 0); float error_after_opt = ggml_get_f32_1d(e, 0);
@ -1600,13 +1601,14 @@ int main(int argc, char ** argv) {
}; };
struct ggml_context * ctx0 = ggml_init(params); struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph gf = {}; struct ggml_cgraph * gf = NULL;
gf = ggml_new_graph_custom(ctx0, LLAMA_TRAIN_MAX_NODES, true);
int n_past = 0; int n_past = 0;
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past); struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, gf, tokens_input, sample_ctx, n_past);
ggml_build_forward_expand(&gf, logits); ggml_build_forward_expand(gf, logits);
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1); ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1);
struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx); struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx); struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);