From b587482287014e5440da047376858d6af88a6340 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 4 Mar 2024 19:43:22 +0200 Subject: [PATCH] iq3_s_mult_shuffle: mult + shuffle based codebook --- ggml-cuda.cu | 53 +++++++++++++++++++++--------------- ggml-quants.c | 74 ++++++++++++++++++++++++++++++++++----------------- 2 files changed, 81 insertions(+), 46 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ff721ea43..6f8c4a3ac 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2375,8 +2375,12 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds // Better (lower PPL), but requires more bit twidling, so slower #define IQ3S_MULTIPLIER 190842953LL #else -#define IQ3S_MULTIPLIER 898886 +#define IQ3S_MULTIPLIER 72968561ULL +//#define IQ3S_MULTIPLIER 540201 +//#define IQ3S_MULTIPLIER 1378231 +//#define IQ3S_MULTIPLIER 898886 //#define IQ3S_MULTIPLIER 842866 +static const __device__ uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15}; #endif template @@ -2400,32 +2404,36 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ aux32[0] = ((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; aux32[1] = ((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; #else - aux32[0] = (((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; - aux32[1] = (((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + //aux32[0] = (((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + //aux32[1] = (((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + aux32[0] = (((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + aux32[1] = (((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); #endif -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics -#ifdef IQ3S_SLOW_MULT - aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101; - aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101; -#endif - uint32_t signs0 = __vcmpeq4(((signs & 0xf) * 0x01010101) & 0x08040201, 0x08040201); - uint32_t signs1 = __vcmpeq4(((signs >> 4) * 0x01010101) & 0x08040201, 0x08040201); - aux32[0] = __vsub4(aux32[0] ^ signs0, signs0); - aux32[1] = __vsub4(aux32[1] ^ signs1, signs1); - for (int j = 0; j < 8; ++j) { - y[j] = d * grid[j]; - } -#else +//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics +//#ifdef IQ3S_SLOW_MULT +// aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101; +// aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101; +//#endif +// uint32_t signs0 = __vcmpeq4(((signs & 0xf) * 0x01010101) & 0x08040201, 0x08040201); +// uint32_t signs1 = __vcmpeq4(((signs >> 4) * 0x01010101) & 0x08040201, 0x08040201); +// aux32[0] = __vsub4(aux32[0] ^ signs0, signs0); +// aux32[1] = __vsub4(aux32[1] ^ signs1, signs1); +// for (int j = 0; j < 8; ++j) { +// y[j] = d * grid[j]; +// } +//#else #ifdef IQ3S_SLOW_MULT for (int j = 0; j < 8; ++j) { y[j] = d * (2*((grid[j]-1)/2) + 1) * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } #else + //static const uint8_t k_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 15}; for (int j = 0; j < 8; ++j) { - y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + //y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + y[j] = d * iq3s_values[grid[j]] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } #endif -#endif +//#endif #else assert(false); #endif @@ -5225,7 +5233,6 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( #endif } -// TODO: don't use lookup table for signs static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics @@ -5233,6 +5240,7 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( const block_iq3_s * bq2 = (const block_iq3_s *) vbq; uint32_t aux32[2]; + uint8_t * aux8 = (uint8_t *)aux32; const int ib32 = iqs; const uint8_t * qs = bq2->qs + 8*ib32; @@ -5249,8 +5257,11 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101; aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101; #else - aux32[0] = (((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; - aux32[1] = (((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + //aux32[0] = (((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + //aux32[1] = (((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + aux32[0] = (((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + aux32[1] = (((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + for (int j = 0; j < 8; ++j) aux8[j] = iq3s_values[aux8[j]]; #endif uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); diff --git a/ggml-quants.c b/ggml-quants.c index 70af6e623..0535f5c25 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -4058,12 +4058,18 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y // Best PPL #define IQ3S_MULTIPLIER 190842953 #else -#define IQ3S_MULTIPLIER 898886 +#define IQ3S_MULTIPLIER 72968561ULL +//#define IQ3S_MULTIPLIER 540201 +//#define IQ3S_MULTIPLIER 1378231 +//#define IQ3S_MULTIPLIER 898886 //#define IQ3S_MULTIPLIER 842866 #endif #define IQ3S_BITS 3 +static const uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15}; +//static const uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 15}; + void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -4099,10 +4105,15 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in y[j] = dl * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } #else - aux32[0] = (((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; - aux32[1] = (((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + //aux32[0] = (((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + //aux32[1] = (((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + //for (int j = 0; j < 8; ++j) { + // y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + //} + aux32[0] = (((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + aux32[1] = (((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); for (int j = 0; j < 8; ++j) { - y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + y[j] = dl * iq3s_values[grid[j]] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } #endif y += 8; @@ -4118,12 +4129,17 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in y[j] = dl * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } #else - aux32[0] = (((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; - aux32[1] = (((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; -#endif + //aux32[0] = (((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + //aux32[1] = (((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101; + //for (int j = 0; j < 8; ++j) { + // y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + //} + aux32[0] = (((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); + aux32[1] = (((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f); for (int j = 0; j < 8; ++j) { - y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); + y[j] = dl * iq3s_values[grid[j]] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } +#endif y += 8; } qh += 2; @@ -10073,12 +10089,13 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); + const __m128i shuffle128 = _mm_loadu_si128((const __m128i *)iq3s_values); + const __m256i shuffle = _mm256_set_m128i(shuffle128, shuffle128); - const __m256i idx_mask = _mm256_set1_epi32(256); - const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8); const __m256i idx_mult = _mm256_set1_epi32(IQ3S_MULTIPLIER); - const __m256i m1 = _mm256_set1_epi8(1); + //const __m256i m1 = _mm256_set1_epi8(1); const __m256i m15 = _mm256_set1_epi32(0x0f0f0f0f); + const __m256i m100 = _mm256_set1_epi32(0x0100); #ifdef IQ3S_SLOW_MULT const __m256i m7 = _mm256_set1_epi32(0x07070707); const __m256i m0 = _mm256_setzero_si256(); @@ -10096,12 +10113,19 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m128i idx_l_8 = _mm_loadu_si128((const __m128i*)qs); qs += 16; - const __m256i idx_l_16 = _mm256_cvtepu8_epi16(idx_l_8); - const __m256i idx_h_l = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[ib32+0]), idx_shift), idx_mask); - const __m256i idx_h_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[ib32+1]), idx_shift), idx_mask); - const __m256i idx_32_l = _mm256_or_si256(idx_h_l, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l_16))); - const __m256i idx_32_h = _mm256_or_si256(idx_h_h, _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l_16, 1))); + + const __m256i q3_low_bytes_1 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)qs)); qs += 8; + const __m256i q3_low_bytes_2 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)qs)); qs += 8; + uint64_t high_bits_spread_1 = ((uint64_t)qh[ib32+0] * 0x0101010101010101ULL) & 0x8040201008040201ULL; + uint64_t high_bits_spread_2 = ((uint64_t)qh[ib32+1] * 0x0101010101010101ULL) & 0x8040201008040201ULL; + const __m256i high_bits_in_low_1 = _mm256_cmpgt_epi32( + _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)&high_bits_spread_1)), + _mm256_setzero_si256()); + const __m256i high_bits_in_low_2 = _mm256_cmpgt_epi32( + _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)&high_bits_spread_2)), + _mm256_setzero_si256()); + const __m256i idx_32_l = _mm256_or_si256(_mm256_and_si256(m100, high_bits_in_low_1), q3_low_bytes_1); + const __m256i idx_32_h = _mm256_or_si256(_mm256_and_si256(m100, high_bits_in_low_2), q3_low_bytes_2); #ifdef IQ3S_SLOW_MULT const __m256i idx_l = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1), m0); @@ -10109,12 +10133,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i idx_h = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1), m0); const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1); #else - //const __m256i idx_l = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1); - //const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1); - //const __m256i idx_h = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1); - //const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1); - const __m256i q2_1 = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1); - const __m256i q2_2 = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1); + const __m256i q2_1 = _mm256_shuffle_epi8(shuffle, _mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15)); + const __m256i q2_2 = _mm256_shuffle_epi8(shuffle, _mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15)); #endif __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16)); @@ -11364,10 +11384,14 @@ static void iq3xs_init_grid512(void) { #ifdef IQ3S_SLOW_MULT aux32 = ((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f; #else - aux32 = (((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f) | 0x01010101; + //aux32 = (((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f) | 0x01010101; + aux32 = ((k * IQ3S_MULTIPLIER) & 0x0f0f0f0f); #endif + //for (int i = 0; i < 4; ++i) { + // pos[i] = 2*((q4[i]-1)/2) + 1; + //} for (int i = 0; i < 4; ++i) { - pos[i] = 2*((q4[i]-1)/2) + 1; + pos[i] = iq3s_values[q4[i]]; } }