diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index a8a74bee1..ce3d92cb2 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2052,6 +2052,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_opt_step_adamw( struct ggml_context * ctx, struct ggml_tensor * a, + struct ggml_tensor * grad, float alpha, float beta1, float beta2, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index aac4e3a7b..bcbc32d91 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7818,12 +7818,14 @@ struct ggml_tensor * ggml_cross_entropy_loss_back( struct ggml_tensor * ggml_opt_step_adamw( struct ggml_context * ctx, struct ggml_tensor * a, + struct ggml_tensor * grad, float alpha, float beta1, float beta2, float eps, float wd) { GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM); + GGML_ASSERT(ggml_are_same_shape(a, grad)); GGML_ASSERT(alpha > 0.0f); GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f); GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f); @@ -7842,9 +7844,9 @@ struct ggml_tensor * ggml_opt_step_adamw( result->op = GGML_OP_OPT_STEP_ADAMW; result->src[0] = a; - result->src[1] = a->grad; - result->src[2] = ggml_dup_tensor(ctx, a); - result->src[3] = ggml_dup_tensor(ctx, a); + result->src[1] = grad; + result->src[2] = ggml_dup_tensor(ctx, grad); + result->src[3] = ggml_dup_tensor(ctx, grad); return result; } @@ -18769,7 +18771,7 @@ void ggml_build_opt_adamw( if (node->flags & GGML_TENSOR_FLAG_PARAM) { GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); - struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd); + struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd); ggml_build_forward_expand(gb, opt_step); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 5c78b6704..95d983aa0 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2751,7 +2751,10 @@ struct test_opt_step_adamw : public test_case { ggml_set_param(ctx, a); // Despite tensor a having gradients the output tensor will not. ggml_set_name(a, "a"); - ggml_tensor * out = ggml_opt_step_adamw(ctx, a, alpha, beta1, beta2, eps, wd); + ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]); + ggml_set_name(grad, "grad"); + + ggml_tensor * out = ggml_opt_step_adamw(ctx, a, grad, alpha, beta1, beta2, eps, wd); ggml_set_name(out, "out"); return out;