mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 14:20:31 +01:00
Optimize AVX2 ggml_vec_dot_q4_0 (#642)
This commit is contained in:
parent
02c5b27e91
commit
1d08882afa
31
ggml.c
31
ggml.c
@ -1833,7 +1833,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|||||||
const block_q4_0 * restrict x = vx;
|
const block_q4_0 * restrict x = vx;
|
||||||
const block_q4_0 * restrict y = vy;
|
const block_q4_0 * restrict y = vy;
|
||||||
|
|
||||||
ggml_float sumf = 0.0;
|
float sumf = 0.0;
|
||||||
|
|
||||||
#if defined(__ARM_NEON)
|
#if defined(__ARM_NEON)
|
||||||
float sum0 = 0.0f;
|
float sum0 = 0.0f;
|
||||||
@ -1928,7 +1928,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
sumf = (ggml_float)(sum0 + sum1);
|
sumf = sum0 + sum1;
|
||||||
#elif defined(__AVX512F__)
|
#elif defined(__AVX512F__)
|
||||||
// Initialize accumulator with zeros
|
// Initialize accumulator with zeros
|
||||||
__m512 acc0 = _mm512_setzero_ps();
|
__m512 acc0 = _mm512_setzero_ps();
|
||||||
@ -1962,6 +1962,10 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|||||||
__m256 acc = _mm256_setzero_ps();
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
|
||||||
// Main loop
|
// Main loop
|
||||||
|
// TODO: figure a way to do this in a portable way
|
||||||
|
#ifdef __GNUC__
|
||||||
|
#pragma GCC unroll 16
|
||||||
|
#endif
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
// Compute combined scale for the block
|
// Compute combined scale for the block
|
||||||
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
||||||
@ -1975,20 +1979,21 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|||||||
bx = _mm256_sub_epi8( bx, off );
|
bx = _mm256_sub_epi8( bx, off );
|
||||||
by = _mm256_sub_epi8( by, off );
|
by = _mm256_sub_epi8( by, off );
|
||||||
|
|
||||||
// Sign-extend first 16 signed bytes into int16_t
|
// Get absolute values of x vectors
|
||||||
__m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
|
const __m256i ax = _mm256_sign_epi8(bx, bx);
|
||||||
__m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
|
|
||||||
// Compute products of int16_t integers, add pairwise
|
|
||||||
__m256i i32 = _mm256_madd_epi16( x16, y16 );
|
|
||||||
|
|
||||||
// Sign-extend last 16 signed bytes into int16_t vectors
|
// Sign the values of the y vectors
|
||||||
x16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
|
const __m256i sy = _mm256_sign_epi8(by, bx);
|
||||||
y16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
|
|
||||||
// Accumulate products of int16_t integers
|
// Perform multiplication and create 16-bit values
|
||||||
i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16, y16 ) );
|
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
||||||
|
|
||||||
|
const __m256i ones = _mm256_set1_epi16(1);
|
||||||
|
const __m256i i32 = _mm256_madd_epi16(ones, dot);
|
||||||
|
|
||||||
// Convert int32_t to float
|
// Convert int32_t to float
|
||||||
__m256 p = _mm256_cvtepi32_ps( i32 );
|
const __m256 p = _mm256_cvtepi32_ps( i32 );
|
||||||
|
|
||||||
// Apply the scale, and accumulate
|
// Apply the scale, and accumulate
|
||||||
acc = _mm256_fmadd_ps( d, p, acc );
|
acc = _mm256_fmadd_ps( d, p, acc );
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user