mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-28 04:47:04 +01:00
iq2_xs: attempt to fix AVX dot product for QK_K = 64
Tests pass, but I get gibberish.
This commit is contained in:
parent
13ba37f1aa
commit
28e6146c11
@ -9492,15 +9492,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
|
||||
|
||||
#elif defined(__AVX2__)
|
||||
|
||||
const __m128i m4 = _mm_set1_epi8(0xf);
|
||||
const __m128i m1 = _mm_set1_epi8(1);
|
||||
const __m256i m511 = _mm256_set1_epi16(511);
|
||||
const __m256i mone = _mm256_set1_epi8(1);
|
||||
|
||||
static const uint8_t k_bit_helper[32] = {
|
||||
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
||||
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
||||
};
|
||||
static const char block_sign_shuffle_mask_1[32] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
|
||||
0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
|
||||
@ -9514,11 +9506,77 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
|
||||
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
|
||||
};
|
||||
|
||||
const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
|
||||
const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
|
||||
const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
|
||||
const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
|
||||
|
||||
#if QK_K == 64
|
||||
static const uint8_t k_bit_helper[16] = {
|
||||
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
||||
};
|
||||
const __m128i bit_helper = _mm_loadu_si128((const __m128i*)k_bit_helper);
|
||||
const __m128i m511 = _mm_set1_epi16(511);
|
||||
typedef union {
|
||||
__m128i vec_index;
|
||||
uint16_t index[8];
|
||||
} index_t;
|
||||
|
||||
index_t idx;
|
||||
__m256 accumf = _mm256_setzero_ps();
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||
const __m128i q2_data = _mm_loadu_si128((const __m128i*)x[i].qs);
|
||||
idx.vec_index = _mm_and_si128(q2_data, m511);
|
||||
|
||||
const __m128i partial_sign_bits = _mm_srli_epi16(q2_data, 9);
|
||||
const __m128i partial_sign_bits_upper = _mm_srli_epi16(q2_data, 13);
|
||||
const __m128i partial_sign_bits_for_counting = _mm_xor_si128(partial_sign_bits, partial_sign_bits_upper);
|
||||
|
||||
const __m128i odd_bits = _mm_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);
|
||||
const __m128i full_sign_bits = _mm_or_si128(partial_sign_bits, odd_bits);
|
||||
const __m256i full_signs = _mm256_set_m128i(full_sign_bits, full_sign_bits);
|
||||
|
||||
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
||||
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)(y[i].qs+32));
|
||||
|
||||
const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[idx.index[3]], iq2xs_grid[idx.index[2]],
|
||||
iq2xs_grid[idx.index[1]], iq2xs_grid[idx.index[0]]);
|
||||
const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[idx.index[7]], iq2xs_grid[idx.index[6]],
|
||||
iq2xs_grid[idx.index[5]], iq2xs_grid[idx.index[4]]);
|
||||
|
||||
__m256i signs;
|
||||
signs = _mm256_shuffle_epi8(full_signs, block_sign_shuffle_1);
|
||||
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
|
||||
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));
|
||||
|
||||
signs = _mm256_shuffle_epi8(full_signs, block_sign_shuffle_2);
|
||||
signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
|
||||
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));
|
||||
|
||||
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
|
||||
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
|
||||
|
||||
const __m256i sc1 = _mm256_set_m128i(_mm_set1_epi16(2*(x[i].scales[0] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[0] & 0xf)+1));
|
||||
const __m256i sc2 = _mm256_set_m128i(_mm_set1_epi16(2*(x[i].scales[1] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[1] & 0xf)+1));
|
||||
|
||||
const __m256i sum = _mm256_add_epi32(_mm256_madd_epi16(sc1, dot1), _mm256_madd_epi16(sc2, dot2));
|
||||
|
||||
accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sum), accumf);
|
||||
|
||||
}
|
||||
|
||||
*s = 0.125f * hsum_float_8(accumf);
|
||||
#else
|
||||
|
||||
static const uint8_t k_bit_helper[32] = {
|
||||
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
||||
0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
|
||||
};
|
||||
const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
|
||||
const __m256i m511 = _mm256_set1_epi16(511);
|
||||
const __m128i m4 = _mm_set1_epi8(0xf);
|
||||
const __m128i m1 = _mm_set1_epi8(1);
|
||||
|
||||
uint64_t aux64;
|
||||
|
||||
// somewhat hacky, but gives a significant boost in performance
|
||||
@ -9607,6 +9665,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
|
||||
}
|
||||
|
||||
*s = 0.125f * hsum_float_8(accumf);
|
||||
#endif
|
||||
|
||||
#else
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user