CUDA: optimize and refactor MMQ (#8416)

* CUDA: optimize and refactor MMQ

* explicit q8_1 memory layouts, add documentation
This commit is contained in:
Johannes Gäßler 2024-07-11 16:47:47 +02:00 committed by GitHub
parent a977c11544
commit 808aba3916
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 867 additions and 687 deletions

View File

@ -70,6 +70,10 @@ struct mma_int_A_I16K8 {
}
#endif // defined(INT8_MMA_AVAILABLE)
}
__device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
}
};
struct mma_int_B_J8K4 {

File diff suppressed because it is too large Load Diff

View File

@ -37,47 +37,92 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
reinterpret_cast<half&>(y[ib].ds.y) = sum;
}
template <bool need_sum>
template <mmq_q8_1_ds_layout ds_layout>
static __global__ void quantize_mmq_q8_1(
const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
if (ix0 >= kx0_padded) {
return;
}
const float4 * x4 = (const float4 *) x;
const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f;
float amax = fabsf(xi);
// Load 4 floats per thread and calculate max. abs. value between them:
const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
float amax = fabsf(xi.x);
amax = fmaxf(amax, fabsf(xi.y));
amax = fmaxf(amax, fabsf(xi.z));
amax = fmaxf(amax, fabsf(xi.w));
amax = warp_reduce_max(amax);
float sum;
if (need_sum) {
sum = warp_reduce_sum(xi);
// Exchange max. abs. value between vals_per_scale/4 threads.
#pragma unroll
for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
}
const float d = amax / 127;
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
float sum;
if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
sum = xi.x + xi.y + xi.z + xi.w;
y[ib].qs[iqs] = q;
// Exchange calculate sum across vals_per_sum/4 threads.
#pragma unroll
for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
}
}
const float d_inv = 127.0f / amax;
char4 q;
q.x = roundf(xi.x*d_inv);
q.y = roundf(xi.y*d_inv);
q.z = roundf(xi.z*d_inv);
q.w = roundf(xi.w*d_inv);
// Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
char4 * yqs4 = (char4 *) y[ib].qs;
yqs4[iqs/4] = q;
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
if (iqs % 16 != 0 || iqs >= 96) {
return;
}
y[ib].d2s6[2 + iqs/16] = sum;
if (iqs % 64 != 0) {
return;
}
const float d = 1.0f / d_inv;
y[ib].d2s6[iqs/64] = d;
if (iqs % QK8_1 != 0) {
return;
}
if (need_sum) {
y[ib].ds[iqs/QK8_1] = make_half2(d, sum);
if (iqs % 32 != 0) {
return;
}
const float d = 1.0f / d_inv;
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
y[ib].ds4[iqs/32] = make_half2(d, sum);
} else {
((float *) y[ib].ds)[iqs/QK8_1] = d;
y[ib].d4[iqs/32] = d;
}
}
@ -101,12 +146,24 @@ void quantize_mmq_q8_1_cuda(
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
const dim3 num_blocks(block_num_x, kx1, channels);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
if (mmq_need_sum(type_x)) {
quantize_mmq_q8_1<true><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
} else {
quantize_mmq_q8_1<false><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
switch (mmq_get_q8_1_ds_layout(type_x)) {
case MMQ_Q8_1_DS_LAYOUT_D4:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
break;
case MMQ_Q8_1_DS_LAYOUT_DS4:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
break;
case MMQ_Q8_1_DS_LAYOUT_D2S6:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
break;
default:
GGML_ASSERT(false);
break;
}
}

View File

@ -5,7 +5,11 @@
#include <cstdint>
#define CUDA_QUANTIZE_BLOCK_SIZE 256
#define CUDA_QUANTIZE_BLOCK_SIZE 256
#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128
static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access.");
static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
typedef void (*quantize_cuda_t)(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,

View File

@ -189,7 +189,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
}
#define VDR_Q2_K_Q8_1_MMVQ 1
#define VDR_Q2_K_Q8_1_MMQ 2
#define VDR_Q2_K_Q8_1_MMQ 4
// contiguous v/x values
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
@ -219,32 +219,56 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
return dm2f.x*sumf_d - dm2f.y*sumf_m;
}
// contiguous u/y values
// contiguous v/x + u/y values
template <int ns8>
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) {
const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) {
float sumf_d = 0.0f;
float sumf_m = 0.0f;
float sumf = 0.0f;
float sumf_d8 = 0.0f;
#pragma unroll
for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]);
int sumi_d = 0;
int sumi_m = 0;
for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) {
const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]);
int sumi_d0 = 0;
const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]);
int sumi_d1 = 0;
const int vi0 = v[i0/(QI8_1/2)];
#pragma unroll
for (int i = i0; i < i0 + QI8_1/2; ++i) {
const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303;
sumi_d = ggml_cuda_dp4a(vi, u[i], sumi_d); // SIMD dot product
sumi_m = ggml_cuda_dp4a(0x01010101, u[i], sumi_m);
sumi_d0 = ggml_cuda_dp4a(v[i], u[i], sumi_d0);
}
sumf_d8 += dm2f0.x * sumi_d0;
sumf_d += dm2f.x * sumi_d;
sumf_m += dm2f.y * sumi_m;
#pragma unroll
for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
sumi_d1 = ggml_cuda_dp4a(v[i], u[i], sumi_d1);
}
sumf_d8 += dm2f1.x * sumi_d1;
if (i0/QI8_1 < ns8) {
const float2 s8f = __half22float2(s8[i0/QI8_1]);
sumf -= dm2f0.y*s8f.x;
sumf -= dm2f1.y*s8f.y;
} else {
int sumi_m0 = 0;
#pragma unroll
for (int i = i0; i < i0 + QI8_1/2; ++i) {
sumi_m0 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m0);
}
sumf_d8 -= dm2f0.y * sumi_m0;
int sumi_m1 = 0;
#pragma unroll
for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
sumi_m1 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m1);
}
sumf_d8 -= dm2f1.y * sumi_m1;
}
}
return d8*(sumf_d - sumf_m);
return sumf + d8*sumf_d8;
}
#define VDR_Q3_K_Q8_1_MMVQ 1
@ -283,7 +307,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
return d3 * sumf;
}
// contiguous u/y values
// contiguous v/x + u/y values
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
const float & d3, const float & d8) {
@ -296,8 +320,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
#pragma unroll
for (int i = i0; i < i0 + QI8_1/2; ++i) {
const int vi = __vsubss4((v[i/2] >> (4*(i%2))) & 0x0F0F0F0F, 0x04040404);
sumi_sc = ggml_cuda_dp4a(vi, u[i], sumi_sc); // SIMD dot product
sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product
}
sumi += sumi_sc * scales[i0 / (QI8_1/2)];
@ -334,7 +357,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
return dm4f.x*sumf_d - dm4f.y*sumf_m;
}
// contiguous u/y values
// contiguous v/x + u/y values
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
@ -397,7 +420,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
return dm5f.x*sumf_d - dm5f.y*sumf_m;
}
// contiguous u/y values
// contiguous v/x + u/y values
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
@ -451,13 +474,16 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
return d*sumf;
}
// contiguous u/y values
// contiguous v/x + u/y values
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
const float & d6, const float * __restrict__ d8) {
float sumf_d = 0.0f;
const int sc_packed = get_int_b4(sc, 0);
const int8_t * sc_reg = (const int8_t *) &sc_packed;
#pragma unroll
for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
@ -471,7 +497,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
}
sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y);
sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y);
}
return d6 * sumf_d;