CUDA: stream-k decomposition for MMQ (#8018)

* CUDA: stream-k decomposition for MMQ

* fix undefined memory reads for small matrices
This commit is contained in:
Johannes Gäßler 2024-06-20 14:39:21 +02:00 committed by GitHub
parent 2075a66a96
commit d50f8897a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 292 additions and 113 deletions

View File

@ -635,7 +635,7 @@ static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> &
} }
const int cc = ggml_cuda_info().devices[id].cc; const int cc = ggml_cuda_info().devices[id].cc;
row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc, get_mmq_x_max_host(cc))); row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc));
} }
return row_rounding; return row_rounding;
} }

View File

@ -652,8 +652,8 @@ static int get_mmq_x_max_host(const int cc) {
} }
// Round rows to this value for --split-mode row: // Round rows to this value for --split-mode row:
static int get_mmq_y_host(const int cc, const int mmq_x) { static int get_mmq_y_host(const int cc) {
return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64; return cc >= CC_VOLTA ? 128 : 64;
} }
////////////////////// //////////////////////

View File

@ -30,34 +30,34 @@ void ggml_cuda_op_mul_mat_q(
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
mul_mat_q_case<GGML_TYPE_Q4_0>(args, stream); mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
break; break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
mul_mat_q_case<GGML_TYPE_Q4_1>(args, stream); mul_mat_q_case<GGML_TYPE_Q4_1>(ctx, args, stream);
break; break;
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
mul_mat_q_case<GGML_TYPE_Q5_0>(args, stream); mul_mat_q_case<GGML_TYPE_Q5_0>(ctx, args, stream);
break; break;
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
mul_mat_q_case<GGML_TYPE_Q5_1>(args, stream); mul_mat_q_case<GGML_TYPE_Q5_1>(ctx, args, stream);
break; break;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
mul_mat_q_case<GGML_TYPE_Q8_0>(args, stream); mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
break; break;
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
mul_mat_q_case<GGML_TYPE_Q2_K>(args, stream); mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
break; break;
case GGML_TYPE_Q3_K: case GGML_TYPE_Q3_K:
mul_mat_q_case<GGML_TYPE_Q3_K>(args, stream); mul_mat_q_case<GGML_TYPE_Q3_K>(ctx, args, stream);
break; break;
case GGML_TYPE_Q4_K: case GGML_TYPE_Q4_K:
mul_mat_q_case<GGML_TYPE_Q4_K>(args, stream); mul_mat_q_case<GGML_TYPE_Q4_K>(ctx, args, stream);
break; break;
case GGML_TYPE_Q5_K: case GGML_TYPE_Q5_K:
mul_mat_q_case<GGML_TYPE_Q5_K>(args, stream); mul_mat_q_case<GGML_TYPE_Q5_K>(ctx, args, stream);
break; break;
case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K:
mul_mat_q_case<GGML_TYPE_Q6_K>(args, stream); mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
break; break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);

View File

@ -8,6 +8,7 @@
#include <cstdint> #include <cstdint>
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
#define MMQ_NWARPS 8
typedef void (*load_tiles_mmq_t)( typedef void (*load_tiles_mmq_t)(
const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
@ -15,7 +16,7 @@ typedef void (*load_tiles_mmq_t)(
typedef void (*vec_dot_mmq_t)( typedef void (*vec_dot_mmq_t)(
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0); const int * __restrict__ y, float * __restrict__ sum, const int & k0);
typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1); typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
struct block_q8_1_mmq { struct block_q8_1_mmq {
half2 ds[4]; half2 ds[4];
@ -50,21 +51,17 @@ static constexpr __device__ int get_mmq_x_max_device() {
// get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row // get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
static constexpr __device__ int get_mmq_y_device() {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
static constexpr __device__ int get_mmq_y_device(int mmq_x) { return 128;
return mmq_x >= 32 ? 128 : 64;
}
#else #else
#if __CUDA_ARCH__ >= CC_VOLTA #if __CUDA_ARCH__ >= CC_VOLTA
static constexpr __device__ int get_mmq_y_device(int mmq_x) { return 128;
return mmq_x >= 32 ? 128 : 64;
}
#else #else
static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) {
return 64; return 64;
}
#endif // __CUDA_ARCH__ >= CC_VOLTA #endif // __CUDA_ARCH__ >= CC_VOLTA
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
}
#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} #define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} #define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
@ -1734,30 +1731,34 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
} }
template<int mmq_x, int mmq_y, int nwarps, bool need_check> template<int mmq_x, int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) { static __device__ __forceinline__ void mmq_write_back_dp4a(
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
const int j = blockIdx.y*mmq_x + j0 + threadIdx.y; const int j = j0 + threadIdx.y;
if (j >= ne1) { if (j > j_max) {
return; return;
} }
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
const int i = blockIdx.x*mmq_y + i0 + threadIdx.x; const int i = i0 + threadIdx.x;
if (need_check && i >= ne0) { if (need_check && i > i_max) {
continue; continue;
} }
dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
} }
} }
} }
template<int mmq_x, int mmq_y, int nwarps, bool need_check> template<int mmq_x, int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void mmq_write_back_mma(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) { static __device__ __forceinline__ void mmq_write_back_mma(
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
typedef mma_int_C_I16J8 mma_C; typedef mma_int_C_I16J8 mma_C;
const int i0 = threadIdx.y*mma_C::I; const int i0 = threadIdx.y*mma_C::I;
@ -1769,19 +1770,19 @@ static __device__ __forceinline__ void mmq_write_back_mma(const float * __restri
for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) { for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
#pragma unroll #pragma unroll
for (int l = 0; l < mma_C::ne; ++l) { for (int l = 0; l < mma_C::ne; ++l) {
const int j = blockIdx.y*mmq_x + j0 + mma_C::get_j(l); const int j = j0 + mma_C::get_j(l);
if (j >= ne1) { if (j > j_max) {
continue; continue;
} }
const int i = blockIdx.x*mmq_y + i0 + mma_C::get_i(l); const int i = i0 + mma_C::get_i(l);
if (need_check && i >= ne0) { if (need_check && i > i_max) {
continue; continue;
} }
dst[j*ne0 + i] = sum[(j0/mma_C::J)*mma_C::ne + l]; dst[j*stride + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
} }
} }
} }
@ -1896,32 +1897,16 @@ static bool mmq_need_sum(const ggml_type type_x) {
return false; return false;
} }
template <ggml_type type, int mmq_x, int nwarps, bool need_check> template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) static __device__ void mul_mat_q_process_tile(
#if defined(RDNA3) || defined(RDNA2) const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
__launch_bounds__(WARP_SIZE*nwarps, 2) const int & ne00, const int & ne01, const int & stride01, const int & ne10, const int & ne11, const int & stride11, const int & ne0,
#endif // defined(RDNA3) || defined(RDNA2) const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
#else
#if __CUDA_ARCH__ >= CC_VOLTA
__launch_bounds__(WARP_SIZE*nwarps, 1)
#else
__launch_bounds__(WARP_SIZE*nwarps, 2)
#endif // __CUDA_ARCH__ >= CC_VOLTA
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
static __global__ void mul_mat_q(
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst,
const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
// Skip unused template specializations for faster compilation:
if (mmq_x > get_mmq_x_max_device()) {
NO_DEVICE_CODE;
return;
}
constexpr int qk = ggml_cuda_type_traits<type>::qk; constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qr = ggml_cuda_type_traits<type>::qr; constexpr int qr = ggml_cuda_type_traits<type>::qr;
constexpr int qi = ggml_cuda_type_traits<type>::qi; constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int mmq_y = get_mmq_y_device(mmq_x); constexpr int mmq_y = get_mmq_y_device();
constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr; constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles; constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
@ -1941,20 +1926,18 @@ static __global__ void mul_mat_q(
int * tile_x_sc = (int *) (tile_x_dm + txs.dm); int * tile_x_sc = (int *) (tile_x_dm + txs.dm);
int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)] int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
const int blocks_per_row_x = ne00 / qk; constexpr int blocks_per_warp = WARP_SIZE / qi;
const int blocks_per_warp = WARP_SIZE / qi;
const int & ne1 = ne11;
const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1;
const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) { const int tile_x_max_i = ne01 - it*mmq_y - 1;
const int tile_y_max_j = ne11 - jt*mmq_x - 1;
load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01); const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_warp) {
load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
#pragma unroll #pragma unroll
for (int kr = 0; kr < qr; ++kr) { for (int kr = 0; kr < qr; ++kr) {
@ -1977,7 +1960,176 @@ static __global__ void mul_mat_q(
} }
} }
write_back(sum, dst, ne0, ne1); if (fixup) {
write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
} else {
write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
}
}
// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#if defined(RDNA3) || defined(RDNA2)
__launch_bounds__(WARP_SIZE*nwarps, 2)
#endif // defined(RDNA3) || defined(RDNA2)
#else
#if __CUDA_ARCH__ >= CC_VOLTA
__launch_bounds__(WARP_SIZE*nwarps, 1)
#else
__launch_bounds__(WARP_SIZE*nwarps, 2)
#endif // __CUDA_ARCH__ >= CC_VOLTA
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
static __global__ void mul_mat_q(
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
// Skip unused template specializations for faster compilation:
if (mmq_x > get_mmq_x_max_device()) {
NO_DEVICE_CODE;
return;
}
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int mmq_y = get_mmq_y_device();
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
{
constexpr bool fixup = false;
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
(x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
blockIdx.x, blockIdx.y, 0, ne00/qk);
return;
}
#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
const int64_t blocks_per_ne00 = ne00 / qk;
constexpr int blocks_per_warp = WARP_SIZE / qi;
const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
// kbc == k block continuous, current index in continuous ijk space.
int64_t kbc = GGML_PAD((int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
const int64_t kbc_stop = GGML_PAD((int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
// kb0 == k index when doing the matrix multiplication for an output tile.
int kb0_start = kbc % blocks_per_ne00;
int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
const int jt = kbc / (blocks_per_ne00*nty); // j index of current tile.
const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // i index of current tile.
constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
(x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
it, jt, kb0_start, kb0_stop);
kbc += blocks_per_ne00;
kbc -= kbc % blocks_per_ne00;
kb0_start = 0;
kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
}
if (kbc >= kbc_stop) {
return;
}
const int jt = kbc / (blocks_per_ne00*nty);
const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks.
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
(x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
it, jt, kb0_start, kb0_stop);
}
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
static __global__ void mul_mat_q_stream_k_fixup(
float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ne00, const int ne01, const int ne11, const int ne0, const int block_num_mmq) {
constexpr int mmq_y = get_mmq_y_device();
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int blocks_per_warp = WARP_SIZE / qi;
const int64_t blocks_per_ne00 = ne00 / qk;
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
const int ntx = (ne11 + mmq_x - 1) / mmq_x;
const int nty = (ne01 + mmq_y - 1) / mmq_y;
bool any_fixup = false;
const int bidx_start = (blockIdx.y*nty + blockIdx.x) * block_num_mmq / (gridDim.y*gridDim.x);
const int bidx_stop = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x) + 1;
for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
const int64_t kbc = GGML_PAD((int64_t) bidx *blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
const int64_t kbc_stop = GGML_PAD((int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
// Skip fixup tile if the MMQ CUDA block never wrote anything to it:
if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
continue;
}
const int jt = kbc_stop / (blocks_per_ne00*nty);
const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
// Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
if (it != blockIdx.x || jt != blockIdx.y) {
continue;
}
any_fixup = true;
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
}
}
}
if (!any_fixup) {
return;
}
dst += blockIdx.y*mmq_x*ne0 + blockIdx.x*mmq_y;
const int i_max = ne01 - blockIdx.x*mmq_y - 1;
const int j_max = ne11 - blockIdx.y*mmq_x - 1;
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (j > j_max) {
return;
}
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (need_check && i > i_max) {
continue;
}
dst[j*ne0 + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
}
}
} }
struct mmq_args { struct mmq_args {
@ -1987,124 +2139,151 @@ struct mmq_args {
int64_t ne0; int64_t ne0;
}; };
constexpr int mmq_get_nwarps(int mmq_x) {
return mmq_x >= 32 ? 8 : 4;
}
static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) { static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) {
const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y); const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
const int nwarps = mmq_get_nwarps(mmq_x);
const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2); const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
return shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int)); return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
} }
template <ggml_type type, int mmq_x, int nwarps> template <ggml_type type, int mmq_x>
static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) { static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
const int id = ggml_cuda_get_device(); const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc; const int cc = ggml_cuda_info().devices[id].cc;
const int mmq_y = get_mmq_y_host(cc, mmq_x); const int nsm = ggml_cuda_info().devices[id].nsm;
const int mmq_y = get_mmq_y_host(cc);
const int block_num_x = (args.ne01 + mmq_y - 1) / mmq_y; const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
const int block_num_y = (args.ne11 + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const int shmem = mmq_get_shmem(type, mmq_x, mmq_y); const int shmem = mmq_get_shmem(type, mmq_x, mmq_y);
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shmem_limit_raised[id]) { if (!shmem_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
shmem_limit_raised[id] = true; shmem_limit_raised[id] = true;
} }
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
const dim3 block_nums_xy_tiling(nty, ntx, 1);
const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
if (!use_stream_k) {
if (args.ne01 % mmq_y == 0) {
constexpr bool need_check = false;
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
(args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
} else {
constexpr bool need_check = true;
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
(args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
}
return;
}
const dim3 block_nums_mmq(nsm, 1, 1);
ggml_cuda_pool & pool = ctx.pool();
ggml_cuda_pool_alloc<float> tmp_fixup(pool, block_nums_mmq.x * mmq_x*mmq_y);
if (args.ne01 % mmq_y == 0) { if (args.ne01 % mmq_y == 0) {
const bool need_check = false; constexpr bool need_check = false;
mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
(args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
(args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
(args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
} else { } else {
const bool need_check = true; constexpr bool need_check = true;
mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
(args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
(args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
(args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
} }
} }
template <ggml_type type> template <ggml_type type>
void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) { void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
const int id = ggml_cuda_get_device(); const int id = ggml_cuda_get_device();
const int nsm = ggml_cuda_info().devices[id].nsm; const int nsm = ggml_cuda_info().devices[id].nsm;
const int cc = ggml_cuda_info().devices[id].cc; const int cc = ggml_cuda_info().devices[id].cc;
const int smpbo = ggml_cuda_info().devices[id].smpbo; const int smpbo = ggml_cuda_info().devices[id].smpbo;
const int mmq_x_max = get_mmq_x_max_host(cc); const int mmq_x_max = get_mmq_x_max_host(cc);
const int mmq_y = get_mmq_y_host(cc, mmq_x_max); const int mmq_y = get_mmq_y_host(cc);
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y; const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
int mmq_x_best = 0; int mmq_x_best = 0;
int nwaves_best = INT_MAX; int nparts_best = INT_MAX;
for (int mmq_x = 8; mmq_x <= mmq_x_max && nwaves_best > 1; mmq_x += 8) { for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x; const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x;
const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm; const int nwaves_xy_tiling = ntiles_x*block_num_y;
if (nwaves < nwaves_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) { const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling;
if (nparts < nparts_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) {
mmq_x_best = mmq_x; mmq_x_best = mmq_x;
nwaves_best = nwaves; nparts_best = nparts;
} }
} }
switch (mmq_x_best) { switch (mmq_x_best) {
case 8: case 8:
launch_mul_mat_q<type, 8, mmq_get_nwarps( 8)>(args, stream); launch_mul_mat_q<type, 8>(ctx, args, stream);
break; break;
case 16: case 16:
launch_mul_mat_q<type, 16, mmq_get_nwarps( 16)>(args, stream); launch_mul_mat_q<type, 16>(ctx, args, stream);
break; break;
case 24: case 24:
launch_mul_mat_q<type, 24, mmq_get_nwarps( 24)>(args, stream); launch_mul_mat_q<type, 24>(ctx, args, stream);
break; break;
case 32: case 32:
launch_mul_mat_q<type, 32, mmq_get_nwarps( 32)>(args, stream); launch_mul_mat_q<type, 32>(ctx, args, stream);
break; break;
case 40: case 40:
launch_mul_mat_q<type, 40, mmq_get_nwarps( 40)>(args, stream); launch_mul_mat_q<type, 40>(ctx, args, stream);
break; break;
case 48: case 48:
launch_mul_mat_q<type, 48, mmq_get_nwarps( 48)>(args, stream); launch_mul_mat_q<type, 48>(ctx, args, stream);
break; break;
case 56: case 56:
launch_mul_mat_q<type, 56, mmq_get_nwarps( 56)>(args, stream); launch_mul_mat_q<type, 56>(ctx, args, stream);
break; break;
case 64: case 64:
launch_mul_mat_q<type, 64, mmq_get_nwarps( 64)>(args, stream); launch_mul_mat_q<type, 64>(ctx, args, stream);
break; break;
case 72: case 72:
launch_mul_mat_q<type, 72, mmq_get_nwarps( 72)>(args, stream); launch_mul_mat_q<type, 72>(ctx, args, stream);
break; break;
case 80: case 80:
launch_mul_mat_q<type, 80, mmq_get_nwarps( 80)>(args, stream); launch_mul_mat_q<type, 80>(ctx, args, stream);
break; break;
case 88: case 88:
launch_mul_mat_q<type, 88, mmq_get_nwarps( 88)>(args, stream); launch_mul_mat_q<type, 88>(ctx, args, stream);
break; break;
case 96: case 96:
launch_mul_mat_q<type, 96, mmq_get_nwarps( 96)>(args, stream); launch_mul_mat_q<type, 96>(ctx, args, stream);
break; break;
case 104: case 104:
launch_mul_mat_q<type, 104, mmq_get_nwarps(104)>(args, stream); launch_mul_mat_q<type, 104>(ctx, args, stream);
break; break;
case 112: case 112:
launch_mul_mat_q<type, 112, mmq_get_nwarps(112)>(args, stream); launch_mul_mat_q<type, 112>(ctx, args, stream);
break; break;
case 120: case 120:
launch_mul_mat_q<type, 120, mmq_get_nwarps(120)>(args, stream); launch_mul_mat_q<type, 120>(ctx, args, stream);
break; break;
case 128: case 128:
launch_mul_mat_q<type, 128, mmq_get_nwarps(128)>(args, stream); launch_mul_mat_q<type, 128>(ctx, args, stream);
break; break;
default: default:
fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best); fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
@ -2114,7 +2293,7 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
} }
#define DECL_MMQ_CASE(type) \ #define DECL_MMQ_CASE(type) \
template void mul_mat_q_case<type>(const mmq_args & args, cudaStream_t stream) \ template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
extern DECL_MMQ_CASE(GGML_TYPE_Q4_0); extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q4_1); extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);