diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 0ed92b3ff..9e627da8c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -7883,7 +7883,7 @@ static void ggml_compute_forward_out_prod_f32( float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1); } @@ -7892,7 +7892,7 @@ static void ggml_compute_forward_out_prod_f32( float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); ggml_vec_mad_f32(ne0, d, s0, *s1); } diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 35a1c876c..2ccb4b472 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -416,7 +416,8 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st case GGML_OP_IM2COL_BACK: return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; case GGML_OP_OUT_PROD: - return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32; + return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) && + src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; default: return true; } diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index c7b6be4e2..ce4b9cfb5 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -93,26 +93,31 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s template static __global__ void k_repeat_back( - const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, - const int64_t ne0, const int64_t ne1, const int64_t ne2) { + const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const size_t s00, const size_t s01, const size_t s02, const size_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) { - const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x; - const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y; - const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z; + const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x; + const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y; + const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z; + const int64_t tid2 = tid23 % ne2; + const int64_t tid3 = tid23 / ne2; if (tid0 >= ne0) { return; } T sum = 0; - for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { - for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { - for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) { - sum += src[i2*ne01*ne00 + i1*ne00 + i0]; + for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) { + for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { + for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { + for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) { + sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00]; + } } } } - dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum; + dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum; } template @@ -274,12 +279,14 @@ struct bin_bcast_cuda { template static void repeat_back_cuda( - const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, - const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) { + const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const size_t s00, const size_t s01, const size_t s02, const size_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { const dim3 block_dims(WARP_SIZE, 1, 1); - const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2); - k_repeat_back<<>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2); + const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3); + k_repeat_back<<>> + (src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3); } template @@ -326,27 +333,26 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst const ggml_tensor * src0 = dst->src[0]; GGML_ASSERT(src0->type == dst->type); - GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(ggml_can_repeat(dst, src0)); cudaStream_t stream = ctx.stream(); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - GGML_ASSERT(src0->ne[3] == 1); + GGML_TENSOR_UNARY_OP_LOCALS; - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - GGML_ASSERT(dst->ne[3] == 1); + GGML_ASSERT(ne2*ne3 <= (1 << 15)); + + const size_t ts = ggml_type_size(src0->type); + const size_t s00 = nb00 / ts; + const size_t s01 = nb01 / ts; + const size_t s02 = nb02 / ts; + const size_t s03 = nb03 / ts; switch (dst->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; float * dst_d = (float *) dst->data; - repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream); + repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream); } break; default: { GGML_ASSERT(false); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 7fd1fc853..e602419bc 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3002,7 +3002,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; } break; case GGML_OP_REPEAT_BACK: - return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1; + return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15); case GGML_OP_CONCAT: { ggml_type src0_type = op->src[0]->type; diff --git a/ggml/src/ggml-cuda/out-prod.cu b/ggml/src/ggml-cuda/out-prod.cu index 73e3e2c47..c9b2b699c 100644 --- a/ggml/src/ggml-cuda/out-prod.cu +++ b/ggml/src/ggml-cuda/out-prod.cu @@ -34,6 +34,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { CUBLAS_CHECK(cublasSetStream(handle, stream)); + const int64_t lda = nb01 / sizeof(float); + const int64_t ldc = nb1 / sizeof(float); + const bool src1_T = ggml_is_transposed(src1); const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T; const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); @@ -57,9 +60,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { CUBLAS_CHECK( cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, ne0, ne1, ne01, - &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00, + &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda, src1_d + i3 *s13 + i2 *s12, ldb, - &beta, dst_d + i3 *s3 + i2 *s2, ne0)); + &beta, dst_d + i3 *s3 + i2 *s2, ldc)); } } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b1d0d4913..92c4294c5 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5339,7 +5339,7 @@ static void ggml_compute_backward( } break; case GGML_OP_MUL: { if (src0_needs_grads) { - ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1)); } if (src1_needs_grads) { struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad); @@ -5431,21 +5431,25 @@ static void ggml_compute_backward( // src1.shape [n,p,qq,rr] if (src0_needs_grads) { - struct ggml_tensor * s1_tg = + GGML_ASSERT(grad->ne[2] == src1->ne[2]); + GGML_ASSERT(grad->ne[3] == src1->ne[3]); + struct ggml_tensor * tmp = ggml_out_prod(ctx, // [n,m,qq,rr] src1, // [n,p,qq,rr] grad); // [m,p,qq,rr] - const int64_t qq = s1_tg->ne[2]; - const int64_t rr = s1_tg->ne[3]; - const int64_t q1 = src0->ne[2]; - const int64_t r1 = src0->ne[3]; - const bool ne2_broadcasted = qq > q1; - const bool ne3_broadcasted = rr > r1; - if (ne2_broadcasted || ne3_broadcasted) { - // sum broadcast repetitions of s1_tg into shape of src0 - s1_tg = ggml_repeat_back(ctx, s1_tg, src0); + if (!ggml_are_same_shape(tmp, src0)) { + GGML_ASSERT(tmp->ne[0] == src0->ne[0]); + GGML_ASSERT(tmp->ne[1] == src0->ne[1]); + GGML_ASSERT(tmp->ne[3] == 1); + + const int64_t nr2 = tmp->ne[2] / src0->ne[2]; + const size_t nb2 = tmp->nb[2] * nr2; + const size_t nb3 = tmp->nb[2]; + + tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0); + tmp = ggml_repeat_back(ctx, tmp, src0); } - ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/); + ggml_add_or_set(ctx, cgraph, isrc0, tmp); } if (src1_needs_grads) { ggml_add_or_set(ctx, cgraph, isrc1, @@ -5514,7 +5518,9 @@ static void ggml_compute_backward( if (src0_needs_grads) { GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0])); GGML_ASSERT(ggml_is_contiguous(grad)); - ggml_add_or_set(ctx, cgraph, isrc0, grad); + GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0)); + ggml_add_or_set(ctx, cgraph, isrc0, + ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0)); } } break; case GGML_OP_RESHAPE: { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 381956a04..468016403 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1302,6 +1302,59 @@ struct test_repeat : public test_case { } }; +// GGML_OP_REPEAT_BACK +struct test_repeat_back : public test_case { + const ggml_type type; + const std::array ne; + const std::array nr; + const bool v; // whether src is a noncontiguous view + + std::string vars() override { + return VARS_TO_STR4(type, ne, nr, v); + } + + size_t op_size(ggml_tensor * t) override { + return ggml_nbytes(t) * 2; + } + + test_repeat_back(ggml_type type = GGML_TYPE_F32, + std::array ne = {8, 6, 4, 2}, + std::array nr = {2, 2, 2, 2}, + bool v = false) + : type(type), ne(ne), nr(nr), v(v) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * src = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]); + ggml_set_name(src, "src"); + + if (v) { + GGML_ASSERT(ne[0] % 2 == 0); + GGML_ASSERT(ne[1] % 2 == 0); + GGML_ASSERT(ne[2] % 2 == 0); + GGML_ASSERT(ne[3] % 2 == 0); + GGML_ASSERT(nr[0] % 2 == 0 || nr[0] == 1); + GGML_ASSERT(nr[1] % 2 == 0 || nr[1] == 1); + GGML_ASSERT(nr[2] % 2 == 0 || nr[2] == 1); + GGML_ASSERT(nr[3] % 2 == 0 || nr[3] == 1); + + const int64_t ne00 = nr[0] == 1 ? src->ne[0] : src->ne[0] / 2; + const int64_t ne01 = nr[1] == 1 ? src->ne[1] : src->ne[1] / 2; + const int64_t ne02 = nr[2] == 1 ? src->ne[2] : src->ne[2] / 2; + const int64_t ne03 = nr[3] == 1 ? src->ne[3] : src->ne[3] / 2; + + src = ggml_view_4d(ctx, src, ne00, ne01, ne02, ne03, src->nb[1], src->nb[2], src->nb[3], 0); + } + + ggml_tensor * target = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(target, "target"); + + ggml_tensor * out = ggml_repeat_back(ctx, src, target); + ggml_set_name(out, "out"); + + return out; + } +}; + // GGML_OP_DUP struct test_dup : public test_case { const ggml_type type; @@ -1849,6 +1902,10 @@ struct test_mul_mat : public test_case { return 5e-4; } + int64_t grad_nmax() override { + return 20000; + } + uint64_t op_flops(ggml_tensor * t) override { GGML_UNUSED(t); return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1]; @@ -1878,8 +1935,12 @@ struct test_mul_mat : public test_case { a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]); b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]); - ggml_set_param(ctx, a); - ggml_set_param(ctx, b); + if (!ggml_is_quantized(type_a)) { + if (bs[1] == 1 && nr[1] == 1) { + ggml_set_param(ctx, a); + } + ggml_set_param(ctx, b); + } ggml_set_name(a, "a"); ggml_set_name(b, "b"); @@ -1890,8 +1951,12 @@ struct test_mul_mat : public test_case { } else { a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]); b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); - ggml_set_param(ctx, a); - ggml_set_param(ctx, b); + if (!ggml_is_quantized(type_a)) { + if (bs[1] == 1 && nr[1] == 1) { + ggml_set_param(ctx, a); + } + ggml_set_param(ctx, b); + } ggml_set_name(a, "a"); ggml_set_name(b, "b"); } @@ -3798,6 +3863,16 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2})); } + for (bool view : {false, true}) { + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 1}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {2, 1, 1, 1}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I32, {8, 6, 4, 2}, {2, 1, 1, 1}, view)); + test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I16, {8, 6, 4, 2}, {1, 1, 1, 2}, view)); + } + test_cases.emplace_back(new test_dup(GGML_TYPE_F32)); test_cases.emplace_back(new test_dup(GGML_TYPE_F16)); test_cases.emplace_back(new test_dup(GGML_TYPE_I32)); @@ -3919,21 +3994,25 @@ static std::vector> make_test_cases_eval() { for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) { // test cases without permutation - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 2})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2})); - test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 2})); // test cases with permutation test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));