mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-27 04:23:06 +01:00
CPU/CUDA: fix (GQA) mul mat back, add CUDA support (#11380)
This commit is contained in:
parent
1af6945eb0
commit
8137b4bb2b
@ -7883,7 +7883,7 @@ static void ggml_compute_forward_out_prod_f32(
|
||||
|
||||
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
||||
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
|
||||
ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
|
||||
}
|
||||
@ -7892,7 +7892,7 @@ static void ggml_compute_forward_out_prod_f32(
|
||||
|
||||
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
||||
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
|
||||
ggml_vec_mad_f32(ne0, d, s0, *s1);
|
||||
}
|
||||
|
@ -416,7 +416,8 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
|
||||
case GGML_OP_IM2COL_BACK:
|
||||
return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
|
||||
case GGML_OP_OUT_PROD:
|
||||
return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32;
|
||||
return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) &&
|
||||
src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
|
@ -93,26 +93,31 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
|
||||
|
||||
template <typename T>
|
||||
static __global__ void k_repeat_back(
|
||||
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2) {
|
||||
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
|
||||
|
||||
const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
|
||||
const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y;
|
||||
const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z;
|
||||
const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
|
||||
const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
|
||||
const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;
|
||||
const int64_t tid2 = tid23 % ne2;
|
||||
const int64_t tid3 = tid23 / ne2;
|
||||
|
||||
if (tid0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
T sum = 0;
|
||||
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
|
||||
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
|
||||
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
|
||||
sum += src[i2*ne01*ne00 + i1*ne00 + i0];
|
||||
for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
|
||||
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
|
||||
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
|
||||
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
|
||||
sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
|
||||
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
|
||||
}
|
||||
|
||||
template<float (*bin_op)(const float, const float)>
|
||||
@ -274,12 +279,14 @@ struct bin_bcast_cuda {
|
||||
|
||||
template <typename T>
|
||||
static void repeat_back_cuda(
|
||||
const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
|
||||
const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
|
||||
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2);
|
||||
k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
|
||||
const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3);
|
||||
k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>
|
||||
(src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3);
|
||||
}
|
||||
|
||||
template<class op>
|
||||
@ -326,27 +333,26 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(src0->type == dst->type);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(ggml_can_repeat(dst, src0));
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
GGML_ASSERT(src0->ne[3] == 1);
|
||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
GGML_ASSERT(dst->ne[3] == 1);
|
||||
GGML_ASSERT(ne2*ne3 <= (1 << 15));
|
||||
|
||||
const size_t ts = ggml_type_size(src0->type);
|
||||
const size_t s00 = nb00 / ts;
|
||||
const size_t s01 = nb01 / ts;
|
||||
const size_t s02 = nb02 / ts;
|
||||
const size_t s03 = nb03 / ts;
|
||||
|
||||
switch (dst->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
repeat_back_cuda<float>(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
|
||||
repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ASSERT(false);
|
||||
|
@ -3002,7 +3002,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
||||
} break;
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;
|
||||
return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15);
|
||||
case GGML_OP_CONCAT:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
|
@ -34,6 +34,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(handle, stream));
|
||||
|
||||
const int64_t lda = nb01 / sizeof(float);
|
||||
const int64_t ldc = nb1 / sizeof(float);
|
||||
|
||||
const bool src1_T = ggml_is_transposed(src1);
|
||||
const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
|
||||
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
|
||||
@ -57,9 +60,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00,
|
||||
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
|
||||
src1_d + i3 *s13 + i2 *s12, ldb,
|
||||
&beta, dst_d + i3 *s3 + i2 *s2, ne0));
|
||||
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5339,7 +5339,7 @@ static void ggml_compute_backward(
|
||||
} break;
|
||||
case GGML_OP_MUL: {
|
||||
if (src0_needs_grads) {
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad));
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1));
|
||||
}
|
||||
if (src1_needs_grads) {
|
||||
struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
|
||||
@ -5431,21 +5431,25 @@ static void ggml_compute_backward(
|
||||
// src1.shape [n,p,qq,rr]
|
||||
|
||||
if (src0_needs_grads) {
|
||||
struct ggml_tensor * s1_tg =
|
||||
GGML_ASSERT(grad->ne[2] == src1->ne[2]);
|
||||
GGML_ASSERT(grad->ne[3] == src1->ne[3]);
|
||||
struct ggml_tensor * tmp =
|
||||
ggml_out_prod(ctx, // [n,m,qq,rr]
|
||||
src1, // [n,p,qq,rr]
|
||||
grad); // [m,p,qq,rr]
|
||||
const int64_t qq = s1_tg->ne[2];
|
||||
const int64_t rr = s1_tg->ne[3];
|
||||
const int64_t q1 = src0->ne[2];
|
||||
const int64_t r1 = src0->ne[3];
|
||||
const bool ne2_broadcasted = qq > q1;
|
||||
const bool ne3_broadcasted = rr > r1;
|
||||
if (ne2_broadcasted || ne3_broadcasted) {
|
||||
// sum broadcast repetitions of s1_tg into shape of src0
|
||||
s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
|
||||
if (!ggml_are_same_shape(tmp, src0)) {
|
||||
GGML_ASSERT(tmp->ne[0] == src0->ne[0]);
|
||||
GGML_ASSERT(tmp->ne[1] == src0->ne[1]);
|
||||
GGML_ASSERT(tmp->ne[3] == 1);
|
||||
|
||||
const int64_t nr2 = tmp->ne[2] / src0->ne[2];
|
||||
const size_t nb2 = tmp->nb[2] * nr2;
|
||||
const size_t nb3 = tmp->nb[2];
|
||||
|
||||
tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0);
|
||||
tmp = ggml_repeat_back(ctx, tmp, src0);
|
||||
}
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, tmp);
|
||||
}
|
||||
if (src1_needs_grads) {
|
||||
ggml_add_or_set(ctx, cgraph, isrc1,
|
||||
@ -5514,7 +5518,9 @@ static void ggml_compute_backward(
|
||||
if (src0_needs_grads) {
|
||||
GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
|
||||
GGML_ASSERT(ggml_is_contiguous(grad));
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, grad);
|
||||
GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0));
|
||||
ggml_add_or_set(ctx, cgraph, isrc0,
|
||||
ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0));
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_RESHAPE: {
|
||||
|
@ -1302,6 +1302,59 @@ struct test_repeat : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_REPEAT_BACK
|
||||
struct test_repeat_back : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const std::array<int, 4> nr;
|
||||
const bool v; // whether src is a noncontiguous view
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne, nr, v);
|
||||
}
|
||||
|
||||
size_t op_size(ggml_tensor * t) override {
|
||||
return ggml_nbytes(t) * 2;
|
||||
}
|
||||
|
||||
test_repeat_back(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {8, 6, 4, 2},
|
||||
std::array<int, 4> nr = {2, 2, 2, 2},
|
||||
bool v = false)
|
||||
: type(type), ne(ne), nr(nr), v(v) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * src = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
|
||||
ggml_set_name(src, "src");
|
||||
|
||||
if (v) {
|
||||
GGML_ASSERT(ne[0] % 2 == 0);
|
||||
GGML_ASSERT(ne[1] % 2 == 0);
|
||||
GGML_ASSERT(ne[2] % 2 == 0);
|
||||
GGML_ASSERT(ne[3] % 2 == 0);
|
||||
GGML_ASSERT(nr[0] % 2 == 0 || nr[0] == 1);
|
||||
GGML_ASSERT(nr[1] % 2 == 0 || nr[1] == 1);
|
||||
GGML_ASSERT(nr[2] % 2 == 0 || nr[2] == 1);
|
||||
GGML_ASSERT(nr[3] % 2 == 0 || nr[3] == 1);
|
||||
|
||||
const int64_t ne00 = nr[0] == 1 ? src->ne[0] : src->ne[0] / 2;
|
||||
const int64_t ne01 = nr[1] == 1 ? src->ne[1] : src->ne[1] / 2;
|
||||
const int64_t ne02 = nr[2] == 1 ? src->ne[2] : src->ne[2] / 2;
|
||||
const int64_t ne03 = nr[3] == 1 ? src->ne[3] : src->ne[3] / 2;
|
||||
|
||||
src = ggml_view_4d(ctx, src, ne00, ne01, ne02, ne03, src->nb[1], src->nb[2], src->nb[3], 0);
|
||||
}
|
||||
|
||||
ggml_tensor * target = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_name(target, "target");
|
||||
|
||||
ggml_tensor * out = ggml_repeat_back(ctx, src, target);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_DUP
|
||||
struct test_dup : public test_case {
|
||||
const ggml_type type;
|
||||
@ -1849,6 +1902,10 @@ struct test_mul_mat : public test_case {
|
||||
return 5e-4;
|
||||
}
|
||||
|
||||
int64_t grad_nmax() override {
|
||||
return 20000;
|
||||
}
|
||||
|
||||
uint64_t op_flops(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1];
|
||||
@ -1878,8 +1935,12 @@ struct test_mul_mat : public test_case {
|
||||
|
||||
a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]);
|
||||
b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);
|
||||
ggml_set_param(ctx, a);
|
||||
ggml_set_param(ctx, b);
|
||||
if (!ggml_is_quantized(type_a)) {
|
||||
if (bs[1] == 1 && nr[1] == 1) {
|
||||
ggml_set_param(ctx, a);
|
||||
}
|
||||
ggml_set_param(ctx, b);
|
||||
}
|
||||
ggml_set_name(a, "a");
|
||||
ggml_set_name(b, "b");
|
||||
|
||||
@ -1890,8 +1951,12 @@ struct test_mul_mat : public test_case {
|
||||
} else {
|
||||
a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
|
||||
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
|
||||
ggml_set_param(ctx, a);
|
||||
ggml_set_param(ctx, b);
|
||||
if (!ggml_is_quantized(type_a)) {
|
||||
if (bs[1] == 1 && nr[1] == 1) {
|
||||
ggml_set_param(ctx, a);
|
||||
}
|
||||
ggml_set_param(ctx, b);
|
||||
}
|
||||
ggml_set_name(a, "a");
|
||||
ggml_set_name(b, "b");
|
||||
}
|
||||
@ -3798,6 +3863,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2}));
|
||||
}
|
||||
|
||||
for (bool view : {false, true}) {
|
||||
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 1}, view));
|
||||
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
|
||||
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view));
|
||||
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view));
|
||||
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
|
||||
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
|
||||
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I16, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
|
||||
test_cases.emplace_back(new test_dup(GGML_TYPE_F16));
|
||||
test_cases.emplace_back(new test_dup(GGML_TYPE_I32));
|
||||
@ -3919,21 +3994,25 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
for (ggml_type type_a : base_types) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
// test cases without permutation
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 2}));
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 2}));
|
||||
|
||||
// test cases with permutation
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
|
||||
|
Loading…
Reference in New Issue
Block a user