mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 05:17:21 +01:00
Kahan summation on Q4_1
This commit is contained in:
parent
69071d3b6b
commit
66ea164e1d
30
ggml.c
30
ggml.c
@ -1704,6 +1704,9 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
|
|||||||
// Accumulator for constant offsets
|
// Accumulator for constant offsets
|
||||||
__m128 acc_offset = _mm_setzero_ps(); //0.0f;
|
__m128 acc_offset = _mm_setzero_ps(); //0.0f;
|
||||||
|
|
||||||
|
__m256 acc_err = _mm256_setzero_ps();
|
||||||
|
__m128 acc_offset_err = _mm_setzero_ps();
|
||||||
|
|
||||||
// Main loop
|
// Main loop
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
const float * m0 = (const float *) (pm0 + i*bs);
|
const float * m0 = (const float *) (pm0 + i*bs);
|
||||||
@ -1756,17 +1759,30 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
|
|||||||
__m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
|
__m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
|
||||||
__m256 sums = _mm256_cvtepi32_ps( sumsi );
|
__m256 sums = _mm256_cvtepi32_ps( sumsi );
|
||||||
|
|
||||||
// Apply the scales, and accumulate
|
|
||||||
// acc += d0*m1*x + d1*m0*y
|
|
||||||
acc = _mm256_fmadd_ps( cross_scales, sums, acc );
|
|
||||||
|
|
||||||
// Convert int32_t to float
|
// Convert int32_t to float
|
||||||
__m256 p = _mm256_cvtepi32_ps( i32 );
|
__m256 p = _mm256_cvtepi32_ps( i32 );
|
||||||
// acc += d0*d1*x*y
|
|
||||||
acc = _mm256_fmadd_ps( scale_01, p, acc );
|
// Apply the scales, and accumulate
|
||||||
|
// Use Kahan error compensation
|
||||||
|
// acc += d0*m1*x + d1*m0*y + d0*d1*x*y
|
||||||
|
__m256 delta = _mm256_mul_ps( scale_01, p );
|
||||||
|
delta = _mm256_fmadd_ps( cross_scales, sums, delta );
|
||||||
|
delta = _mm256_sub_ps( delta, acc_err );
|
||||||
|
|
||||||
|
__m256 acc_next = _mm256_add_ps( acc, delta );
|
||||||
|
acc_err = _mm256_sub_ps( _mm256_sub_ps( acc_next, acc ), delta );
|
||||||
|
|
||||||
|
acc = acc_next;
|
||||||
|
|
||||||
|
__m128 offs_delta = _mm_mul_ss( _mm256_castps256_ps128( m0v ), _mm256_castps256_ps128( m1v ) );
|
||||||
|
offs_delta = _mm_sub_ss( offs_delta, acc_offset_err );
|
||||||
|
|
||||||
|
__m128 offs_next = _mm_add_ss( acc_offset, offs_delta );
|
||||||
|
acc_offset_err = _mm_sub_ss( _mm_sub_ss( offs_next, acc_offset ), offs_delta );
|
||||||
|
acc_offset = offs_next;
|
||||||
|
|
||||||
// acc_offset += m0*m1 (avoid reloading from RAM)
|
// acc_offset += m0*m1 (avoid reloading from RAM)
|
||||||
acc_offset = _mm_fmadd_ss( _mm256_castps256_ps128( m0v ), _mm256_castps256_ps128( m1v ), acc_offset );
|
//acc_offset = _mm_fmadd_ss( _mm256_castps256_ps128( m0v ), _mm256_castps256_ps128( m1v ), acc_offset );
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return horizontal sum of the acc vector
|
// Return horizontal sum of the acc vector
|
||||||
|
Loading…
x
Reference in New Issue
Block a user