From ebf95c22255f5379fe0bd2935795b04eefad5909 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 10 Jun 2024 15:46:54 +0300 Subject: [PATCH] tests : add non-cont unary tests --- tests/test-backend-ops.cpp | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ce406a8af..2b48e623e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -642,20 +642,29 @@ struct test_case { struct test_unary : public test_case { const ggml_unary_op op; const ggml_type type; - const std::array ne; + const std::array ne_a; + int v; // view (1 : non-contiguous a) std::string vars() override { - return VARS_TO_STR2(type, ne); + return VARS_TO_STR3(type, ne_a, v); } test_unary(ggml_unary_op op, ggml_type type = GGML_TYPE_F32, - std::array ne = {128, 10, 10, 10}) - : op(op), type(type), ne(ne) {} + std::array ne_a = {128, 10, 10, 10}, + int v = 0) + : op(op), type(type), ne_a(ne_a), v(v) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * in = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_tensor * out = ggml_unary(ctx, in, op); + ggml_tensor * a; + if (v & 1) { + auto ne = ne_a; ne[0] *= 3; + a = ggml_new_tensor(ctx, type, 4, ne.data()); + a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0); + } else { + a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + } + ggml_tensor * out = ggml_unary(ctx, a, op); return out; } @@ -2016,9 +2025,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op }; // unary ops - for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) { - test_cases.emplace_back(new test_unary((ggml_unary_op) op)); - test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 7, 13, 19, 23 })); + for (int v : {0, 1}) { + for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) { + test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 128, 10, 10, 10 }, v)); + test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 7, 13, 19, 23 }, v)); + } } test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));