diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c18ff07ea..0ce498e9e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1418,6 +1418,48 @@ struct test_flash_attn_ext : public test_case { } }; +// Attention +struct test_attn : public test_case { + const int64_t hs; // head size + const int64_t nh; // num heads + const int64_t kv; // kv size + const int64_t nb; // batch size + + std::string op_desc(ggml_tensor * t) override { + return "ATTN"; + + GGML_UNUSED(t); + } + + std::string vars() override { + return VARS_TO_STR4(hs, nh, kv, nb); + } + + double max_nmse_err() override { + return 5e-4; + } + + test_attn(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : hs(hs), nh(nh), kv(kv), nb(nb) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); + ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); // transposed + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1); + + struct ggml_tensor * cur; + + cur = ggml_mul_mat (ctx, k, q); + cur = ggml_soft_max_ext(ctx, cur, mask, 1.0f/sqrtf(hs)); + cur = ggml_mul_mat (ctx, v, cur); + cur = ggml_permute (ctx, cur, 0, 2, 1, 3); + cur = ggml_cont_2d (ctx, cur, hs*nh, nb); + + return cur; + } +}; + // Mixtral MOE struct test_moe : public test_case { const int n_experts; @@ -1684,15 +1726,25 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 8)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 7)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 1)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 8)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 7)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 1)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 8)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 7)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 1)); + test_cases.emplace_back(new test_attn(64, 32, 512, 8)); + test_cases.emplace_back(new test_attn(64, 32, 512, 7)); + test_cases.emplace_back(new test_attn(64, 32, 512, 1)); + test_cases.emplace_back(new test_attn(80, 32, 512, 8)); + test_cases.emplace_back(new test_attn(80, 32, 512, 7)); + test_cases.emplace_back(new test_attn(80, 32, 512, 1)); + test_cases.emplace_back(new test_attn(128, 32, 512, 8)); + test_cases.emplace_back(new test_attn(128, 32, 512, 7)); + test_cases.emplace_back(new test_attn(128, 32, 512, 1)); + + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 8)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 7)); + test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 1)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 8)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 7)); + test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 1)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 8)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 7)); + test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 1)); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer