From fb43d5e8b57120e0dba713598116136bc3777e8c Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 9 Jan 2025 12:16:02 -0500 Subject: [PATCH] ggml-cuda : cleanup TQ2_0 This also removes custom TQ2_0 mmq dp4a, because re-using the one from Q8_0 allows avoiding to repeatedly unpack the 2-bit values to 8-bit and instead only do it once per tile. --- ggml/src/ggml-cuda/mmq.cuh | 50 +++++----------------------------- ggml/src/ggml-cuda/vecdotq.cuh | 4 --- 2 files changed, 7 insertions(+), 47 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 0cd671bc9..91c6d68ac 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -140,9 +140,6 @@ static constexpr __device__ int get_mmq_y_device() { #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) } -// tile_x_sizes{qs, dm, sc} - -// TODO: TQ2_0 to minimize shared mem #define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} #define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} #define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0} @@ -1814,7 +1811,6 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #endif // INT8_MMA_AVAILABLE } -// This is the first "simple" type with a block size of 256 template static __device__ __forceinline__ void load_tiles_tq2_0( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -1840,22 +1836,22 @@ template static __device__ __forceinlin const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride; const int qs0 = get_int_b2(bxi->qs, kqsx); -#ifdef INT8_MMA_AVAILABLE - #pragma unroll for (int l = 0; l < QR2_0; ++l) { // 0..7, 32..39 // 8..15, 40..47 // 16..23, 48..55 // 24..31, 56..63 - // FIXME: this might assume WARP_SIZE is >= 32 const int k = (kqsx/8)*32 + l*8 + kqsx % 8; + const int q = __vsub4((qs0 >> (2*l)) & 0x03030303, 0x01010101); - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = __vsub4((qs0 >> (2*l)) & 0x03030303, 0x01010101); - } +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = q; #else - x_qs[i*(2*WARP_SIZE + 1) + kqsx] = qs0; + // NOTE: this might assume WARP_SIZE is >= 32 + x_qs[i*(2*WARP_SIZE + 1) + k] = q; #endif // INT8_MMA_AVAILABLE + } } // TODO: does this work with WARP_SIZE != 32? @@ -1872,7 +1868,6 @@ template static __device__ __forceinlin const int k = threadIdx.x % (QI2_0/2); #ifdef INT8_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = bxi->d; #else x_df[i*(WARP_SIZE/4) + i/4 + k] = bxi->d; @@ -1880,37 +1875,6 @@ template static __device__ __forceinlin } } -template -static __device__ __forceinline__ void vec_dot_tq2_0_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { - - constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_TQ2_0, mmq_y); - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + txs.qs; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_0*VDR_TQ2_0_Q8_1_MMQ) { - const int k0 = k00 + k01; - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_tq2_0_q8_1_impl( - &x_qs[i*(2*WARP_SIZE + 1) + k0/QR2_0], &y_qs[j*MMQ_TILE_Y_K + k01], - x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2)], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); - // x_df[i*(WARP_SIZE/QI2_0) + i/QI2_0], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); - } - } - } -} - template static __device__ __forceinline__ void load_tiles_iq4_nl( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -2535,7 +2499,7 @@ struct mmq_type_traits { static constexpr int vdr = VDR_TQ2_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_tq2_0; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_tq2_0_q8_1_dp4a; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; template diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 80019ab24..4105211f6 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -524,9 +524,6 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( return d6 * sumf_d; } -// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called -// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q - #define VDR_TQ2_0_Q8_1_MMVQ 2 #define VDR_TQ2_0_Q8_1_MMQ 8 @@ -547,7 +544,6 @@ template static __device__ __forceinline__ float vec_dot_tq2_0_q8_1_im sumi = ggml_cuda_dp4a(__vsub4(vi, 0x01010101), u[vdr*i0 + i], sumi); // SIMD dot product } - // TODO: batch subtract by using d8 sum sumf += d8[i0] * sumi; }