mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 13:58:46 +01:00
CUDA: stream-k decomposition for MMQ (#8018)
* CUDA: stream-k decomposition for MMQ * fix undefined memory reads for small matrices
This commit is contained in:
parent
2075a66a96
commit
d50f8897a7
@ -635,7 +635,7 @@ static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> &
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int cc = ggml_cuda_info().devices[id].cc;
|
const int cc = ggml_cuda_info().devices[id].cc;
|
||||||
row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc, get_mmq_x_max_host(cc)));
|
row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc));
|
||||||
}
|
}
|
||||||
return row_rounding;
|
return row_rounding;
|
||||||
}
|
}
|
||||||
|
@ -652,8 +652,8 @@ static int get_mmq_x_max_host(const int cc) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Round rows to this value for --split-mode row:
|
// Round rows to this value for --split-mode row:
|
||||||
static int get_mmq_y_host(const int cc, const int mmq_x) {
|
static int get_mmq_y_host(const int cc) {
|
||||||
return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64;
|
return cc >= CC_VOLTA ? 128 : 64;
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////
|
//////////////////////
|
||||||
|
@ -30,34 +30,34 @@ void ggml_cuda_op_mul_mat_q(
|
|||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
mul_mat_q_case<GGML_TYPE_Q4_0>(args, stream);
|
mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
mul_mat_q_case<GGML_TYPE_Q4_1>(args, stream);
|
mul_mat_q_case<GGML_TYPE_Q4_1>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
mul_mat_q_case<GGML_TYPE_Q5_0>(args, stream);
|
mul_mat_q_case<GGML_TYPE_Q5_0>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
mul_mat_q_case<GGML_TYPE_Q5_1>(args, stream);
|
mul_mat_q_case<GGML_TYPE_Q5_1>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
mul_mat_q_case<GGML_TYPE_Q8_0>(args, stream);
|
mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
mul_mat_q_case<GGML_TYPE_Q2_K>(args, stream);
|
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
mul_mat_q_case<GGML_TYPE_Q3_K>(args, stream);
|
mul_mat_q_case<GGML_TYPE_Q3_K>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
mul_mat_q_case<GGML_TYPE_Q4_K>(args, stream);
|
mul_mat_q_case<GGML_TYPE_Q4_K>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
mul_mat_q_case<GGML_TYPE_Q5_K>(args, stream);
|
mul_mat_q_case<GGML_TYPE_Q5_K>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
mul_mat_q_case<GGML_TYPE_Q6_K>(args, stream);
|
mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
|
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
|
||||||
|
#define MMQ_NWARPS 8
|
||||||
|
|
||||||
typedef void (*load_tiles_mmq_t)(
|
typedef void (*load_tiles_mmq_t)(
|
||||||
const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
|
const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
|
||||||
@ -15,7 +16,7 @@ typedef void (*load_tiles_mmq_t)(
|
|||||||
typedef void (*vec_dot_mmq_t)(
|
typedef void (*vec_dot_mmq_t)(
|
||||||
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
|
||||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0);
|
const int * __restrict__ y, float * __restrict__ sum, const int & k0);
|
||||||
typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1);
|
typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
|
||||||
|
|
||||||
struct block_q8_1_mmq {
|
struct block_q8_1_mmq {
|
||||||
half2 ds[4];
|
half2 ds[4];
|
||||||
@ -50,21 +51,17 @@ static constexpr __device__ int get_mmq_x_max_device() {
|
|||||||
|
|
||||||
// get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
|
// get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
|
||||||
|
|
||||||
|
static constexpr __device__ int get_mmq_y_device() {
|
||||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
static constexpr __device__ int get_mmq_y_device(int mmq_x) {
|
return 128;
|
||||||
return mmq_x >= 32 ? 128 : 64;
|
|
||||||
}
|
|
||||||
#else
|
#else
|
||||||
#if __CUDA_ARCH__ >= CC_VOLTA
|
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||||
static constexpr __device__ int get_mmq_y_device(int mmq_x) {
|
return 128;
|
||||||
return mmq_x >= 32 ? 128 : 64;
|
|
||||||
}
|
|
||||||
#else
|
#else
|
||||||
static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) {
|
|
||||||
return 64;
|
return 64;
|
||||||
}
|
|
||||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
||||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
}
|
||||||
|
|
||||||
#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
|
#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
|
||||||
#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
|
#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
|
||||||
@ -1734,30 +1731,34 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||||
static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
|
static __device__ __forceinline__ void mmq_write_back_dp4a(
|
||||||
|
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
||||||
const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
if (j >= ne1) {
|
if (j > j_max) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
||||||
const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
if (need_check && i >= ne0) {
|
if (need_check && i > i_max) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
|
dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||||
static __device__ __forceinline__ void mmq_write_back_mma(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
|
static __device__ __forceinline__ void mmq_write_back_mma(
|
||||||
|
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
|
||||||
|
|
||||||
typedef mma_int_C_I16J8 mma_C;
|
typedef mma_int_C_I16J8 mma_C;
|
||||||
|
|
||||||
const int i0 = threadIdx.y*mma_C::I;
|
const int i0 = threadIdx.y*mma_C::I;
|
||||||
@ -1769,19 +1770,19 @@ static __device__ __forceinline__ void mmq_write_back_mma(const float * __restri
|
|||||||
for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
|
for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne; ++l) {
|
for (int l = 0; l < mma_C::ne; ++l) {
|
||||||
const int j = blockIdx.y*mmq_x + j0 + mma_C::get_j(l);
|
const int j = j0 + mma_C::get_j(l);
|
||||||
|
|
||||||
if (j >= ne1) {
|
if (j > j_max) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int i = blockIdx.x*mmq_y + i0 + mma_C::get_i(l);
|
const int i = i0 + mma_C::get_i(l);
|
||||||
|
|
||||||
if (need_check && i >= ne0) {
|
if (need_check && i > i_max) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
dst[j*ne0 + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
|
dst[j*stride + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1896,32 +1897,16 @@ static bool mmq_need_sum(const ggml_type type_x) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
|
||||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
static __device__ void mul_mat_q_process_tile(
|
||||||
#if defined(RDNA3) || defined(RDNA2)
|
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
||||||
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
const int & ne00, const int & ne01, const int & stride01, const int & ne10, const int & ne11, const int & stride11, const int & ne0,
|
||||||
#endif // defined(RDNA3) || defined(RDNA2)
|
const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
|
||||||
#else
|
|
||||||
#if __CUDA_ARCH__ >= CC_VOLTA
|
|
||||||
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
|
||||||
#else
|
|
||||||
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
|
||||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
|
||||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
|
||||||
static __global__ void mul_mat_q(
|
|
||||||
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst,
|
|
||||||
const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
|
|
||||||
|
|
||||||
// Skip unused template specializations for faster compilation:
|
|
||||||
if (mmq_x > get_mmq_x_max_device()) {
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||||
constexpr int qr = ggml_cuda_type_traits<type>::qr;
|
constexpr int qr = ggml_cuda_type_traits<type>::qr;
|
||||||
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
||||||
constexpr int mmq_y = get_mmq_y_device(mmq_x);
|
constexpr int mmq_y = get_mmq_y_device();
|
||||||
constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
|
constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
|
||||||
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
|
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
|
||||||
|
|
||||||
@ -1941,20 +1926,18 @@ static __global__ void mul_mat_q(
|
|||||||
int * tile_x_sc = (int *) (tile_x_dm + txs.dm);
|
int * tile_x_sc = (int *) (tile_x_dm + txs.dm);
|
||||||
int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
|
int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
|
||||||
|
|
||||||
const int blocks_per_row_x = ne00 / qk;
|
constexpr int blocks_per_warp = WARP_SIZE / qi;
|
||||||
const int blocks_per_warp = WARP_SIZE / qi;
|
|
||||||
|
|
||||||
const int & ne1 = ne11;
|
|
||||||
|
|
||||||
const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1;
|
|
||||||
|
|
||||||
const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
|
|
||||||
|
|
||||||
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
|
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
|
||||||
|
|
||||||
for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
|
const int tile_x_max_i = ne01 - it*mmq_y - 1;
|
||||||
|
const int tile_y_max_j = ne11 - jt*mmq_x - 1;
|
||||||
|
|
||||||
load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01);
|
const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
|
||||||
|
|
||||||
|
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_warp) {
|
||||||
|
|
||||||
|
load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int kr = 0; kr < qr; ++kr) {
|
for (int kr = 0; kr < qr; ++kr) {
|
||||||
@ -1977,7 +1960,176 @@ static __global__ void mul_mat_q(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
write_back(sum, dst, ne0, ne1);
|
if (fixup) {
|
||||||
|
write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
|
||||||
|
} else {
|
||||||
|
write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
|
||||||
|
|
||||||
|
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
||||||
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#if defined(RDNA3) || defined(RDNA2)
|
||||||
|
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
||||||
|
#endif // defined(RDNA3) || defined(RDNA2)
|
||||||
|
#else
|
||||||
|
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||||
|
__launch_bounds__(WARP_SIZE*nwarps, 1)
|
||||||
|
#else
|
||||||
|
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
||||||
|
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
||||||
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
static __global__ void mul_mat_q(
|
||||||
|
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
||||||
|
const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
|
||||||
|
|
||||||
|
// Skip unused template specializations for faster compilation:
|
||||||
|
if (mmq_x > get_mmq_x_max_device()) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||||
|
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
||||||
|
constexpr int mmq_y = get_mmq_y_device();
|
||||||
|
|
||||||
|
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
|
||||||
|
#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
|
||||||
|
{
|
||||||
|
constexpr bool fixup = false;
|
||||||
|
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
||||||
|
(x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
|
||||||
|
blockIdx.x, blockIdx.y, 0, ne00/qk);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
|
||||||
|
|
||||||
|
const int64_t blocks_per_ne00 = ne00 / qk;
|
||||||
|
constexpr int blocks_per_warp = WARP_SIZE / qi;
|
||||||
|
|
||||||
|
const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
|
||||||
|
const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
|
||||||
|
|
||||||
|
// kbc == k block continuous, current index in continuous ijk space.
|
||||||
|
int64_t kbc = GGML_PAD((int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
|
||||||
|
const int64_t kbc_stop = GGML_PAD((int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
|
||||||
|
|
||||||
|
// kb0 == k index when doing the matrix multiplication for an output tile.
|
||||||
|
int kb0_start = kbc % blocks_per_ne00;
|
||||||
|
int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
|
||||||
|
while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
|
||||||
|
const int jt = kbc / (blocks_per_ne00*nty); // j index of current tile.
|
||||||
|
const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // i index of current tile.
|
||||||
|
|
||||||
|
constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
||||||
|
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
||||||
|
(x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
|
||||||
|
it, jt, kb0_start, kb0_stop);
|
||||||
|
|
||||||
|
kbc += blocks_per_ne00;
|
||||||
|
kbc -= kbc % blocks_per_ne00;
|
||||||
|
|
||||||
|
kb0_start = 0;
|
||||||
|
kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (kbc >= kbc_stop) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int jt = kbc / (blocks_per_ne00*nty);
|
||||||
|
const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
|
||||||
|
|
||||||
|
constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks.
|
||||||
|
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
||||||
|
(x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
|
||||||
|
it, jt, kb0_start, kb0_stop);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
||||||
|
static __global__ void mul_mat_q_stream_k_fixup(
|
||||||
|
float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ne00, const int ne01, const int ne11, const int ne0, const int block_num_mmq) {
|
||||||
|
|
||||||
|
constexpr int mmq_y = get_mmq_y_device();
|
||||||
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||||
|
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
||||||
|
constexpr int blocks_per_warp = WARP_SIZE / qi;
|
||||||
|
const int64_t blocks_per_ne00 = ne00 / qk;
|
||||||
|
|
||||||
|
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
|
||||||
|
|
||||||
|
const int ntx = (ne11 + mmq_x - 1) / mmq_x;
|
||||||
|
const int nty = (ne01 + mmq_y - 1) / mmq_y;
|
||||||
|
|
||||||
|
bool any_fixup = false;
|
||||||
|
|
||||||
|
const int bidx_start = (blockIdx.y*nty + blockIdx.x) * block_num_mmq / (gridDim.y*gridDim.x);
|
||||||
|
const int bidx_stop = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x) + 1;
|
||||||
|
|
||||||
|
for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
|
||||||
|
const int64_t kbc = GGML_PAD((int64_t) bidx *blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
|
||||||
|
const int64_t kbc_stop = GGML_PAD((int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
|
||||||
|
|
||||||
|
// Skip fixup tile if the MMQ CUDA block never wrote anything to it:
|
||||||
|
if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int jt = kbc_stop / (blocks_per_ne00*nty);
|
||||||
|
const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
|
||||||
|
|
||||||
|
// Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
|
||||||
|
if (it != blockIdx.x || jt != blockIdx.y) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
any_fixup = true;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
||||||
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
|
sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!any_fixup) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst += blockIdx.y*mmq_x*ne0 + blockIdx.x*mmq_y;
|
||||||
|
|
||||||
|
const int i_max = ne01 - blockIdx.x*mmq_y - 1;
|
||||||
|
const int j_max = ne11 - blockIdx.y*mmq_x - 1;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
||||||
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
|
if (j > j_max) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
|
if (need_check && i > i_max) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[j*ne0 + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct mmq_args {
|
struct mmq_args {
|
||||||
@ -1987,124 +2139,151 @@ struct mmq_args {
|
|||||||
int64_t ne0;
|
int64_t ne0;
|
||||||
};
|
};
|
||||||
|
|
||||||
constexpr int mmq_get_nwarps(int mmq_x) {
|
|
||||||
return mmq_x >= 32 ? 8 : 4;
|
|
||||||
}
|
|
||||||
|
|
||||||
static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) {
|
static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) {
|
||||||
const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
|
const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
|
||||||
const int nwarps = mmq_get_nwarps(mmq_x);
|
|
||||||
|
|
||||||
const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
||||||
const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
|
const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
|
||||||
return shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int));
|
return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <ggml_type type, int mmq_x, int nwarps>
|
template <ggml_type type, int mmq_x>
|
||||||
static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
|
static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
|
||||||
const int id = ggml_cuda_get_device();
|
const int id = ggml_cuda_get_device();
|
||||||
const int cc = ggml_cuda_info().devices[id].cc;
|
const int cc = ggml_cuda_info().devices[id].cc;
|
||||||
const int mmq_y = get_mmq_y_host(cc, mmq_x);
|
const int nsm = ggml_cuda_info().devices[id].nsm;
|
||||||
|
const int mmq_y = get_mmq_y_host(cc);
|
||||||
|
|
||||||
const int block_num_x = (args.ne01 + mmq_y - 1) / mmq_y;
|
const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
|
||||||
const int block_num_y = (args.ne11 + mmq_x - 1) / mmq_x;
|
|
||||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
|
||||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
|
||||||
|
|
||||||
const int shmem = mmq_get_shmem(type, mmq_x, mmq_y);
|
const int shmem = mmq_get_shmem(type, mmq_x, mmq_y);
|
||||||
|
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||||
if (!shmem_limit_raised[id]) {
|
if (!shmem_limit_raised[id]) {
|
||||||
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
|
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
|
||||||
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
|
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
|
||||||
shmem_limit_raised[id] = true;
|
shmem_limit_raised[id] = true;
|
||||||
}
|
}
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
|
||||||
|
const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
|
||||||
|
const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
|
||||||
|
const dim3 block_nums_xy_tiling(nty, ntx, 1);
|
||||||
|
|
||||||
|
const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
|
||||||
|
if (!use_stream_k) {
|
||||||
if (args.ne01 % mmq_y == 0) {
|
if (args.ne01 % mmq_y == 0) {
|
||||||
const bool need_check = false;
|
constexpr bool need_check = false;
|
||||||
mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
|
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
|
||||||
(args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
|
(args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
|
||||||
} else {
|
} else {
|
||||||
const bool need_check = true;
|
constexpr bool need_check = true;
|
||||||
mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
|
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
|
||||||
(args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
|
(args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const dim3 block_nums_mmq(nsm, 1, 1);
|
||||||
|
|
||||||
|
ggml_cuda_pool & pool = ctx.pool();
|
||||||
|
ggml_cuda_pool_alloc<float> tmp_fixup(pool, block_nums_mmq.x * mmq_x*mmq_y);
|
||||||
|
|
||||||
|
if (args.ne01 % mmq_y == 0) {
|
||||||
|
constexpr bool need_check = false;
|
||||||
|
|
||||||
|
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
|
||||||
|
(args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
|
||||||
|
|
||||||
|
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
|
||||||
|
(args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
|
||||||
|
} else {
|
||||||
|
constexpr bool need_check = true;
|
||||||
|
|
||||||
|
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
|
||||||
|
(args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
|
||||||
|
|
||||||
|
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
|
||||||
|
(args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <ggml_type type>
|
template <ggml_type type>
|
||||||
void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
|
void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
|
||||||
const int id = ggml_cuda_get_device();
|
const int id = ggml_cuda_get_device();
|
||||||
const int nsm = ggml_cuda_info().devices[id].nsm;
|
const int nsm = ggml_cuda_info().devices[id].nsm;
|
||||||
const int cc = ggml_cuda_info().devices[id].cc;
|
const int cc = ggml_cuda_info().devices[id].cc;
|
||||||
const int smpbo = ggml_cuda_info().devices[id].smpbo;
|
const int smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||||
|
|
||||||
const int mmq_x_max = get_mmq_x_max_host(cc);
|
const int mmq_x_max = get_mmq_x_max_host(cc);
|
||||||
const int mmq_y = get_mmq_y_host(cc, mmq_x_max);
|
const int mmq_y = get_mmq_y_host(cc);
|
||||||
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
|
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
|
||||||
|
const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
|
||||||
|
|
||||||
int mmq_x_best = 0;
|
int mmq_x_best = 0;
|
||||||
int nwaves_best = INT_MAX;
|
int nparts_best = INT_MAX;
|
||||||
|
|
||||||
for (int mmq_x = 8; mmq_x <= mmq_x_max && nwaves_best > 1; mmq_x += 8) {
|
for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
|
||||||
const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x;
|
const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x;
|
||||||
const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm;
|
const int nwaves_xy_tiling = ntiles_x*block_num_y;
|
||||||
|
|
||||||
if (nwaves < nwaves_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) {
|
const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling;
|
||||||
|
|
||||||
|
if (nparts < nparts_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) {
|
||||||
mmq_x_best = mmq_x;
|
mmq_x_best = mmq_x;
|
||||||
nwaves_best = nwaves;
|
nparts_best = nparts;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (mmq_x_best) {
|
switch (mmq_x_best) {
|
||||||
case 8:
|
case 8:
|
||||||
launch_mul_mat_q<type, 8, mmq_get_nwarps( 8)>(args, stream);
|
launch_mul_mat_q<type, 8>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 16:
|
case 16:
|
||||||
launch_mul_mat_q<type, 16, mmq_get_nwarps( 16)>(args, stream);
|
launch_mul_mat_q<type, 16>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 24:
|
case 24:
|
||||||
launch_mul_mat_q<type, 24, mmq_get_nwarps( 24)>(args, stream);
|
launch_mul_mat_q<type, 24>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 32:
|
case 32:
|
||||||
launch_mul_mat_q<type, 32, mmq_get_nwarps( 32)>(args, stream);
|
launch_mul_mat_q<type, 32>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 40:
|
case 40:
|
||||||
launch_mul_mat_q<type, 40, mmq_get_nwarps( 40)>(args, stream);
|
launch_mul_mat_q<type, 40>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 48:
|
case 48:
|
||||||
launch_mul_mat_q<type, 48, mmq_get_nwarps( 48)>(args, stream);
|
launch_mul_mat_q<type, 48>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 56:
|
case 56:
|
||||||
launch_mul_mat_q<type, 56, mmq_get_nwarps( 56)>(args, stream);
|
launch_mul_mat_q<type, 56>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 64:
|
case 64:
|
||||||
launch_mul_mat_q<type, 64, mmq_get_nwarps( 64)>(args, stream);
|
launch_mul_mat_q<type, 64>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 72:
|
case 72:
|
||||||
launch_mul_mat_q<type, 72, mmq_get_nwarps( 72)>(args, stream);
|
launch_mul_mat_q<type, 72>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 80:
|
case 80:
|
||||||
launch_mul_mat_q<type, 80, mmq_get_nwarps( 80)>(args, stream);
|
launch_mul_mat_q<type, 80>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 88:
|
case 88:
|
||||||
launch_mul_mat_q<type, 88, mmq_get_nwarps( 88)>(args, stream);
|
launch_mul_mat_q<type, 88>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 96:
|
case 96:
|
||||||
launch_mul_mat_q<type, 96, mmq_get_nwarps( 96)>(args, stream);
|
launch_mul_mat_q<type, 96>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 104:
|
case 104:
|
||||||
launch_mul_mat_q<type, 104, mmq_get_nwarps(104)>(args, stream);
|
launch_mul_mat_q<type, 104>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 112:
|
case 112:
|
||||||
launch_mul_mat_q<type, 112, mmq_get_nwarps(112)>(args, stream);
|
launch_mul_mat_q<type, 112>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 120:
|
case 120:
|
||||||
launch_mul_mat_q<type, 120, mmq_get_nwarps(120)>(args, stream);
|
launch_mul_mat_q<type, 120>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
case 128:
|
case 128:
|
||||||
launch_mul_mat_q<type, 128, mmq_get_nwarps(128)>(args, stream);
|
launch_mul_mat_q<type, 128>(ctx, args, stream);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
|
fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
|
||||||
@ -2114,7 +2293,7 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define DECL_MMQ_CASE(type) \
|
#define DECL_MMQ_CASE(type) \
|
||||||
template void mul_mat_q_case<type>(const mmq_args & args, cudaStream_t stream) \
|
template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
|
||||||
|
|
||||||
extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
|
extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
|
||||||
extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
|
extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
|
||||||
|
Loading…
Reference in New Issue
Block a user