mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 07:34:18 +01:00
tests : add ATTN tests
This commit is contained in:
parent
1db22d7032
commit
4794821a31
@ -1418,6 +1418,48 @@ struct test_flash_attn_ext : public test_case {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Attention
|
||||||
|
struct test_attn : public test_case {
|
||||||
|
const int64_t hs; // head size
|
||||||
|
const int64_t nh; // num heads
|
||||||
|
const int64_t kv; // kv size
|
||||||
|
const int64_t nb; // batch size
|
||||||
|
|
||||||
|
std::string op_desc(ggml_tensor * t) override {
|
||||||
|
return "ATTN";
|
||||||
|
|
||||||
|
GGML_UNUSED(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string vars() override {
|
||||||
|
return VARS_TO_STR4(hs, nh, kv, nb);
|
||||||
|
}
|
||||||
|
|
||||||
|
double max_nmse_err() override {
|
||||||
|
return 5e-4;
|
||||||
|
}
|
||||||
|
|
||||||
|
test_attn(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
|
||||||
|
: hs(hs), nh(nh), kv(kv), nb(nb) {}
|
||||||
|
|
||||||
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
|
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
|
||||||
|
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
||||||
|
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); // transposed
|
||||||
|
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1);
|
||||||
|
|
||||||
|
struct ggml_tensor * cur;
|
||||||
|
|
||||||
|
cur = ggml_mul_mat (ctx, k, q);
|
||||||
|
cur = ggml_soft_max_ext(ctx, cur, mask, 1.0f/sqrtf(hs));
|
||||||
|
cur = ggml_mul_mat (ctx, v, cur);
|
||||||
|
cur = ggml_permute (ctx, cur, 0, 2, 1, 3);
|
||||||
|
cur = ggml_cont_2d (ctx, cur, hs*nh, nb);
|
||||||
|
|
||||||
|
return cur;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Mixtral MOE
|
// Mixtral MOE
|
||||||
struct test_moe : public test_case {
|
struct test_moe : public test_case {
|
||||||
const int n_experts;
|
const int n_experts;
|
||||||
@ -1684,15 +1726,25 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
test_cases.emplace_back(new test_pad());
|
test_cases.emplace_back(new test_pad());
|
||||||
test_cases.emplace_back(new test_leaky_relu());
|
test_cases.emplace_back(new test_leaky_relu());
|
||||||
|
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 8));
|
test_cases.emplace_back(new test_attn(64, 32, 512, 8));
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 7));
|
test_cases.emplace_back(new test_attn(64, 32, 512, 7));
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 256*8, 1));
|
test_cases.emplace_back(new test_attn(64, 32, 512, 1));
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 8));
|
test_cases.emplace_back(new test_attn(80, 32, 512, 8));
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 7));
|
test_cases.emplace_back(new test_attn(80, 32, 512, 7));
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 256*8, 1));
|
test_cases.emplace_back(new test_attn(80, 32, 512, 1));
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 8));
|
test_cases.emplace_back(new test_attn(128, 32, 512, 8));
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 7));
|
test_cases.emplace_back(new test_attn(128, 32, 512, 7));
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256*8, 1));
|
test_cases.emplace_back(new test_attn(128, 32, 512, 1));
|
||||||
|
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 8));
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 7));
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 1));
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 8));
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 7));
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 1));
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 8));
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 7));
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 1));
|
||||||
|
|
||||||
#if !defined(__SANITIZE_THREAD__)
|
#if !defined(__SANITIZE_THREAD__)
|
||||||
// FIXME: these tests use too much memory with thread sanitizer
|
// FIXME: these tests use too much memory with thread sanitizer
|
||||||
|
Loading…
Reference in New Issue
Block a user