Q6_K AVX improvements (#10118)

* q6_k instruction reordering attempt

* better subtract method

* should be theoretically faster

small improvement with shuffle lut, likely because all loads are already done at that stage

* optimize bit fiddling

* handle -32 offset separately. bsums exists for a reason!

* use shift

* Update ggml-quants.c

* have to update ci macos version to 13 as 12 doesnt work now. 13 is still x86
This commit is contained in:
Eve 2024-11-04 22:06:31 +00:00 committed by GitHub
parent d5a409e57f
commit 3407364776
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 38 additions and 51 deletions

View File

@ -92,7 +92,7 @@ jobs:
name: llama-bin-macos-arm64.zip name: llama-bin-macos-arm64.zip
macOS-latest-cmake-x64: macOS-latest-cmake-x64:
runs-on: macos-12 runs-on: macos-13
steps: steps:
- name: Clone - name: Clone

View File

@ -9104,10 +9104,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
#elif defined __AVX__ #elif defined __AVX__
const __m128i m4 = _mm_set1_epi8(0xF);
const __m128i m3 = _mm_set1_epi8(3); const __m128i m3 = _mm_set1_epi8(3);
const __m128i m32s = _mm_set1_epi8(32); const __m128i m15 = _mm_set1_epi8(15);
const __m128i m2 = _mm_set1_epi8(2);
__m256 acc = _mm256_setzero_ps(); __m256 acc = _mm256_setzero_ps();
@ -9119,12 +9117,20 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const uint8_t * restrict qh = x[i].qh; const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs; const int8_t * restrict q8 = y[i].qs;
// handle the q6_k -32 offset separately using bsums
const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums);
const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1);
const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5);
const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5);
__m128i sumi_0 = _mm_setzero_si128(); __m128i sumi_0 = _mm_setzero_si128();
__m128i sumi_1 = _mm_setzero_si128(); __m128i sumi_1 = _mm_setzero_si128();
__m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); int is = 0;
for (int j = 0; j < QK_K/128; ++j) { for (int j = 0; j < QK_K/128; ++j) {
const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16; const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
@ -9132,26 +9138,26 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4); const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4); const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4); const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2);
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4); const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2);
const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4); const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48));
const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4); const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48));
const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4); const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2);
const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4); const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2);
const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0); const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1); const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2); const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3); const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);
const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4); const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);
const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5); const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);
const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6); const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);
const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7); const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);
const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
@ -9162,15 +9168,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
__m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
__m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
__m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
__m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
__m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
__m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
__m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
__m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
__m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0); __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
__m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1); __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
__m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2); __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
@ -9180,32 +9177,20 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
__m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6); __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
__m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7); __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
p16_0 = _mm_sub_epi16(p16_0, q8s_0); const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
p16_1 = _mm_sub_epi16(p16_1, q8s_1); const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
p16_2 = _mm_sub_epi16(p16_2, q8s_2); const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
p16_3 = _mm_sub_epi16(p16_3, q8s_3); const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
p16_4 = _mm_sub_epi16(p16_4, q8s_4); is += 4;
p16_5 = _mm_sub_epi16(p16_5, q8s_5);
p16_6 = _mm_sub_epi16(p16_6, q8s_6);
p16_7 = _mm_sub_epi16(p16_7, q8s_7);
const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
shuffle = _mm_add_epi8(shuffle, m2);
const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
shuffle = _mm_add_epi8(shuffle, m2);
const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
shuffle = _mm_add_epi8(shuffle, m2);
const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
shuffle = _mm_add_epi8(shuffle, m2);
p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1);
p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3);
p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4); p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5); p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5);
p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6); p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7); p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7);
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
@ -9214,8 +9199,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
} }
__m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0);
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1);
const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc);
} }
*s = hsum_float_8(acc); *s = hsum_float_8(acc);