mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 06:10:29 +01:00
Adding SSE instructions to ggml_vec_dot_q4_0_q8_0 (#1413)
This commit is contained in:
parent
0cd22e190a
commit
ac0cd259d5
135
ggml.c
135
ggml.c
@ -472,7 +472,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
|||||||
// quantization
|
// quantization
|
||||||
//
|
//
|
||||||
|
|
||||||
#if __AVX__ || __AVX2__ || __AVX512F__
|
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
|
||||||
// multiply int8_t, add results pairwise twice
|
// multiply int8_t, add results pairwise twice
|
||||||
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
|
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
|
||||||
// Get absolute values of x vectors
|
// Get absolute values of x vectors
|
||||||
@ -485,6 +485,7 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
|
|||||||
return _mm_madd_epi16(ones, dot);
|
return _mm_madd_epi16(ones, dot);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if __AVX__ || __AVX2__ || __AVX512F__
|
||||||
// horizontally add 8 floats
|
// horizontally add 8 floats
|
||||||
static inline float hsum_float_8(const __m256 x) {
|
static inline float hsum_float_8(const __m256 x) {
|
||||||
__m128 res = _mm256_extractf128_ps(x, 1);
|
__m128 res = _mm256_extractf128_ps(x, 1);
|
||||||
@ -596,7 +597,19 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
|||||||
return _mm_packus_epi16( bytes1, bytes2);
|
return _mm_packus_epi16( bytes1, bytes2);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
#elif defined(__SSSE3__)
|
||||||
|
// horizontally add 4x4 floats
|
||||||
|
static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
|
||||||
|
__m128 res_0 =_mm_hadd_ps(a, b);
|
||||||
|
__m128 res_1 =_mm_hadd_ps(c, d);
|
||||||
|
__m128 res =_mm_hadd_ps(res_0, res_1);
|
||||||
|
res =_mm_hadd_ps(res, res);
|
||||||
|
res =_mm_hadd_ps(res, res);
|
||||||
|
|
||||||
|
return _mm_cvtss_f32(res);
|
||||||
|
}
|
||||||
#endif // __AVX__ || __AVX2__ || __AVX512F__
|
#endif // __AVX__ || __AVX2__ || __AVX512F__
|
||||||
|
#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
|
||||||
|
|
||||||
#if __ARM_NEON
|
#if __ARM_NEON
|
||||||
|
|
||||||
@ -2129,6 +2142,126 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|||||||
}
|
}
|
||||||
|
|
||||||
*s = hsum_float_8(acc);
|
*s = hsum_float_8(acc);
|
||||||
|
#elif defined(__SSSE3__)
|
||||||
|
// set constants
|
||||||
|
const __m128i lowMask = _mm_set1_epi8(0xF);
|
||||||
|
const __m128i off = _mm_set1_epi8(8);
|
||||||
|
|
||||||
|
// Initialize accumulator with zeros
|
||||||
|
__m128 acc_0 = _mm_setzero_ps();
|
||||||
|
__m128 acc_1 = _mm_setzero_ps();
|
||||||
|
__m128 acc_2 = _mm_setzero_ps();
|
||||||
|
__m128 acc_3 = _mm_setzero_ps();
|
||||||
|
|
||||||
|
// First round without accumulation
|
||||||
|
{
|
||||||
|
_mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0);
|
||||||
|
_mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
|
||||||
|
|
||||||
|
// Compute combined scale for the block 0 and 1
|
||||||
|
const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[0].d ), _mm_set1_ps( y[0].d ) );
|
||||||
|
|
||||||
|
const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
|
||||||
|
|
||||||
|
__m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
|
||||||
|
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
|
||||||
|
bx_0 = _mm_sub_epi8(bx_0, off);
|
||||||
|
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
|
||||||
|
|
||||||
|
__m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
|
||||||
|
__m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16));
|
||||||
|
bx_1 = _mm_sub_epi8(bx_1, off);
|
||||||
|
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
|
||||||
|
|
||||||
|
_mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0);
|
||||||
|
_mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
|
||||||
|
|
||||||
|
// Compute combined scale for the block 2 and 3
|
||||||
|
const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[1].d ), _mm_set1_ps( y[1].d ) );
|
||||||
|
|
||||||
|
const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
|
||||||
|
|
||||||
|
__m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
|
||||||
|
__m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs);
|
||||||
|
bx_2 = _mm_sub_epi8(bx_2, off);
|
||||||
|
const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
|
||||||
|
|
||||||
|
__m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
|
||||||
|
__m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16));
|
||||||
|
bx_3 = _mm_sub_epi8(bx_3, off);
|
||||||
|
const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
|
||||||
|
|
||||||
|
// Convert int32_t to float
|
||||||
|
__m128 p0 = _mm_cvtepi32_ps(i32_0);
|
||||||
|
__m128 p1 = _mm_cvtepi32_ps(i32_1);
|
||||||
|
__m128 p2 = _mm_cvtepi32_ps(i32_2);
|
||||||
|
__m128 p3 = _mm_cvtepi32_ps(i32_3);
|
||||||
|
|
||||||
|
// Apply the scale
|
||||||
|
acc_0 = _mm_mul_ps( d_0_1, p0 );
|
||||||
|
acc_1 = _mm_mul_ps( d_0_1, p1 );
|
||||||
|
acc_2 = _mm_mul_ps( d_2_3, p2 );
|
||||||
|
acc_3 = _mm_mul_ps( d_2_3, p3 );
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main loop
|
||||||
|
for (int i = 2; i < nb; i+=2) {
|
||||||
|
_mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
|
||||||
|
_mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
|
||||||
|
|
||||||
|
// Compute combined scale for the block 0 and 1
|
||||||
|
const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[i].d ), _mm_set1_ps( y[i].d ) );
|
||||||
|
|
||||||
|
const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
|
||||||
|
|
||||||
|
__m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
|
||||||
|
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
|
||||||
|
bx_0 = _mm_sub_epi8(bx_0, off);
|
||||||
|
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
|
||||||
|
|
||||||
|
__m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
|
||||||
|
__m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
|
||||||
|
bx_1 = _mm_sub_epi8(bx_1, off);
|
||||||
|
const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
|
||||||
|
|
||||||
|
_mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
|
||||||
|
_mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
|
||||||
|
|
||||||
|
// Compute combined scale for the block 2 and 3
|
||||||
|
const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[i + 1].d ), _mm_set1_ps( y[i + 1].d ) );
|
||||||
|
|
||||||
|
const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
|
||||||
|
|
||||||
|
__m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
|
||||||
|
__m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs);
|
||||||
|
bx_2 = _mm_sub_epi8(bx_2, off);
|
||||||
|
const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
|
||||||
|
|
||||||
|
__m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
|
||||||
|
__m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16));
|
||||||
|
bx_3 = _mm_sub_epi8(bx_3, off);
|
||||||
|
const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
|
||||||
|
|
||||||
|
// Convert int32_t to float
|
||||||
|
__m128 p0 = _mm_cvtepi32_ps(i32_0);
|
||||||
|
__m128 p1 = _mm_cvtepi32_ps(i32_1);
|
||||||
|
__m128 p2 = _mm_cvtepi32_ps(i32_2);
|
||||||
|
__m128 p3 = _mm_cvtepi32_ps(i32_3);
|
||||||
|
|
||||||
|
// Apply the scale
|
||||||
|
__m128 p0_d = _mm_mul_ps( d_0_1, p0 );
|
||||||
|
__m128 p1_d = _mm_mul_ps( d_0_1, p1 );
|
||||||
|
__m128 p2_d = _mm_mul_ps( d_2_3, p2 );
|
||||||
|
__m128 p3_d = _mm_mul_ps( d_2_3, p3 );
|
||||||
|
|
||||||
|
// Acummulate
|
||||||
|
acc_0 = _mm_add_ps(p0_d, acc_0);
|
||||||
|
acc_1 = _mm_add_ps(p1_d, acc_1);
|
||||||
|
acc_2 = _mm_add_ps(p2_d, acc_2);
|
||||||
|
acc_3 = _mm_add_ps(p3_d, acc_3);
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
|
||||||
#else
|
#else
|
||||||
// scalar
|
// scalar
|
||||||
float sumf = 0.0;
|
float sumf = 0.0;
|
||||||
|
Loading…
Reference in New Issue
Block a user