mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-26 12:21:40 +01:00
Adding SSE instructions to ggml_vec_dot_q4_0_q8_0 (#1413)
This commit is contained in:
parent
0cd22e190a
commit
ac0cd259d5
135
ggml.c
135
ggml.c
@ -472,7 +472,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
||||
// quantization
|
||||
//
|
||||
|
||||
#if __AVX__ || __AVX2__ || __AVX512F__
|
||||
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
|
||||
// multiply int8_t, add results pairwise twice
|
||||
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
|
||||
// Get absolute values of x vectors
|
||||
@ -485,6 +485,7 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
|
||||
return _mm_madd_epi16(ones, dot);
|
||||
}
|
||||
|
||||
#if __AVX__ || __AVX2__ || __AVX512F__
|
||||
// horizontally add 8 floats
|
||||
static inline float hsum_float_8(const __m256 x) {
|
||||
__m128 res = _mm256_extractf128_ps(x, 1);
|
||||
@ -596,7 +597,19 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
||||
return _mm_packus_epi16( bytes1, bytes2);
|
||||
}
|
||||
#endif
|
||||
#elif defined(__SSSE3__)
|
||||
// horizontally add 4x4 floats
|
||||
static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
|
||||
__m128 res_0 =_mm_hadd_ps(a, b);
|
||||
__m128 res_1 =_mm_hadd_ps(c, d);
|
||||
__m128 res =_mm_hadd_ps(res_0, res_1);
|
||||
res =_mm_hadd_ps(res, res);
|
||||
res =_mm_hadd_ps(res, res);
|
||||
|
||||
return _mm_cvtss_f32(res);
|
||||
}
|
||||
#endif // __AVX__ || __AVX2__ || __AVX512F__
|
||||
#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
|
||||
|
||||
#if __ARM_NEON
|
||||
|
||||
@ -2129,6 +2142,126 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc);
|
||||
#elif defined(__SSSE3__)
|
||||
// set constants
|
||||
const __m128i lowMask = _mm_set1_epi8(0xF);
|
||||
const __m128i off = _mm_set1_epi8(8);
|
||||
|
||||
// Initialize accumulator with zeros
|
||||
__m128 acc_0 = _mm_setzero_ps();
|
||||
__m128 acc_1 = _mm_setzero_ps();
|
||||
__m128 acc_2 = _mm_setzero_ps();
|
||||
__m128 acc_3 = _mm_setzero_ps();
|
||||
|
||||
// First round without accumulation
|
||||
{
|
||||
_mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0);
|
||||
_mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
|
||||
|
||||
// Compute combined scale for the block 0 and 1
|
||||
const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[0].d ), _mm_set1_ps( y[0].d ) );
|
||||
|
||||
const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
|
||||
|
||||
__m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
|
||||
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
|
||||
bx_0 = _mm_sub_epi8(bx_0, off);
|
||||
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
|
||||
|
||||
__m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
|
||||
__m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16));
|
||||
bx_1 = _mm_sub_epi8(bx_1, off);
|
||||
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
|
||||
|
||||
_mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0);
|
||||
_mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
|
||||
|
||||
// Compute combined scale for the block 2 and 3
|
||||
const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[1].d ), _mm_set1_ps( y[1].d ) );
|
||||
|
||||
const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
|
||||
|
||||
__m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
|
||||
__m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs);
|
||||
bx_2 = _mm_sub_epi8(bx_2, off);
|
||||
const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
|
||||
|
||||
__m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
|
||||
__m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16));
|
||||
bx_3 = _mm_sub_epi8(bx_3, off);
|
||||
const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
|
||||
|
||||
// Convert int32_t to float
|
||||
__m128 p0 = _mm_cvtepi32_ps(i32_0);
|
||||
__m128 p1 = _mm_cvtepi32_ps(i32_1);
|
||||
__m128 p2 = _mm_cvtepi32_ps(i32_2);
|
||||
__m128 p3 = _mm_cvtepi32_ps(i32_3);
|
||||
|
||||
// Apply the scale
|
||||
acc_0 = _mm_mul_ps( d_0_1, p0 );
|
||||
acc_1 = _mm_mul_ps( d_0_1, p1 );
|
||||
acc_2 = _mm_mul_ps( d_2_3, p2 );
|
||||
acc_3 = _mm_mul_ps( d_2_3, p3 );
|
||||
}
|
||||
|
||||
// Main loop
|
||||
for (int i = 2; i < nb; i+=2) {
|
||||
_mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
|
||||
_mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
|
||||
|
||||
// Compute combined scale for the block 0 and 1
|
||||
const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[i].d ), _mm_set1_ps( y[i].d ) );
|
||||
|
||||
const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
|
||||
|
||||
__m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
|
||||
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
|
||||
bx_0 = _mm_sub_epi8(bx_0, off);
|
||||
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
|
||||
|
||||
__m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
|
||||
__m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
|
||||
bx_1 = _mm_sub_epi8(bx_1, off);
|
||||
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
|
||||
|
||||
_mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
|
||||
_mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
|
||||
|
||||
// Compute combined scale for the block 2 and 3
|
||||
const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[i + 1].d ), _mm_set1_ps( y[i + 1].d ) );
|
||||
|
||||
const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
|
||||
|
||||
__m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
|
||||
__m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs);
|
||||
bx_2 = _mm_sub_epi8(bx_2, off);
|
||||
const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
|
||||
|
||||
__m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
|
||||
__m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16));
|
||||
bx_3 = _mm_sub_epi8(bx_3, off);
|
||||
const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
|
||||
|
||||
// Convert int32_t to float
|
||||
__m128 p0 = _mm_cvtepi32_ps(i32_0);
|
||||
__m128 p1 = _mm_cvtepi32_ps(i32_1);
|
||||
__m128 p2 = _mm_cvtepi32_ps(i32_2);
|
||||
__m128 p3 = _mm_cvtepi32_ps(i32_3);
|
||||
|
||||
// Apply the scale
|
||||
__m128 p0_d = _mm_mul_ps( d_0_1, p0 );
|
||||
__m128 p1_d = _mm_mul_ps( d_0_1, p1 );
|
||||
__m128 p2_d = _mm_mul_ps( d_2_3, p2 );
|
||||
__m128 p3_d = _mm_mul_ps( d_2_3, p3 );
|
||||
|
||||
// Acummulate
|
||||
acc_0 = _mm_add_ps(p0_d, acc_0);
|
||||
acc_1 = _mm_add_ps(p1_d, acc_1);
|
||||
acc_2 = _mm_add_ps(p2_d, acc_2);
|
||||
acc_3 = _mm_add_ps(p3_d, acc_3);
|
||||
}
|
||||
|
||||
*s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
|
||||
#else
|
||||
// scalar
|
||||
float sumf = 0.0;
|
||||
|
Loading…
Reference in New Issue
Block a user