From fa62da9b2dc03539a30f1306f59c4c6ffbe4f50a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 5 Feb 2025 08:58:31 +0100 Subject: [PATCH] CUDA: support for mat. mul. with ne03 != ne13 (#11656) --- ggml/src/ggml-cuda/ggml-cuda.cu | 27 +++----- ggml/src/ggml-cuda/mmv.cu | 114 ++++++++++++++++++++------------ 2 files changed, 81 insertions(+), 60 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 70a598099..4dbaefdba 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1366,8 +1366,6 @@ static void ggml_cuda_op_mul_mat( const int64_t ne13 = src1->ne[3]; const int64_t nrows1 = ggml_nrows(src1); - GGML_ASSERT(ne03 == ne13); - const int64_t ne0 = dst->ne[0]; const int64_t ne1 = dst->ne[1]; @@ -1381,9 +1379,11 @@ static void ggml_cuda_op_mul_mat( GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1)); - GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0); + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); const int64_t i02_divisor = ne12 / ne02; + const int64_t i03_divisor = ne13 / ne03; const size_t src0_ts = ggml_type_size(src0->type); const size_t src0_bs = ggml_blck_size(src0->type); @@ -1399,6 +1399,7 @@ static void ggml_cuda_op_mul_mat( GGML_ASSERT(!(split && ne02 > 1)); GGML_ASSERT(!(split && ne03 > 1)); GGML_ASSERT(!(split && ne02 < ne12)); + GGML_ASSERT(!(split && ne03 < ne13)); ggml_tensor_extra_gpu * src0_extra = split ? (ggml_tensor_extra_gpu *) src0->extra : nullptr; @@ -1562,7 +1563,8 @@ static void ggml_cuda_op_mul_mat( } // for split tensors the data begins at i0 == i0_offset_low - char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs; + const size_t nbytes_src0_matrix = ne01*ne00*src0_ts / src0_bs; + char * src0_dd_i = dev[id].src0_dd + ((i03/i03_divisor)*ne02 + (i02/i02_divisor)) * nbytes_src0_matrix; float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10; char * src1_ddq_i = dev[id].src1_ddq + src1_ddq_i_offset; float * dst_dd_i = dev[id].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff); @@ -1606,8 +1608,9 @@ static void ggml_cuda_op_mul_mat( CUDA_CHECK(cudaGetLastError()); } - if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) { - CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream)); + if (src1_col_0 == 0 && !src0_is_contiguous && i03 % i03_divisor == 0 && i02 % i02_divisor == 0) { + CUDA_CHECK(ggml_cuda_cpy_tensor_2d( + src0_dd_i, src0, i03/i03_divisor, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream)); } // do the computation @@ -1882,7 +1885,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (!split && use_mul_mat_vec && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { + if (!split && use_mul_mat_vec && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { // the custom F16 vector kernel can be used over batched cuBLAS GEMM // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) ggml_cuda_mul_mat_vec(ctx, src0, src1, dst); @@ -2216,12 +2219,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_rms_norm_back(ctx, dst); break; case GGML_OP_MUL_MAT: - if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { - GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]); - return false; - } else { - ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst); - } + ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst); break; case GGML_OP_MUL_MAT_ID: ggml_cuda_mul_mat_id(ctx, dst); @@ -2998,9 +2996,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) { return false; } - if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) { - return false; - } #ifdef GGML_USE_MUSA if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) { diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu index 5a9ddd958..f89ed03b5 100644 --- a/ggml/src/ggml-cuda/mmv.cu +++ b/ggml/src/ggml-cuda/mmv.cu @@ -1,18 +1,21 @@ +#include "ggml.h" #include "common.cuh" #include "mmv.cuh" template static __global__ void mul_mat_vec( const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row, - const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) { + const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) { const int64_t row = blockIdx.x; - const int64_t channel = blockIdx.z; + const int64_t channel = blockIdx.y; + const int64_t sample = blockIdx.z; const int tid = threadIdx.x; constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - x += (channel/channel_ratio)*stride_channel_x + row*stride_row; - y += channel *stride_channel_y; - dst += channel *stride_channel_dst; + x += (sample/sample_ratio)*stride_sample_x + (channel/channel_ratio)*stride_channel_x + row*stride_row; + y += sample *stride_sample_y + channel *stride_channel_y; + dst += sample *stride_sample_dst + channel *stride_channel_dst; const float2 * y2 = (const float2 *) y; @@ -91,12 +94,15 @@ template static void launch_mul_mat_vec_cuda( const T * x, const float * y, float * dst, const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(stride_row % 2 == 0); GGML_ASSERT(nchannels_y % nchannels_x == 0); + GGML_ASSERT(nsamples_y % nsamples_x == 0); const int64_t channel_ratio = nchannels_y / nchannels_x; + const int64_t sample_ratio = nsamples_y / nsamples_x; int device; int warp_size; @@ -118,40 +124,48 @@ static void launch_mul_mat_vec_cuda( } const int smem = warp_size*sizeof(float); - const dim3 block_nums(nrows, 1, nchannels_y); + const dim3 block_nums(nrows, nchannels_y, nsamples_y); const dim3 block_dims(block_size_best, 1, 1); switch (block_size_best) { case 32: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 64: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 96: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 128: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 160: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 192: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 224: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 256: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; default: { GGML_ABORT("fatal error"); @@ -163,16 +177,19 @@ template static void mul_mat_vec_cuda( const T * x, const float * y, float * dst, const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, enum ggml_prec prec, cudaStream_t stream) { switch (prec) { case GGML_PREC_DEFAULT: { - launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, - stride_channel_x, stride_channel_y, stride_channel_dst, stream); + launch_mul_mat_vec_cuda + (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } break; case GGML_PREC_F32: { - launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, - stride_channel_x, stride_channel_y, stride_channel_dst, stream); + launch_mul_mat_vec_cuda + (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } break; } } @@ -181,10 +198,19 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; + GGML_TENSOR_BINARY_OP_LOCALS; - GGML_ASSERT(src1->ne[1] == 1); + const size_t ts_src0 = ggml_type_size(src0->type); + const size_t ts_src1 = ggml_type_size(src1->type); + const size_t ts_dst = ggml_type_size(dst->type); + + GGML_ASSERT(ne11 == 1); + GGML_ASSERT(ne12 == ne2); + GGML_ASSERT(ne13 == ne3); + + GGML_ASSERT(nb00 == ts_src0); + GGML_ASSERT(nb10 == ts_src1); + GGML_ASSERT(nb0 == ts_dst); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; @@ -192,29 +218,22 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * const float * src1_d = (const float *) src1->data; float * dst_d = (float *) dst->data; - const int64_t ne02 = src0->ne[2]; - const int64_t ne12 = src1->ne[2]; - GGML_ASSERT(dst->ne[2] == ne12); - - GGML_ASSERT(src0->ne[3] == 1); - GGML_ASSERT(src1->ne[3] == 1); - GGML_ASSERT( dst->ne[3] == 1); - - const int64_t stride_row = src0->nb[1] / ggml_type_size(src0->type); - const int64_t channel_stride_x = src0->nb[2] / ggml_type_size(src0->type); - const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type); - const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type); + const int64_t s01 = src0->nb[1] / ts_src0; + const int64_t s02 = src0->nb[2] / ts_src0; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s2 = dst->nb[2] / ts_dst; + const int64_t s03 = src0->nb[3] / ts_src0; + const int64_t s13 = src1->nb[3] / ts_src1; + const int64_t s3 = dst->nb[3] / ts_dst; switch (src0->type) { case GGML_TYPE_F16: { const half * src0_d = (const half *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, - channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); + mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream()); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, - channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); + mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream()); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); @@ -243,20 +262,27 @@ void ggml_cuda_op_mul_mat_vec( const int64_t stride_row = ne00; const int64_t nchannels_x = 1; const int64_t nchannels_y = 1; - const int64_t channel_stride_x = 0; - const int64_t channel_stride_y = 0; - const int64_t channel_stride_dst = 0; + const int64_t stride_channel_x = 0; + const int64_t stride_channel_y = 0; + const int64_t stride_channel_dst = 0; + const int64_t nsamples_x = 1; + const int64_t nsamples_y = 1; + const int64_t stride_sample_x = 0; + const int64_t stride_sample_y = 0; + const int64_t stride_sample_dst = 0; switch (src0->type) { case GGML_TYPE_F16: { const half * src0_d = (const half *) src0_dd_i; mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, - nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); + nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, - nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); + nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));