ggml-cuda : cleanup TQ2_0

This also removes custom TQ2_0 mmq dp4a,
because re-using the one from Q8_0 allows avoiding
to repeatedly unpack the 2-bit values to 8-bit
and instead only do it once per tile.
This commit is contained in:
Francis Couture-Harpin 2025-01-09 12:16:02 -05:00
parent 970b5ab7ca
commit fb43d5e8b5
2 changed files with 7 additions and 47 deletions

View File

@ -140,9 +140,6 @@ static constexpr __device__ int get_mmq_y_device() {
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
}
// tile_x_sizes{qs, dm, sc}
// TODO: TQ2_0 to minimize shared mem
#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
@ -1814,7 +1811,6 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
#endif // INT8_MMA_AVAILABLE
}
// This is the first "simple" type with a block size of 256
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_tq2_0(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
@ -1840,22 +1836,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride;
const int qs0 = get_int_b2(bxi->qs, kqsx);
#ifdef INT8_MMA_AVAILABLE
#pragma unroll
for (int l = 0; l < QR2_0; ++l) {
// 0..7, 32..39
// 8..15, 40..47
// 16..23, 48..55
// 24..31, 56..63
// FIXME: this might assume WARP_SIZE is >= 32
const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
const int q = __vsub4((qs0 >> (2*l)) & 0x03030303, 0x01010101);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = __vsub4((qs0 >> (2*l)) & 0x03030303, 0x01010101);
}
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = q;
#else
x_qs[i*(2*WARP_SIZE + 1) + kqsx] = qs0;
// NOTE: this might assume WARP_SIZE is >= 32
x_qs[i*(2*WARP_SIZE + 1) + k] = q;
#endif // INT8_MMA_AVAILABLE
}
}
// TODO: does this work with WARP_SIZE != 32?
@ -1872,7 +1868,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const int k = threadIdx.x % (QI2_0/2);
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = bxi->d;
#else
x_df[i*(WARP_SIZE/4) + i/4 + k] = bxi->d;
@ -1880,37 +1875,6 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_tq2_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_TQ2_0, mmq_y);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + txs.qs;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_0*VDR_TQ2_0_Q8_1_MMQ) {
const int k0 = k00 + k01;
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_tq2_0_q8_1_impl<VDR_TQ2_0_Q8_1_MMQ>(
&x_qs[i*(2*WARP_SIZE + 1) + k0/QR2_0], &y_qs[j*MMQ_TILE_Y_K + k01],
x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2)], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
// x_df[i*(WARP_SIZE/QI2_0) + i/QI2_0], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
}
}
}
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
@ -2535,7 +2499,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_TQ2_0> {
static constexpr int vdr = VDR_TQ2_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_tq2_0<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_tq2_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>

View File

@ -524,9 +524,6 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
return d6 * sumf_d;
}
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
#define VDR_TQ2_0_Q8_1_MMVQ 2
#define VDR_TQ2_0_Q8_1_MMQ 8
@ -547,7 +544,6 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_tq2_0_q8_1_im
sumi = ggml_cuda_dp4a(__vsub4(vi, 0x01010101), u[vdr*i0 + i], sumi); // SIMD dot product
}
// TODO: batch subtract by using d8 sum
sumf += d8[i0] * sumi;
}