diff --git a/ggml-quants.c b/ggml-quants.c index 3f2ed0a5f..a24b4b244 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -7605,36 +7605,53 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest #elif defined(__AVX2__) + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + const __m128i m511 = _mm_set1_epi16(511); + const __m128i m127 = _mm_set1_epi16(127); + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + uint64_t aux64; + + // somewhat hacky, but gives a significant boost in performance + __m128i aux_gindex, aux_sindex; + const uint16_t * gindex = (const uint16_t *)&aux_gindex; + const uint16_t * sindex = (const uint16_t *)&aux_sindex; + __m256 accumf = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; const uint16_t * restrict q2 = x[i].qs; - const uint8_t * restrict sc = x[i].scales; const int8_t * restrict q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + __m128i stmp = _mm_set1_epi64x(aux64); + stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4)); + const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1); + __m256i sumi1 = _mm256_setzero_si256(); __m256i sumi2 = _mm256_setzero_si256(); for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[q2[3] & 511], iq2xs_grid[q2[2] & 511], iq2xs_grid[q2[1] & 511], iq2xs_grid[q2[0] & 511]); - const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[q2[7] & 511], iq2xs_grid[q2[6] & 511], iq2xs_grid[q2[5] & 511], iq2xs_grid[q2[4] & 511]); - const __m256i s2_1 = _mm256_set_epi64x(signs64[q2[3] >> 9], signs64[q2[2] >> 9], signs64[q2[1] >> 9], signs64[q2[0] >> 9]); - const __m256i s2_2 = _mm256_set_epi64x(signs64[q2[7] >> 9], signs64[q2[6] >> 9], signs64[q2[5] >> 9], signs64[q2[4] >> 9]); + const __m128i q2_data = _mm_loadu_si128((const __m128i*)q2); q2 += 8; + aux_gindex = _mm_and_si128(q2_data, m511); + aux_sindex = _mm_and_si128(_mm_srli_epi16(q2_data, 9), m127); + const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]], iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]); + const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]], iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]); + const __m256i s2_1 = _mm256_set_epi64x(signs64[sindex[3]], signs64[sindex[2]], signs64[sindex[1]], signs64[sindex[0]]); + const __m256i s2_2 = _mm256_set_epi64x(signs64[sindex[7]], signs64[sindex[6]], signs64[sindex[5]], signs64[sindex[4]]); const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); - const uint16_t ls1 = 2*(sc[0] & 0xf) + 1, ls2 = 2*(sc[0] >> 4) + 1; - const uint16_t ls3 = 2*(sc[1] & 0xf) + 1, ls4 = 2*(sc[1] >> 4) + 1; - const __m256i p1 = _mm256_madd_epi16(dot1, MM256_SET_M128I(_mm_set1_epi16(ls2), _mm_set1_epi16(ls1))); - const __m256i p2 = _mm256_madd_epi16(dot2, MM256_SET_M128I(_mm_set1_epi16(ls4), _mm_set1_epi16(ls3))); - sumi1 = _mm256_add_epi32(sumi1, p1); - sumi2 = _mm256_add_epi32(sumi2, p2); - q2 += 8; - sc += 2; + const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0))); + const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1))); + + sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1)); + sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2)); } accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);