From 202084d31d4247764fc6d6d40d2e2bda0c89a73a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 3 Sep 2024 17:21:46 +0200 Subject: [PATCH] tests: add gradient tests for all backends (ggml/932) * tests: add gradient checking to test-backend-ops * remove old comment * reorder includes * adjust SIN/COS parameters * add documentation, use supports_op if possible --- ggml/include/ggml.h | 12 +- ggml/src/ggml-backend.c | 4 + ggml/src/ggml-cuda.cu | 12 + ggml/src/ggml-cuda/cross-entropy-loss.cu | 4 +- ggml/src/ggml-cuda/sum.cu | 41 + ggml/src/ggml-cuda/sum.cuh | 5 + ggml/src/ggml-cuda/unary.cu | 29 + ggml/src/ggml-cuda/unary.cuh | 3 + ggml/src/ggml.c | 32 +- tests/test-backend-ops.cpp | 1030 ++++++++++++++++++++-- 10 files changed, 1080 insertions(+), 92 deletions(-) create mode 100644 ggml/src/ggml-cuda/sum.cu create mode 100644 ggml/src/ggml-cuda/sum.cuh diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 6354fcf51..536018b66 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1272,7 +1272,7 @@ extern "C" { size_t nb1, size_t nb2, size_t nb3, - size_t offset); + size_t offset); // in bytes // b -> view(a,offset,nb1,nb2,3), return view(a) GGML_API struct ggml_tensor * ggml_set_inplace( @@ -1282,19 +1282,19 @@ extern "C" { size_t nb1, size_t nb2, size_t nb3, - size_t offset); + size_t offset); // in bytes GGML_API struct ggml_tensor * ggml_set_1d( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, - size_t offset); + size_t offset); // in bytes GGML_API struct ggml_tensor * ggml_set_1d_inplace( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, - size_t offset); + size_t offset); // in bytes // b -> view(a,offset,nb1,nb2,3), return modified a GGML_API struct ggml_tensor * ggml_set_2d( @@ -1302,7 +1302,7 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, - size_t offset); + size_t offset); // in bytes // b -> view(a,offset,nb1,nb2,3), return view(a) GGML_API struct ggml_tensor * ggml_set_2d_inplace( @@ -1310,7 +1310,7 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, - size_t offset); + size_t offset); // in bytes // a -> b, return view(b) GGML_API struct ggml_tensor * ggml_cpy( diff --git a/ggml/src/ggml-backend.c b/ggml/src/ggml-backend.c index 96b93d2d8..b5d9301a7 100644 --- a/ggml/src/ggml-backend.c +++ b/ggml/src/ggml-backend.c @@ -827,6 +827,10 @@ GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float case GGML_OP_MUL_MAT: return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type; + case GGML_OP_ROPE_BACK: + return op->src[2] == NULL && (op->op_params[2] & 4) == 0; + case GGML_OP_IM2COL_BACK: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; default: return true; } diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index af1bd0518..982316f56 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -27,6 +27,7 @@ #include "ggml-cuda/rope.cuh" #include "ggml-cuda/scale.cuh" #include "ggml-cuda/softmax.cuh" +#include "ggml-cuda/sum.cuh" #include "ggml-cuda/sumrows.cuh" #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/unary.cuh" @@ -2180,6 +2181,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_dup(ctx, dst); break; case GGML_OP_ADD: + case GGML_OP_ADD1: // TODO: more efficient implementation ggml_cuda_op_add(ctx, dst); break; case GGML_OP_SUB: @@ -2196,6 +2198,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg break; case GGML_OP_UNARY: switch (ggml_get_unary_op(dst)) { + case GGML_UNARY_OP_NEG: + ggml_cuda_op_neg(ctx, dst); + break; case GGML_UNARY_OP_GELU: ggml_cuda_op_gelu(ctx, dst); break; @@ -2304,6 +2309,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_POOL_2D: ggml_cuda_op_pool2d(ctx, dst); break; + case GGML_OP_SUM: + ggml_cuda_op_sum(ctx, dst); + break; case GGML_OP_SUM_ROWS: ggml_cuda_op_sum_rows(ctx, dst); break; @@ -2748,6 +2756,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_NEG: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: @@ -2877,6 +2886,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_TRANSPOSE: case GGML_OP_NORM: case GGML_OP_ADD: + case GGML_OP_ADD1: case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: @@ -2896,7 +2906,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_ROPE: return ggml_is_contiguous(op->src[0]); case GGML_OP_IM2COL: + return op->src[0]->type == GGML_TYPE_F16; case GGML_OP_POOL_2D: + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cu b/ggml/src/ggml-cuda/cross-entropy-loss.cu index a14043e70..5575a90f6 100644 --- a/ggml/src/ggml-cuda/cross-entropy-loss.cu +++ b/ggml/src/ggml-cuda/cross-entropy-loss.cu @@ -1,6 +1,6 @@ #include "common.cuh" #include "cross-entropy-loss.cuh" -#include "sumrows.cuh" +#include "sum.cuh" #include #include @@ -102,5 +102,5 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); // Combine results from individual blocks: - sum_rows_f32_cuda(dst_tmp.ptr, dst_d, blocks_num.x, 1, stream); + sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream); } diff --git a/ggml/src/ggml-cuda/sum.cu b/ggml/src/ggml-cuda/sum.cu new file mode 100644 index 000000000..0d5e953ee --- /dev/null +++ b/ggml/src/ggml-cuda/sum.cu @@ -0,0 +1,41 @@ +#include "sumrows.cuh" +#include "sum.cuh" + +#include + +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) +#include +using namespace cub; +#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) + +void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) { +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) + size_t tmp_size = 0; + DeviceReduce::Sum(nullptr, tmp_size, x, dst, ne, stream); + ggml_cuda_pool_alloc tmp_alloc(pool, tmp_size); + DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, x, dst, ne, stream); +#else + // Use (inefficient) sum_rows implementation as a fallback. + // For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14. + sum_rows_f32_cuda(x, dst, ne, 1, stream); + GGML_UNUSED(pool); +#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) +} + +void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + + const int64_t ne = ggml_nelements(src0); + + ggml_cuda_pool & pool = ctx.pool(); + cudaStream_t stream = ctx.stream(); + + sum_f32_cuda(pool, src0_d, dst_d, ne, stream); +} diff --git a/ggml/src/ggml-cuda/sum.cuh b/ggml/src/ggml-cuda/sum.cuh new file mode 100644 index 000000000..8cadc3736 --- /dev/null +++ b/ggml/src/ggml-cuda/sum.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream); + +void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 89abfc21d..8ac669f94 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -1,5 +1,15 @@ #include "unary.cuh" +static __global__ void neg_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = -x[i]; +} + static __global__ void gelu_f32(const float * x, float * dst, const int k) { const float GELU_COEF_A = 0.044715f; const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; @@ -119,6 +129,11 @@ static __global__ void cos_f32(const float * x, float * dst, const int k) { dst[i] = cosf(x[i]); } +static void neg_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE; + neg_f32<<>>(x, dst, k); +} + static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; gelu_f32<<>>(x, dst, k); @@ -184,6 +199,20 @@ static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t cos_f32<<>>(x, dst, k); } +void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); +} + void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index c610e996a..ed2ffc461 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -1,5 +1,6 @@ #include "common.cuh" +#define CUDA_NEG_BLOCK_SIZE 256 #define CUDA_GELU_BLOCK_SIZE 256 #define CUDA_SILU_BLOCK_SIZE 256 #define CUDA_TANH_BLOCK_SIZE 256 @@ -12,6 +13,8 @@ #define CUDA_SIN_BLOCK_SIZE 256 #define CUDA_COS_BLOCK_SIZE 256 +void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5f106d52a..28ee46e04 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5267,6 +5267,7 @@ struct ggml_tensor * ggml_concat( bool is_node = false; if (a->grad || b->grad) { + GGML_ABORT("fatal error"); // TODO: implement is_node = true; } @@ -5388,6 +5389,7 @@ struct ggml_tensor * ggml_leaky_relu( bool is_node = false; if (!inplace && (a->grad)) { + GGML_ABORT("fatal error"); // TODO: not implemented is_node = true; } @@ -5826,6 +5828,7 @@ static struct ggml_tensor * ggml_set_impl( // make a view of the destination struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + GGML_ASSERT(offset < (size_t)(1 << 30)); int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; ggml_set_op_params(result, params, sizeof(params)); @@ -6783,14 +6786,12 @@ struct ggml_tensor * ggml_rope_back( GGML_ASSERT(ggml_is_vector(b)); GGML_ASSERT(b->type == GGML_TYPE_I32); GGML_ASSERT(a->ne[2] == b->ne[0]); - GGML_ASSERT(c == NULL && "freq factors not implemented yet"); - - GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); bool is_node = false; if (a->grad) { - is_node = false; // TODO: implement backward + GGML_ASSERT(false && "backwards pass not implemented"); + is_node = false; } struct ggml_tensor * result = ggml_dup_tensor(ctx, a); @@ -6808,6 +6809,7 @@ struct ggml_tensor * ggml_rope_back( result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; + result->src[2] = c; return result; } @@ -7361,6 +7363,11 @@ struct ggml_tensor * ggml_argsort( enum ggml_sort_order order) { bool is_node = false; + if (a->grad) { + GGML_ABORT("fatal error"); // TODO: not implemented + is_node = true; + } + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne); ggml_set_op_params_i32(result, 0, (int32_t) order); @@ -10953,9 +10960,6 @@ static void ggml_compute_forward_sum_f32( return; } - assert(ggml_is_scalar(dst)); - - assert(ggml_is_scalar(dst)); assert(src0->nb[0] == sizeof(float)); @@ -18356,14 +18360,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor if (src0->grad || src1->grad) { GGML_ASSERT(src0->type == tensor->type); GGML_ASSERT(tensor->grad->type == tensor->type); - GGML_ASSERT(tensor->grad->type == src1->grad->type); + GGML_ASSERT(!src1->grad || src1->grad->type == tensor->grad->type); tensor_grad_view = ggml_view_4d(ctx, - tensor->grad, - src1->grad->ne[0], - src1->grad->ne[1], - src1->grad->ne[2], - src1->grad->ne[3], + tensor->grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], nb1, nb2, nb3, offset); } @@ -18432,9 +18432,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor memcpy(&offset, tensor->op_params, sizeof(offset)); - size_t nb1 = tensor->nb[1]; - size_t nb2 = tensor->nb[2]; - size_t nb3 = tensor->nb[3]; + size_t nb1 = tensor->nb[1]; + size_t nb2 = tensor->nb[2]; + size_t nb3 = tensor->nb[3]; if (src0->type != src0->grad->type) { // gradient is typically F32, but src0 could be other type diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c3536054a..8f659c148 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1,3 +1,20 @@ +// This file defines tests for various GGML ops and backends. +// For the forward pass it asserts that the results of multiple backends computing the same GGML ops are consistent. +// For the backwards pass it asserts that the gradients from backpropagation are consistent +// with the gradients obtained via the method of finite differences ("grad" mode, this is optional). +// It is also possible to check the performance ("perf" mode). +// +// this file has three sections: Section 1 does general setup, section 2 defines the GGML ops to be tested, +// and section 3 defines which tests to run. +// Quick start for adding a new GGML op: Go to section 2 and create a struct that inherits from test_case, +// then go to section 3 and add an instantiation of your struct. + + +// ############################## +// ## Section 1: General Setup ## +// ############################## + + #include #include #include @@ -5,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -212,6 +230,39 @@ static double nmse(const float * a, const float * b, size_t n) { return mse_a_b / mse_a_0; } +// maximum absolute asymmetry between a and b +// asymmetry: (a - b) / (a + b) +// This is more stable than relative error if one of the values fluctuates towards zero. +// n: number of values to compare. +// expected_vals: optional vector of expected values for a. If expected_vals is not empty, filter out all comparisons where +// a does not match any of the expected values. Needed for noncontinuous gradients where the numerical calculation can fail. +static double mean_abs_asymm(const float * a, const float * b, const size_t n, const std::vector & expected_vals) { + double sum = 0.0f; + + size_t nvalid = 0; + for (size_t i = 0; i < n; i++) { + if (!expected_vals.empty()) { + bool matches_any = false; + for (const float & ev : expected_vals) { + if (fabsf(a[i] - ev) < 1e-3f) { + matches_any = true; + break; + } + } + if (!matches_any) { + continue; + } + } + + const float asymm = (a[i] - b[i]) / (a[i] + b[i]); + + sum += fabsf(asymm); + nvalid++; + } + + return sum/nvalid; +} + // utils for printing the variables of the test cases #define VAR_TO_STR(x) (#x "=" + var_to_str(x)) @@ -295,6 +346,7 @@ static bool ggml_is_view_op(enum ggml_op op) { enum test_mode { MODE_TEST, MODE_PERF, + MODE_GRAD, }; struct test_case { @@ -314,6 +366,32 @@ struct test_case { return 1e-7; } + virtual double max_maa_err() { + return 1e-4; + } + + virtual float grad_eps(){ + return 1e-1f; + } + + // If false, estimate gradient with 2 points, neglects 3rd order derivative and higher. + // If true, estimate gradient with 4 points, neglects 5th order derivative and higher. + virtual bool grad_precise(){ + return false; + } + + // Skip gradient checks if total number of gradients to be checked is larger than this (to speed up the tests). + virtual int64_t grad_nmax() { + return 10000; + } + + // No effect if empty. + // If not empty, skip all gradient checks where the numerical result does not match any of the values. + // Needed for dealing with noncontinuous gradients (e.g. ReLU) where estimation using finite differences is unreliable. + virtual std::vector grad_expect() { + return {}; + } + virtual void initialize_tensors(ggml_context * ctx) { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t); @@ -332,6 +410,7 @@ struct test_case { } ggml_cgraph * gf = nullptr; + ggml_cgraph * gb = nullptr; static const int sentinel_size = 1024; @@ -340,7 +419,7 @@ struct test_case { std::vector sentinels; void add_sentinel(ggml_context * ctx) { - if (mode == MODE_PERF) { + if (mode == MODE_PERF || mode == MODE_GRAD) { return; } ggml_tensor * sentinel = ::ggml_new_tensor_1d(ctx, GGML_TYPE_F32, sentinel_size); @@ -389,6 +468,7 @@ struct test_case { /* .no_alloc = */ true, }; ggml_context * ctx = ggml_init(params); + GGML_ASSERT(ctx); gf = ggml_new_graph(ctx); @@ -550,6 +630,7 @@ struct test_case { /* .no_alloc = */ true, }; ggml_context * ctx = ggml_init(params); + GGML_ASSERT(ctx); ggml_tensor * out = build_graph(ctx); @@ -643,8 +724,282 @@ struct test_case { return true; } + + bool eval_grad(ggml_backend_t backend, const char * op_name) { + mode = MODE_GRAD; + const std::vector expect = grad_expect(); + + ggml_init_params params = { + /* .mem_size = */ ggml_tensor_overhead()*128 + 2*ggml_graph_overhead_custom(GGML_DEFAULT_GRAPH_SIZE, true), + /* .mem_base = */ NULL, + /* .no_alloc = */ true, + }; + ggml_context * ctx = ggml_init(params); + GGML_ASSERT(ctx); + + gf = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true); + gb = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true); + + ggml_tensor * out = build_graph(ctx); + + if (op_name != nullptr && op_desc(out) != op_name) { + //printf(" %s: skipping\n", op_desc(out).c_str()); + ggml_free(ctx); + return true; + } + + printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str()); + fflush(stdout); + + if (out->grad == nullptr) { + printf("backwards pass not supported \n"); + ggml_free(ctx); + return true; + } + if (out->type != GGML_TYPE_F32) { + ggml_free(ctx); + printf("not supported [%s->type != FP32]\n", out->name); + return true; + } + + // check if the backend supports the ops + bool supported = true; + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (!ggml_backend_supports_op(backend, t)) { + printf("not supported [%s] ", ggml_backend_name(backend)); + supported = false; + break; + } + if ((t->flags & GGML_TENSOR_FLAG_PARAM) && t->type != GGML_TYPE_F32) { + printf("not supported [%s->type != FP32] ", t->name); + supported = false; + break; + } + } + if (!supported) { + printf("\n"); + ggml_free(ctx); + return true; + } + + int64_t ngrads = 0; + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->flags & GGML_TENSOR_FLAG_PARAM) { + ngrads += ggml_nelements(t); + } + } + if (ngrads > grad_nmax()) { + printf("skipping large tensors for speed \n"); + ggml_free(ctx); + return true; + } + + + if (!ggml_is_scalar(out)) { + out = ggml_sum(ctx, out); + ggml_set_name(out, "sum_of_out"); + } + + ggml_build_forward_expand(gf, out); + ggml_graph_cpy(gf, gb); + ggml_build_backward_expand(ctx, gf, gb, false); + if (expect.size() != 1 || expect[0] != 0.0f) { + GGML_ASSERT(gb->n_nodes > gf->n_nodes); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + GGML_ASSERT(!(t->flags & GGML_TENSOR_FLAG_PARAM) || t->grad->op != GGML_OP_NONE); + } + } + + // TODO: refactor so that this check is only needed once + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (!ggml_backend_supports_op(backend, t)) { + printf("not supported [%s] ", ggml_backend_name(backend)); + supported = false; + break; + } + if ((t->flags & GGML_TENSOR_FLAG_PARAM) && t->type != GGML_TYPE_F32) { + printf("not supported [%s->type != FP32] ", t->name); + supported = false; + break; + } + } + if (!supported) { + printf("\n"); + ggml_free(ctx); + return true; + } + + // allocate + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend); + if (buf == NULL) { + printf("failed to allocate tensors [%s] ", ggml_backend_name(backend)); + ggml_free(ctx); + return false; + } + + // randomize tensors + initialize_tensors(ctx); + + for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + if (!t->grad) { + continue; + } + + std::vector tmp(ggml_nelements(t->grad)); + ggml_backend_tensor_set(t->grad, tmp.data(), 0, ggml_nbytes(t->grad)); + } + + // build graphs + const float onef = 1.0f; + ggml_backend_graph_compute(backend, gf); + ggml_backend_tensor_set(out->grad, &onef, 0, ggml_nbytes(out->grad)); + ggml_backend_graph_compute(backend, gb); + + bool ok = true; + for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + if (!(t->flags & GGML_TENSOR_FLAG_PARAM)) { + continue; + } + + const char * bn = ggml_backend_name(backend); + const int64_t ne = ggml_nelements(t); + + std::vector ga = tensor_to_float(t->grad); + + for (int64_t i = 0; i < ne; ++i) { // gradient algebraic + // check for nans + if (!std::isfinite(ga[i])) { + printf("[%s] nonfinite gradient at index %zu (%s=%f) ", ggml_op_desc(t), i, bn, ga[i]); + ok = false; + break; + } + } + if (!ok) { + break; + } + + std::vector gn(ne); // gradient numeric + GGML_ASSERT(ga.size() == gn.size()); + + std::vector x0 = tensor_to_float(t); // original t data + GGML_ASSERT(ggml_is_scalar(out)); + GGML_ASSERT(out->type == GGML_TYPE_F32); + + const float eps = grad_eps(); + for (int64_t i = 0; i < ne; ++i) { + const float xiu = x0[i] + 1.0f*eps; // x, index i, up + const float xiuh = x0[i] + 0.5f*eps; // x, index i, up half + const float xidh = x0[i] - 0.5f*eps; // x, index i, down half + const float xid = x0[i] - 1.0f*eps; // x, index i, down + + float fu, fuh, fdh, fd; // output values for xiu, xiuh, xid, xidh + + ggml_backend_tensor_set(t, &xiu, i*sizeof(float), sizeof(float)); + ggml_backend_graph_compute(backend, gf); + ggml_backend_tensor_get(out, &fu, 0, ggml_nbytes(out)); + + ggml_backend_tensor_set(t, &xid, i*sizeof(float), sizeof(float)); + ggml_backend_graph_compute(backend, gf); + ggml_backend_tensor_get(out, &fd, 0, ggml_nbytes(out)); + + if (grad_precise()) { + ggml_backend_tensor_set(t, &xiuh, i*sizeof(float), sizeof(float)); + ggml_backend_graph_compute(backend, gf); + ggml_backend_tensor_get(out, &fuh, 0, ggml_nbytes(out)); + + ggml_backend_tensor_set(t, &xidh, i*sizeof(float), sizeof(float)); + ggml_backend_graph_compute(backend, gf); + ggml_backend_tensor_get(out, &fdh, 0, ggml_nbytes(out)); + + gn[i] = (8.0*(double)fuh + (double)fd - (8.0*(double)fdh + (double)fu)) / (6.0*(double)eps); + } else { + gn[i] = (fu - fd) / (2.0f*eps); + } + + ggml_backend_tensor_set(t, x0.data(), 0, ggml_nbytes(t)); + } + + const double err = mean_abs_asymm(gn.data(), ga.data(), gn.size(), expect); + if (err > max_maa_err()) { + printf("[%s] MAA = %.9f > %.9f ", ggml_op_desc(t), err, max_maa_err()); + ok = false; + break; + } + if (!ok) { + break; + } + } + + if (!ok) { + printf("compare failed "); + } + + ggml_backend_buffer_free(buf); + + ggml_free(ctx); + + if (ok) { + printf("\033[1;32mOK\033[0m\n"); + return true; + } + + printf("\033[1;31mFAIL\033[0m\n"); + return false; + } }; + +// ################################### +// ## Section 2: GGML Op Defintions ## +// ################################### + + +// The following is an example showing the bare minimum for creating a test for a GGML op. + +// GGML_OP_EXAMPLE +struct test_example : public test_case { + // Always define these 2 or variants thereof: + const ggml_type type; // The type of the input tensors. + const std::array ne; // The shape of the input tensors. + // For some ops it's necessary to define multiple types or shapes for the inputs. + // Or they may need additional parameters. + + // Put all parameters needed to fully define the test into one of the VARS_TO_STR macros. + // In most cases these are just the properties of the struct that you defined above. + // This is needed for info prints. + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + // Define a constructor for the struct. + // In most cases it will be sufficient to have the same arguments as the struct has properties + // and just use initializer lists. + test_example(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 5, 4, 3}) + : type(type), ne(ne) {} + + // Define how a simple GGML compute graph can be constructed for the new GGML op. + ggml_tensor * build_graph(ggml_context * ctx) override { + // Step 1: create input tensors that don't depend on any other tensors: + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); // Setting names is optional but it's useful for debugging. + + ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(b, "b"); + + // Step 2: use the op that you want to test in the GGML compute graph. + ggml_tensor * out = ggml_add(ctx, a, b); // For this example we're just doing a simple addition. + ggml_set_name(out, "out"); + + // Step 3: return the output tensor. + return out; + } + // In order to also check the gradients for your op, add calls like ggml_set_param(ctx, a) + // immediately after you create the tensors. + // This is optional and only makes sense if a backwards pass has actually been implemented for the new op. +}; + + // GGML_OP_UNARY struct test_unary : public test_case { const ggml_unary_op op; @@ -658,20 +1013,36 @@ struct test_unary : public test_case { test_unary(ggml_unary_op op, ggml_type type = GGML_TYPE_F32, - std::array ne_a = {128, 10, 10, 10}, + std::array ne_a = {128, 2, 2, 2}, int v = 0) : op(op), type(type), ne_a(ne_a), v(v) {} ggml_tensor * build_graph(ggml_context * ctx) override { + const bool grad_supported = op == GGML_UNARY_OP_ABS || op == GGML_UNARY_OP_SGN || op == GGML_UNARY_OP_NEG || + op == GGML_UNARY_OP_STEP || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU; + ggml_tensor * a; if (v & 1) { auto ne = ne_a; ne[0] *= 3; a = ggml_new_tensor(ctx, type, 4, ne.data()); + if (grad_supported) { + ggml_set_param(ctx, a); + } + ggml_set_name(a, "a"); + a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0); + ggml_set_name(a, "view_of_a"); } else { a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + if (grad_supported) { + ggml_set_param(ctx, a); + } + ggml_set_name(a, "a"); } + ggml_tensor * out = ggml_unary(ctx, a, op); + ggml_set_name(out, "out"); + return out; } @@ -681,6 +1052,24 @@ struct test_unary : public test_case { init_tensor_uniform(t, -150.f, 150.f); } } + + float grad_eps() override { + return 15.0f; + } + + std::vector grad_expect() override { + if (op == GGML_UNARY_OP_ABS) { + return {-1.0f, 1.0f}; + } + if (op == GGML_UNARY_OP_SGN || op == GGML_UNARY_OP_STEP) { + return {0.0f}; + } + if (op == GGML_UNARY_OP_RELU) { + return {0.0f, 1.0f}; + } + return {}; + } + }; // GGML_OP_GET_ROWS @@ -701,11 +1090,24 @@ struct test_get_rows : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * in = ggml_new_tensor_3d(ctx, type, n, m, b); + ggml_set_name(in, "in"); + ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b); + ggml_set_name(rows, "rows"); if (v) { rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0); + ggml_set_name(rows, "view_of_rows"); } + + const bool grad_supported = ggml_is_matrix(in) && ggml_is_vector(rows); + if (grad_supported) { + ggml_set_param(ctx, in); + // rows is a constant input -> no gradients + } + ggml_tensor * out = ggml_get_rows(ctx, in, rows); + ggml_set_name(out, "out"); + return out; } @@ -741,14 +1143,21 @@ struct test_repeat : public test_case { } test_repeat(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 10, 10, 10}, + std::array ne = {10, 5, 4, 3}, std::array nr = {2, 2, 2, 2}) : type(type), ne(ne), nr(nr) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * target = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]); + ggml_set_name(target, "target"); + ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(ctx, src); + ggml_set_name(src, "src"); + ggml_tensor * out = ggml_repeat(ctx, src, target); + ggml_set_name(out, "out"); + return out; } }; @@ -774,10 +1183,62 @@ struct test_dup : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(ctx, src); + ggml_set_name(src, "src"); + if (_use_permute) { src = ggml_permute(ctx, src, permute[0], permute[1], permute[2], permute[3]); + ggml_set_name(src, "src_permuted"); } + ggml_tensor * out = ggml_dup(ctx, src); + ggml_set_name(out, "out"); + + return out; + } +}; + +// GGML_OP_SET +struct test_set : public test_case { + const ggml_type type_src; + const ggml_type type_dst; + const std::array ne; + const int dim; + + std::string vars() override { + return VARS_TO_STR4(type_src, type_dst, ne, dim); + } + + size_t op_size(ggml_tensor * t) override { + return ggml_nbytes(t) + ggml_nbytes(t->src[0]); + } + + test_set(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32, + std::array ne = {6, 5, 4, 3}, int dim = 1) + : type_src(type_src), type_dst(type_dst), ne(ne), dim(dim) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data()); + ggml_set_param(ctx, src); + ggml_set_name(src, "src"); + + auto ne_dst = ne; + for (int i = 0; i < dim; ++i) { + ne_dst[i] *= 2; + } + ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data()); + ggml_set_param(ctx, dst); + ggml_set_name(dst, "dst"); + + size_t offset = 0; + for (int i = 0; i < dim; ++i) { + offset += ((ne_dst[i] - ne[i])/2)*dst->nb[i]; + } + ggml_tensor * out = ggml_set(ctx, dst, src, + // The backwards pass requires setting a contiguous region: + src->nb[1], src->nb[2], src->nb[3], offset); + ggml_set_name(out, "out"); + return out; } }; @@ -810,11 +1271,20 @@ struct test_cpy : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data()); + ggml_set_param(ctx, src); + ggml_set_name(src, "src"); + if (_src_use_permute) { src = ggml_permute(ctx, src, permute[0], permute[1], permute[2], permute[3]); + ggml_set_name(src, "src_permuted"); } + ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, src->ne); + ggml_set_name(dst, "dst"); + ggml_tensor * out = ggml_cpy(ctx, src, dst); + ggml_set_name(out, "out"); + return out; } }; @@ -834,8 +1304,14 @@ struct test_cont : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(ctx, src); + ggml_set_name(src, "src"); + src = ggml_transpose(ctx, src); + ggml_set_name(src, "src_transposed"); + ggml_tensor * out = ggml_cont(ctx, src); + ggml_set_name(out, "out"); return out; } @@ -866,21 +1342,79 @@ struct test_bin_bcast : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]); + ggml_set_name(a, "a"); + ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(b, "b"); + + // The backwards pass supports broadcasting only for GGML_ADD: + const bool grad_supported = op == ggml_add || ggml_are_same_shape(a, b); + if (grad_supported) { + ggml_set_param(ctx, a); + ggml_set_param(ctx, b); + } + ggml_tensor * out = op(ctx, a, b); + ggml_set_name(out, "out"); + return out; } void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { - if (op == ggml_div) { - // avoid division by zero - init_tensor_uniform(t, 1.0f, 2.0f); + if (op == ggml_mul || op == ggml_div) { + // MUL and DIV have numerical issues around zero: + init_tensor_uniform(t, 0.9f, 1.1f); } else { init_tensor_uniform(t); } } } + + float grad_eps() override { + return 0.1f * (op == ggml_mul ? ne[0]*ne[1]*ne[2]*ne[3] : 1); + } + + bool grad_precise() override { + return op == ggml_div; + } + + double max_maa_err() override { + return op == ggml_add ? 1e-4 : 1e-3; + } +}; + +// GGML_OP_ADD1 +struct test_add1 : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_add1(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 5, 4, 3}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(ctx, a); + ggml_set_name(a, "a"); + + ggml_tensor * b = ggml_new_tensor_1d(ctx, type, 1); + // ggml_set_param(ctx, b); // TODO: implement + ggml_set_name(b, "b"); + + ggml_tensor * out = ggml_add1(ctx, a, b); + ggml_set_name(out, "out"); + + return out; + } + + float grad_eps() override { + return 0.1f * ne[0]*ne[1]*ne[2]*ne[3]; + } }; // GGML_OP_SCALE @@ -900,7 +1434,12 @@ struct test_scale : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(ctx, a); + ggml_set_name(a, "a"); + ggml_tensor * out = ggml_scale(ctx, a, scale); + ggml_set_name(out, "out"); + return out; } }; @@ -916,13 +1455,17 @@ struct test_norm : public test_case { } test_norm(ggml_type type = GGML_TYPE_F32, - std::array ne = {64, 10, 10, 10}, + std::array ne = {64, 5, 4, 3}, float eps = 1e-6f) : type(type), ne(ne), eps(eps) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + ggml_tensor * out = ggml_norm(ctx, a, eps); + ggml_set_name(out, "out"); + return out; } }; @@ -938,15 +1481,24 @@ struct test_rms_norm : public test_case { } test_rms_norm(ggml_type type = GGML_TYPE_F32, - std::array ne = {64, 10, 10, 10}, + std::array ne = {64, 5, 4, 3}, float eps = 1e-6f) : type(type), ne(ne), eps(eps) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(ctx, a); + ggml_set_name(a, "a"); + ggml_tensor * out = ggml_rms_norm(ctx, a, eps); + ggml_set_name(out, "out"); + return out; } + + bool grad_precise() override { + return true; + } }; // GGML_OP_SSM_CONV @@ -1038,7 +1590,14 @@ struct test_mul_mat : public test_case { // C^T = A * B^T: (k, m) * (k, n) => (m, n) ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0] , bs[1]); ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); + ggml_set_param(ctx, a); + ggml_set_param(ctx, b); + ggml_set_name(a, "a"); + ggml_set_name(b, "b"); + ggml_tensor * out = ggml_mul_mat(ctx, a, b); + ggml_set_name(out, "out"); + return out; } }; @@ -1082,12 +1641,21 @@ struct test_mul_mat_id : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { // C^T = A * B^T: (k, m) * (k, n) => (m, n) ggml_tensor * as = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats); + ggml_set_name(as, "as"); + ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n); + ggml_set_name(ids, "ids"); if (n_used != n_mats) { ids = ggml_view_2d(ctx, ids, n_used, n, ids->nb[1], 0); + ggml_set_name(ids, "view_of_ids"); } + ggml_tensor * b = ggml_new_tensor_3d(ctx, type_b, k, this->b ? 1 : n_used, n); + ggml_set_name(b, "b"); + ggml_tensor * out = ggml_mul_mat_id(ctx, as, b, ids); + ggml_set_name(out, "out"); + return out; } @@ -1123,14 +1691,23 @@ struct test_sqr : public test_case { } test_sqr(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 10, 10, 10}) + std::array ne = {10, 5, 4, 3}) : type(type), ne(ne) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(ctx, a); + ggml_set_name(a, "a"); + ggml_tensor * out = ggml_sqr(ctx, a); + ggml_set_name(out, "out"); + return out; } + + float grad_eps() override { + return 0.1f * 0.25f*ne[0]*ne[1]*ne[2]*ne[3]; // 10% of expected value of sum. + } }; // GGML_OP_SQRT @@ -1143,21 +1720,70 @@ struct test_sqrt : public test_case { } test_sqrt(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 10, 10, 10}) + std::array ne = {10, 3, 3, 2}) : type(type), ne(ne) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(ctx, a); + ggml_set_name(a, "a"); + ggml_tensor * out = ggml_sqrt(ctx, a); + ggml_set_name(out, "out"); + return out; } void initialize_tensors(ggml_context * ctx) override { // fill with positive values for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { - init_tensor_uniform(t, 0.0f, 100.0f); + init_tensor_uniform(t, 50.0f, 100.0f); } } + + float grad_eps() override { + return 20.0f; + } + + bool grad_precise() override { + return true; + } +}; + +// GGML_OP_LOG +struct test_log : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_log(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 5, 4, 3}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(ctx, a); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_log(ctx, a); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + // log(1) == 0, cluster values there to keep the sum low for better precision in the backwards pass: + init_tensor_uniform(t, 0.9f, 1.1f); + } + } + + bool grad_precise() override { + return true; + } }; // GGML_OP_SIN @@ -1170,20 +1796,37 @@ struct test_sin : public test_case { } test_sin(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 10, 10, 10}) + std::array ne = {10, 2, 2, 2}) : type(type), ne(ne) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(ctx, a); + ggml_set_name(a, "a"); + ggml_tensor * out = ggml_sin(ctx, a); + ggml_set_name(out, "out"); + return out; } void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { - init_tensor_uniform(t, -100.0f, 100.0f); + init_tensor_uniform(t, -6.5f, 6.5f); // Covers interval [-2*pi, 2*pi]. } } + + double max_maa_err() override { + return 1e-3; + } + + float grad_eps() override { + return 0.2f; + } + + bool grad_precise() override { + return true; + } }; // GGML_OP_COS @@ -1196,20 +1839,37 @@ struct test_cos : public test_case { } test_cos(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 10, 10, 10}) + std::array ne = {10, 2, 2, 2}) : type(type), ne(ne) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(ctx, a); + ggml_set_name(a, "a"); + ggml_tensor * out = ggml_cos(ctx, a); + ggml_set_name(out, "out"); + return out; } void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { - init_tensor_uniform(t, -100.0f, 100.0f); + init_tensor_uniform(t, -6.5f, 6.5f); // Covers interval [-2*pi, 2*pi]. } } + + double max_maa_err() override { + return 1e-3; + } + + float grad_eps() override { + return 0.2f; + } + + bool grad_precise() override { + return true; + } }; // GGML_OP_CLAMP @@ -1224,15 +1884,27 @@ struct test_clamp : public test_case { } test_clamp(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 10, 10, 10}, + std::array ne = {10, 5, 4, 3}, float min = -0.5f, float max = 0.5f) : type(type), ne(ne), min(min), max(max) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + ggml_tensor * out = ggml_clamp(ctx, a, min, max); + ggml_set_name(out, "out"); + return out; } + + float grad_eps() override { + return 1e-2f; + } + + std::vector grad_expect() override { + return {0.0f, 1.0f}; + } }; // GGML_OP_DIAG_MASK_INF @@ -1246,13 +1918,18 @@ struct test_diag_mask_inf : public test_case { } test_diag_mask_inf(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 10, 10, 10}, + std::array ne = {10, 10, 3, 2}, int n_past = 5) : type(type), ne(ne), n_past(n_past) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(ctx, a); + ggml_set_name(a, "a"); + ggml_tensor * out = ggml_diag_mask_inf(ctx, a, n_past); + ggml_set_name(out, "out"); + return out; } }; @@ -1276,7 +1953,7 @@ struct test_soft_max : public test_case { } test_soft_max(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 10, 10, 10}, + std::array ne = {10, 5, 4, 3}, bool mask = false, float scale = 1.0f, float max_bias = 0.0f) @@ -1284,13 +1961,24 @@ struct test_soft_max : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(ctx, a); + ggml_set_name(a, "a"); + ggml_tensor * mask = nullptr; if (this->mask) { mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]); + ggml_set_name(mask, "mask"); } + ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias); + ggml_set_name(out, "out"); + return out; } + + bool grad_precise() override { + return true; + } }; @@ -1312,7 +2000,7 @@ struct test_rope : public test_case { } test_rope(ggml_type type = GGML_TYPE_F32, - std::array ne_a = {10, 10, 10, 1}, + std::array ne_a = {10, 5, 3, 1}, int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f, float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0) : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v) {} @@ -1321,13 +2009,29 @@ struct test_rope : public test_case { if (v & 1) { auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3; a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(ctx, a); + ggml_set_name(a, "a"); + a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0); + ggml_set_name(a, "view_of_a"); } else { a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_set_param(ctx, a); + ggml_set_name(a, "a"); } + ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]); - ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr; + ggml_set_name(pos, "pos"); + + ggml_tensor * freq = nullptr; + if (ff) { + freq = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2); + ggml_set_name(freq, "freq"); + } + ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + ggml_set_name(out, "out"); + return out; } @@ -1350,6 +2054,14 @@ struct test_rope : public test_case { } } } + + double max_maa_err() override { + return 1e-3; + } + + bool grad_precise() override { + return true; + } }; // GGML_OP_POOL2D @@ -1381,7 +2093,12 @@ struct test_pool2d : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data()); + ggml_set_param(ctx, input); + ggml_set_name(input, "input"); + ggml_tensor * out = ggml_pool_2d(ctx, input, pool_type, k0, k1, s0, s1, p0, p1); + ggml_set_name(out, "out"); + return out; } }; @@ -1406,8 +2123,14 @@ struct test_conv_transpose_1d : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data()); + ggml_set_name(input, "input"); + ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data()); + ggml_set_name(kernel, "kernel"); + ggml_tensor * out = ggml_conv_transpose_1d(ctx, kernel, input, s0, p0, d0); + ggml_set_name(out, "out"); + return out; } }; @@ -1446,8 +2169,15 @@ struct test_im2col : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data()); + ggml_set_param(ctx, input); + ggml_set_name(input, "input"); + ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data()); + ggml_set_name(kernel, "kernel"); + ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D, dst_type); + ggml_set_name(out, "out"); + return out; } }; @@ -1465,8 +2195,8 @@ struct test_concat : public test_case { } test_concat(ggml_type type = GGML_TYPE_F32, - std::array ne_a = {10, 10, 10, 10}, - int64_t ne_b_d = 10, + std::array ne_a = {10, 5, 5, 5}, + int64_t ne_b_d = 5, int dim = 2, int v = 0) : type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim), v(v) {} @@ -1477,19 +2207,30 @@ struct test_concat : public test_case { if (v & 1) { auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3; a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0); + ggml_set_name(a, "view_of_a"); } else { a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_set_name(a, "a"); } ggml_tensor * b; if (v & 2) { auto ne = ne_b; ne[0] *= 3; ne[1] *= 2; ne[2] *= 4; b = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(b, "b"); + b = ggml_view_4d(ctx, b, ne_b[0], ne_b[1], ne_b[2], ne_b[3], b->nb[1], b->nb[2], b->nb[3], 0); + ggml_set_name(b, "view_of_b"); } else { b = ggml_new_tensor(ctx, type, 4, ne_b.data()); + ggml_set_name(b, "b"); } + ggml_tensor * out = ggml_concat(ctx, a, b, dim); + ggml_set_name(out, "out"); + return out; } }; @@ -1511,7 +2252,11 @@ struct test_argsort : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + ggml_tensor * out = ggml_argsort(ctx, a, order); + ggml_set_name(out, "out"); + return out; } @@ -1544,6 +2289,35 @@ struct test_argsort : public test_case { } }; +// GGML_OP_SUM +struct test_sum : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_sum(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 5, 4, 3}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(ctx, a); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_sum(ctx, a); + ggml_set_name(out, "out"); + + return out; + } + + float grad_eps() override { + return 0.1f * sqrtf(ne[0]*ne[1]*ne[2]*ne[3]); + } +}; + // GGML_OP_SUM_ROWS struct test_sum_rows : public test_case { const ggml_type type; @@ -1554,12 +2328,17 @@ struct test_sum_rows : public test_case { } test_sum_rows(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 10, 10, 10}) + std::array ne = {10, 5, 4, 3}) : type(type), ne(ne) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(ctx, a); + ggml_set_name(a, "a"); + ggml_tensor * out = ggml_sum_rows(ctx, a); + ggml_set_name(out, "out"); + return out; } }; @@ -1582,8 +2361,16 @@ struct test_upscale : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - if (transpose) a = ggml_transpose(ctx, a); + ggml_set_name(a, "a"); + + if (transpose) { + a = ggml_transpose(ctx, a); + ggml_set_name(a, "a_transposed"); + } + ggml_tensor * out = ggml_upscale(ctx, a, scale_factor); + ggml_set_name(out, "out"); + return out; } }; @@ -1605,7 +2392,11 @@ struct test_upscale_ext : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3]); + ggml_set_name(out, "out"); + return out; } }; @@ -1629,7 +2420,11 @@ struct test_group_norm : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + ggml_tensor * out = ggml_group_norm(ctx, a, num_groups, eps); + ggml_set_name(out, "out"); + return out; } }; @@ -1645,14 +2440,22 @@ struct test_acc : public test_case { } test_acc(ggml_type type = GGML_TYPE_F32, - std::array ne_a = {1024, 577, 1, 1}, - std::array ne_b = {1024, 576, 1, 1}) + std::array ne_a = {256, 17, 1, 1}, + std::array ne_b = {256, 16, 1, 1}) : type(type), ne_a(ne_a), ne_b(ne_b) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_set_param(ctx, a); + ggml_set_name(a, "a"); + ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data()); + ggml_set_param(ctx, b); + ggml_set_name(b, "b"); + ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], b->nb[1]); + ggml_set_name(out, "out"); + return out; } }; @@ -1675,7 +2478,11 @@ struct test_pad : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_set_name(a, "a"); + ggml_tensor * out = ggml_pad(ctx, a, pad_0, pad_1, 0, 0); + ggml_set_name(out, "out"); + return out; } }; @@ -1697,6 +2504,8 @@ struct test_arange : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * out = ggml_arange(ctx, start, stop, step); + ggml_set_name(out, "out"); + return out; } }; @@ -1719,7 +2528,11 @@ struct test_timestep_embedding : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_set_name(a, "a"); + ggml_tensor * out = ggml_timestep_embedding(ctx, a, dim, max_period); + ggml_set_name(out, "out"); + return out; } }; @@ -1735,13 +2548,17 @@ struct test_leaky_relu : public test_case { } test_leaky_relu(ggml_type type = GGML_TYPE_F32, - std::array ne_a = {10, 10, 10, 10}, + std::array ne_a = {10, 5, 4, 3}, float negative_slope = 0.1f) : type(type), ne_a(ne_a), negative_slope(negative_slope) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_set_name(a, "a"); + ggml_tensor * out = ggml_leaky_relu(ctx, a, negative_slope, true); + ggml_set_name(out, "out"); + return out; } }; @@ -1768,19 +2585,37 @@ struct test_flash_attn_ext : public test_case { return 5e-4; } - test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16) + test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, + bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16) : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV) {} ggml_tensor * build_graph(ggml_context * ctx) override { const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV)); ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs_padded, nb, nh, 1); + ggml_set_name(q, "q"); + ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1); + ggml_set_name(k, "k"); + ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1); - ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr; + ggml_set_name(v, "v"); + + ggml_tensor * m = nullptr; + if (mask) { + m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1); + ggml_set_name(m, "m"); + } + ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap); + ggml_set_name(out, "out"); + return out; } + + bool grad_precise() override { + return true; + } }; // GGML_OP_CROSS_ENTROPY_LOSS @@ -1793,15 +2628,42 @@ struct test_cross_entropy_loss : public test_case { } test_cross_entropy_loss(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 10, 10, 10}) + std::array ne = {10, 5, 4, 3}) : type(type), ne(ne) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(ctx, logits); + ggml_set_name(logits, "logits"); + ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data()); + // The labels are assumed to be constant -> no gradients. + ggml_set_name(labels, "labels"); + + // Ensure labels add up to 1: + labels = ggml_soft_max(ctx, labels); + ggml_set_name(labels, "labels_normalized"); + ggml_tensor * out = ggml_cross_entropy_loss(ctx, logits, labels); + ggml_set_name(out, "out"); + return out; } + + void initialize_tensors(ggml_context * ctx) override { + // For larger abs. diffs between logits softmax is more linear, therefore more precise num. gradients. + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t, -100.0f, 100.0f); + } + } + + float grad_eps() override { + return 1.0f; + } + + bool grad_precise() override { + return true; + } }; enum llm_norm_type { @@ -2188,6 +3050,12 @@ struct test_falcon : public test_llm { } }; + +// ########################################### +// ## Section 3: GGML Op Test Instantiation ## +// ########################################### + + static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) { std::vector> test_cases; std::default_random_engine rng(0); @@ -2230,8 +3098,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op // unary ops for (int v : {0, 1}) { for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) { - test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 128, 10, 10, 10 }, v)); - test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 7, 13, 19, 23 }, v)); + test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 128, 2, 2, 2 }, v)); + test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 5, 7, 11, 13 }, v)); } } @@ -2267,11 +3135,13 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16)); // test cases for 1D im2col - test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); // sycl backend will limit task global_range < MAX_INT // test cases for 2D im2col with large input W and H (occurs in stable-diffusion) @@ -2290,13 +3160,13 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1)); - test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1})); - test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1})); - test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1})); - test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 2, 1})); - test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 2})); - test_cases.emplace_back(new test_repeat(GGML_TYPE_I32, {10, 10, 10, 10}, {2, 1, 1, 1})); - test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 10, 10, 10}, {1, 1, 1, 2})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 1})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 1, 1})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 1})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 2})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_I32, {10, 5, 4, 3}, {2, 1, 1, 1})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, 3}, {1, 1, 1, 2})); test_cases.emplace_back(new test_dup(GGML_TYPE_F32)); test_cases.emplace_back(new test_dup(GGML_TYPE_F16)); @@ -2309,6 +3179,10 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {1, 2, 0, 3})); + for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) { + test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim)); + } + for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) { for (ggml_type type_dst : all_types) { test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4})); @@ -2341,16 +3215,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 8, 1}, {1, 1, 1, 1}); add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1, 1}, {32, 1, 1, 1}); add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 320, 320}, {1, 1, 1, 1}); - add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1}); - add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1}); - add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 1}); - add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 1, 1, 1}); - add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 1, 1}); - add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 1}); - add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 2}); - add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 2}); - add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 2, 2}); - add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 2, 2, 2}); + add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 1, 1}, {1, 1, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 1}, {1, 1, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 2}); + add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 2}); + add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 2, 2}); + add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 2, 2, 2}); // stable diffusion add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 1, 1, 1}); @@ -2369,11 +3243,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1}); //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1}); + test_cases.emplace_back(new test_add1()); test_cases.emplace_back(new test_scale()); for (float eps : {1e-6f, 1e-5f, 1e-3f, 1e-1f}) { - test_cases.emplace_back(new test_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps)); - test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps)); + test_cases.emplace_back(new test_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); + test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); } test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1})); @@ -2477,13 +3352,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_sqr()); test_cases.emplace_back(new test_sqrt()); + test_cases.emplace_back(new test_log()); test_cases.emplace_back(new test_sin()); test_cases.emplace_back(new test_cos()); test_cases.emplace_back(new test_clamp()); - test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5)); - test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 1}, 5)); - test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5)); + test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5)); + test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 3, 1}, 5)); + test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 3, 2}, 5)); #if 0 std::uniform_int_distribution<> dist_ne1(1, 50); @@ -2527,23 +3403,23 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (float af : { 1.0f, 1.4245f }) { for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { for (bool ff : {false, true}) { // freq_factors - test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 7B + test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 7B if (all) { - test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 13B - test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 30B - test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 65B + test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 13B + test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 30B + test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 65B } if (all) { - test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B) - test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B) - test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B) - test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512, fs, ef, af, ff, v)); // neox (stablelm) - test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512, fs, ef, af, ff, v)); // neox (phi-2) + test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B) + test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B) + test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B) + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v)); // neox (stablelm) + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v)); // neox (phi-2) } - test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B) + test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B) } } @@ -2567,6 +3443,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen } + test_cases.emplace_back(new test_sum()); test_cases.emplace_back(new test_sum_rows()); test_cases.emplace_back(new test_upscale()); test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true)); @@ -2609,6 +3486,18 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op #endif // run tests + if (mode == MODE_GRAD) { + size_t n_ok = 0; + for (auto & test : test_cases) { + if (test->eval_grad(backend, op_name)) { + n_ok++; + } + } + printf(" %zu/%zu tests passed\n", n_ok, test_cases.size()); + + return n_ok == test_cases.size(); + } + if (mode == MODE_TEST) { ggml_backend_t backend_cpu = ggml_backend_cpu_init(); @@ -2637,8 +3526,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op static void usage(char ** argv) { printf("Usage: %s [mode] [-o op] [-b backend]\n", argv[0]); - printf(" valid modes are: test (compare with CPU backend for correctness) or perf (performance evaluation)\n"); - printf(" op names are as given by ggml_op_desc()\n"); + printf(" valid modes:\n"); + printf(" - test (default, compare with CPU backend for correctness)\n"); + printf(" - perf (performance evaluation)\n"); + printf(" - grad (compare gradients from backpropagation with method of finite differences)\n"); + printf(" op names are as given by ggml_op_desc() (e.g. GGML_ADD)\n"); } int main(int argc, char ** argv) { @@ -2651,6 +3543,8 @@ int main(int argc, char ** argv) { mode = MODE_TEST; } else if (strcmp(argv[i], "perf") == 0) { mode = MODE_PERF; + } else if (strcmp(argv[i], "grad") == 0) { + mode = MODE_GRAD; } else if (strcmp(argv[i], "-o") == 0) { if (i + 1 < argc) { op_name_filter = argv[++i]; @@ -2688,7 +3582,7 @@ int main(int argc, char ** argv) { ggml_backend_t backend = ggml_backend_reg_init_backend(i, NULL); GGML_ASSERT(backend != NULL); - if (backend_filter == NULL && ggml_backend_is_cpu(backend)) { + if (backend_filter == NULL && ggml_backend_is_cpu(backend) && mode != MODE_GRAD) { printf(" Skipping CPU backend\n"); ggml_backend_free(backend); n_ok++;