mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-15 06:40:45 +01:00
iq1_s: WIP AVX2 dot product - something is not right
This commit is contained in:
parent
d94139bf27
commit
592b3b26bb
@ -9282,6 +9282,14 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef __AVX2__
|
||||
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
|
||||
const __m256i ax = _mm256_sign_epi8(x, x);
|
||||
const __m256i sy = _mm256_sign_epi8(y, x);
|
||||
return _mm256_maddubs_epi16(ax, sy);
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_iq1_s_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
assert(n % QK_K == 0);
|
||||
|
||||
@ -9290,6 +9298,59 @@ void ggml_vec_dot_iq1_s_q8_K(const int n, float * restrict s, const void * restr
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
#if defined __AVX2__
|
||||
|
||||
const __m128i m8 = _mm_set1_epi8(0x08);
|
||||
const __m128i m7 = _mm_set1_epi8(0x07);
|
||||
const __m128i shuffle_h = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0);
|
||||
const __m128i shuffle_s[4] = {
|
||||
_mm_set_epi32(0x03030303, 0x02020202, 0x01010101, 0x00000000),
|
||||
_mm_set_epi32(0x07070707, 0x06060606, 0x05050505, 0x04040404),
|
||||
_mm_set_epi32(0x0b0b0b0b, 0x0a0a0a0a, 0x09090909, 0x08080808),
|
||||
_mm_set_epi32(0x0f0f0f0f, 0x0e0e0e0e, 0x0d0d0d0d, 0x0c0c0c0c)
|
||||
};
|
||||
|
||||
uint64_t aux64;
|
||||
|
||||
__m256i v_gindex;
|
||||
const uint16_t * gindex = (const uint16_t *)&v_gindex;
|
||||
|
||||
__m256 accum = _mm256_setzero_ps();
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
const int8_t * q8 = y[i].qs;
|
||||
const uint8_t * qs = x[i].qs;
|
||||
const uint8_t * sc = x[i].scales;
|
||||
|
||||
__m256i sumi = _mm256_setzero_si256();
|
||||
for (int i128 = 0; i128 < QK_K/128; ++i128) {
|
||||
const __m128i ql = _mm_loadu_si128((const __m128i*)qs); qs += 16;
|
||||
memcpy(&aux64, sc, 8); sc += 8;
|
||||
const __m128i qh = _mm_shuffle_epi8(_mm_set_epi64x(aux64 >> 4, aux64), shuffle_h);
|
||||
const __m256i hbit = _mm256_cvtepi8_epi16(_mm_and_si128(qh, m8));
|
||||
v_gindex = _mm256_or_si256(_mm256_cvtepi8_epi16(ql), _mm256_slli_epi16(hbit, 5));
|
||||
const __m128i scales = _mm_and_si128(qh, m7);
|
||||
|
||||
for (int i32 = 0; i32 < 4; ++i32) {
|
||||
const __m256i q8b = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
|
||||
const __m256i q1b = _mm256_set_epi64x(iq1s_grid[gindex[4*i32+3]], iq1s_grid[gindex[4*i32+2]],
|
||||
iq1s_grid[gindex[4*i32+1]], iq1s_grid[gindex[4*i32+0]]);
|
||||
const __m256i dot = mul_add_epi8(q1b, q8b);
|
||||
const __m256i s16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, shuffle_s[i32]));
|
||||
const __m256i p = _mm256_madd_epi16(s16, dot);
|
||||
sumi = _mm256_add_epi32(sumi, p);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
accum = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)), _mm256_cvtepi32_ps(sumi), accum);
|
||||
|
||||
}
|
||||
|
||||
*s = hsum_float_8(accum);
|
||||
|
||||
#else
|
||||
|
||||
int db[4];
|
||||
uint16_t idx[4];
|
||||
|
||||
@ -9326,6 +9387,8 @@ void ggml_vec_dot_iq1_s_q8_K(const int n, float * restrict s, const void * restr
|
||||
|
||||
*s = sumf;
|
||||
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
// ================================ IQ2 quantization =============================================
|
||||
|
Loading…
Reference in New Issue
Block a user