diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 4f88a1ae1..328fc01ed 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2370,11 +2370,9 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds } -//#define IQ3S_MULTIPLIER 2469109 //#define IQ3S_MULTIPLIER 746226 //#define IQ3S_MULTIPLIER 717154 #define IQ3S_MULTIPLIER 677595 -//static const __device__ uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15}; template static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -2395,8 +2393,8 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ aux32[0] = ((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; aux32[1] = ((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - aux32[0] = __vadd4(((__vadd4(aux32[0], 0x01010101) >> 1) & 0x07070707) << 1, 0x01010101); - aux32[1] = __vadd4(((__vadd4(aux32[1], 0x01010101) >> 1) & 0x07070707) << 1, 0x01010101); + aux32[0] = (((__vadd4(aux32[0], 0x01010101) >> 1) & 0x07070707) << 1) | 0x01010101; + aux32[1] = (((__vadd4(aux32[1], 0x01010101) >> 1) & 0x07070707) << 1) | 0x01010101; uint32_t signs0 = __vcmpeq4(((signs & 0xf) * 0x01010101) & 0x08040201, 0x08040201); uint32_t signs1 = __vcmpeq4(((signs >> 4) * 0x01010101) & 0x08040201, 0x08040201); aux32[0] = __vsub4(aux32[0] ^ signs0, signs0); @@ -5227,8 +5225,8 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( for (int l = 0; l < 4; ++l) { aux32[0] = ((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; aux32[1] = ((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f; - aux32[0] = __vadd4(((__vadd4(aux32[0], 0x01010101) >> 1) & 0x07070707) << 1, 0x01010101); - aux32[1] = __vadd4(((__vadd4(aux32[1], 0x01010101) >> 1) & 0x07070707) << 1, 0x01010101); + aux32[0] = (((__vadd4(aux32[0], 0x01010101) >> 1) & 0x07070707) << 1) | 0x01010101; + aux32[1] = (((__vadd4(aux32[1], 0x01010101) >> 1) & 0x07070707) << 1) | 0x01010101; uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); const int grid_l = __vsub4(aux32[0] ^ signs0, signs0); diff --git a/ggml-quants.c b/ggml-quants.c index 9f9d299fe..5e98291bc 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -10019,6 +10019,15 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void #endif } +#ifdef __AVX2__ +static inline __m256i shift_left_epi16(__m256i a, __m256i count) { + const __m256i mask = _mm256_set1_epi32(0xffff0000); + const __m256i lo_half = _mm256_sllv_epi32(a, _mm256_andnot_si256(mask, count)); + const __m256i hi_half = _mm256_sllv_epi32(_mm256_and_si256(mask, a), _mm256_srli_epi32(count, 16)); + return _mm256_blend_epi16(lo_half, hi_half, 0xaa); +} +#endif + void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -10109,6 +10118,13 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); + const __m256i idx_mask = _mm256_set1_epi16(256); + const __m256i idx_shift = _mm256_set_epi16(8, 7, 6, 5, 4, 3, 2, 1, 8, 7, 6, 5, 4, 3, 2, 1); + const __m256i idx_mult = _mm256_set1_epi32(IQ3S_MULTIPLIER); + const __m256i m1 = _mm256_set1_epi32(0x01010101); + const __m256i m7 = _mm256_set1_epi32(0x07070707); + const __m256i m15 = _mm256_set1_epi32(0x0f0f0f0f); + __m256 accumf = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; @@ -10121,24 +10137,16 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q2_1 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+0] << 1) & 256)], - iq3xs_grid[qs[6] | ((qh[ib32+0] << 2) & 256)], - iq3xs_grid[qs[5] | ((qh[ib32+0] << 3) & 256)], - iq3xs_grid[qs[4] | ((qh[ib32+0] << 4) & 256)], - iq3xs_grid[qs[3] | ((qh[ib32+0] << 5) & 256)], - iq3xs_grid[qs[2] | ((qh[ib32+0] << 6) & 256)], - iq3xs_grid[qs[1] | ((qh[ib32+0] << 7) & 256)], - iq3xs_grid[qs[0] | ((qh[ib32+0] << 8) & 256)]); - qs += 8; - const __m256i q2_2 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+1] << 1) & 256)], - iq3xs_grid[qs[6] | ((qh[ib32+1] << 2) & 256)], - iq3xs_grid[qs[5] | ((qh[ib32+1] << 3) & 256)], - iq3xs_grid[qs[4] | ((qh[ib32+1] << 4) & 256)], - iq3xs_grid[qs[3] | ((qh[ib32+1] << 5) & 256)], - iq3xs_grid[qs[2] | ((qh[ib32+1] << 6) & 256)], - iq3xs_grid[qs[1] | ((qh[ib32+1] << 7) & 256)], - iq3xs_grid[qs[0] | ((qh[ib32+1] << 8) & 256)]); - qs += 8; + const __m128i idx_l_8 = _mm_loadu_si128((const __m128i*)qs); qs += 16; + const __m256i idx_l_16 = _mm256_cvtepu8_epi16(idx_l_8); + const __m256i idx_h_16 = _mm256_set_m128i(_mm_set1_epi16(qh[ib32+1]), _mm_set1_epi16(qh[ib32+0])); + const __m256i idx_16 = _mm256_or_si256(idx_l_16, _mm256_and_si256(shift_left_epi16(idx_h_16, idx_shift), idx_mask)); + const __m256i idx_32_l = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_16)); + const __m256i idx_32_h = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_16, 1)); + const __m256i idx_l = _mm256_add_epi32(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1); + const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1); + const __m256i idx_h = _mm256_add_epi32(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1); + const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1); __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16)); aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); @@ -10166,7 +10174,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v } - *s = 0.25f * hsum_float_8(accumf); + *s = hsum_float_8(accumf); #else