mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-26 20:22:25 +01:00
tests : add non-cont unary tests
This commit is contained in:
parent
bfaa676b08
commit
ebf95c2225
@ -642,20 +642,29 @@ struct test_case {
|
||||
struct test_unary : public test_case {
|
||||
const ggml_unary_op op;
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const std::array<int64_t, 4> ne_a;
|
||||
int v; // view (1 : non-contiguous a)
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR2(type, ne);
|
||||
return VARS_TO_STR3(type, ne_a, v);
|
||||
}
|
||||
|
||||
test_unary(ggml_unary_op op,
|
||||
ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {128, 10, 10, 10})
|
||||
: op(op), type(type), ne(ne) {}
|
||||
std::array<int64_t, 4> ne_a = {128, 10, 10, 10},
|
||||
int v = 0)
|
||||
: op(op), type(type), ne_a(ne_a), v(v) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * in = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_tensor * out = ggml_unary(ctx, in, op);
|
||||
ggml_tensor * a;
|
||||
if (v & 1) {
|
||||
auto ne = ne_a; ne[0] *= 3;
|
||||
a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
|
||||
} else {
|
||||
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||
}
|
||||
ggml_tensor * out = ggml_unary(ctx, a, op);
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -2016,9 +2025,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
};
|
||||
|
||||
// unary ops
|
||||
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
|
||||
test_cases.emplace_back(new test_unary((ggml_unary_op) op));
|
||||
test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 7, 13, 19, 23 }));
|
||||
for (int v : {0, 1}) {
|
||||
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
|
||||
test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 128, 10, 10, 10 }, v));
|
||||
test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 7, 13, 19, 23 }, v));
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
|
||||
|
Loading…
Reference in New Issue
Block a user