test: fix OPT_STEP_ADAMW for test-backend-ops (ggml/974)

This commit is contained in:
Johannes Gäßler 2024-09-30 09:55:23 +02:00 committed by Georgi Gerganov
parent cb00020504
commit e98c1c188e
No known key found for this signature in database
GPG Key ID: BF970631944C16B7
3 changed files with 11 additions and 5 deletions

View File

@ -2052,6 +2052,7 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_opt_step_adamw( GGML_API struct ggml_tensor * ggml_opt_step_adamw(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * grad,
float alpha, float alpha,
float beta1, float beta1,
float beta2, float beta2,

View File

@ -7818,12 +7818,14 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
struct ggml_tensor * ggml_opt_step_adamw( struct ggml_tensor * ggml_opt_step_adamw(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * grad,
float alpha, float alpha,
float beta1, float beta1,
float beta2, float beta2,
float eps, float eps,
float wd) { float wd) {
GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM); GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
GGML_ASSERT(ggml_are_same_shape(a, grad));
GGML_ASSERT(alpha > 0.0f); GGML_ASSERT(alpha > 0.0f);
GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f); GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f); GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
@ -7842,9 +7844,9 @@ struct ggml_tensor * ggml_opt_step_adamw(
result->op = GGML_OP_OPT_STEP_ADAMW; result->op = GGML_OP_OPT_STEP_ADAMW;
result->src[0] = a; result->src[0] = a;
result->src[1] = a->grad; result->src[1] = grad;
result->src[2] = ggml_dup_tensor(ctx, a); result->src[2] = ggml_dup_tensor(ctx, grad);
result->src[3] = ggml_dup_tensor(ctx, a); result->src[3] = ggml_dup_tensor(ctx, grad);
return result; return result;
} }
@ -18769,7 +18771,7 @@ void ggml_build_opt_adamw(
if (node->flags & GGML_TENSOR_FLAG_PARAM) { if (node->flags & GGML_TENSOR_FLAG_PARAM) {
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd); struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd);
ggml_build_forward_expand(gb, opt_step); ggml_build_forward_expand(gb, opt_step);
} }
} }

View File

@ -2751,7 +2751,10 @@ struct test_opt_step_adamw : public test_case {
ggml_set_param(ctx, a); // Despite tensor a having gradients the output tensor will not. ggml_set_param(ctx, a); // Despite tensor a having gradients the output tensor will not.
ggml_set_name(a, "a"); ggml_set_name(a, "a");
ggml_tensor * out = ggml_opt_step_adamw(ctx, a, alpha, beta1, beta2, eps, wd); ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
ggml_set_name(grad, "grad");
ggml_tensor * out = ggml_opt_step_adamw(ctx, a, grad, alpha, beta1, beta2, eps, wd);
ggml_set_name(out, "out"); ggml_set_name(out, "out");
return out; return out;