mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-15 14:50:51 +01:00
Update quantize_row_q4_0 for AVX/AVX2
This commit is contained in:
parent
3698f79e6a
commit
5d5f2b2efa
63
ggml.c
63
ggml.c
@ -794,22 +794,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|||||||
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
||||||
x += 32;
|
x += 32;
|
||||||
|
|
||||||
// Compute max(abs(e)) for the block
|
// Compute max for the block
|
||||||
const __m256 signBit = _mm256_set1_ps( -0.0f );
|
__m256 max = _mm256_max_ps( v0, v1 );
|
||||||
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
|
__m256 maxTmp = _mm256_max_ps( v2, v3 );
|
||||||
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
|
max = _mm256_max_ps( max, maxTmp );
|
||||||
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
|
||||||
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
|
|
||||||
|
|
||||||
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
|
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
|
||||||
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
||||||
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
||||||
const float maxScalar = _mm_cvtss_f32( max4 );
|
const float maxScalar = _mm_cvtss_f32( max4 );
|
||||||
|
|
||||||
|
// Compute min for the block
|
||||||
|
__m256 min = _mm256_min_ps( v0, v1 );
|
||||||
|
__m256 minTmp = _mm256_min_ps( v2, v3 );
|
||||||
|
min = _mm256_min_ps( min, minTmp );
|
||||||
|
|
||||||
|
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
|
||||||
|
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
|
||||||
|
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
|
||||||
|
const float minScalar = _mm_cvtss_f32( min4 );
|
||||||
|
|
||||||
// Quantize these floats
|
// Quantize these floats
|
||||||
const float d = maxScalar / 7.0f;
|
const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
|
||||||
|
const float d = magnitude / -8.0f;
|
||||||
y[i].d = d;
|
y[i].d = d;
|
||||||
const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
|
const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
|
||||||
const __m256 mul = _mm256_set1_ps( id );
|
const __m256 mul = _mm256_set1_ps( id );
|
||||||
|
|
||||||
// Apply the multiplier
|
// Apply the multiplier
|
||||||
@ -842,9 +851,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|||||||
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
||||||
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
||||||
|
|
||||||
// Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
|
// Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
|
||||||
const __m256i off = _mm256_set1_epi8( 8 );
|
const __m256i off = _mm256_set1_epi8( 8 );
|
||||||
i0 = _mm256_add_epi8( i0, off );
|
i0 = _mm256_add_epi8( i0, off );
|
||||||
|
const __m256i maxNibble = _mm256_set1_epi8( 15 );
|
||||||
|
i0 = _mm256_min_epi8( i0, maxNibble );
|
||||||
|
|
||||||
// Compress the vector into 4 bit/value, and store
|
// Compress the vector into 4 bit/value, and store
|
||||||
__m128i res = packNibbles( i0 );
|
__m128i res = packNibbles( i0 );
|
||||||
@ -859,22 +870,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|||||||
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
||||||
x += 32;
|
x += 32;
|
||||||
|
|
||||||
// Compute max(abs(e)) for the block
|
// Compute max for the block
|
||||||
const __m256 signBit = _mm256_set1_ps( -0.0f );
|
__m256 max = _mm256_max_ps( v0, v1 );
|
||||||
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
|
__m256 maxTmp = _mm256_max_ps( v2, v3 );
|
||||||
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
|
max = _mm256_max_ps( max, maxTmp );
|
||||||
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
|
||||||
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
|
|
||||||
|
|
||||||
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
|
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
|
||||||
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
||||||
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
||||||
const float maxScalar = _mm_cvtss_f32( max4 );
|
const float maxScalar = _mm_cvtss_f32( max4 );
|
||||||
|
|
||||||
|
// Compute min for the block
|
||||||
|
__m256 min = _mm256_min_ps( v0, v1 );
|
||||||
|
__m256 minTmp = _mm256_min_ps( v2, v3 );
|
||||||
|
min = _mm256_min_ps( min, minTmp );
|
||||||
|
|
||||||
|
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
|
||||||
|
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
|
||||||
|
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
|
||||||
|
const float minScalar = _mm_cvtss_f32( min4 );
|
||||||
|
|
||||||
// Quantize these floats
|
// Quantize these floats
|
||||||
const float d = maxScalar / 7.0f;
|
const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
|
||||||
|
const float d = magnitude / -8.0f;
|
||||||
y[i].d = d;
|
y[i].d = d;
|
||||||
const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
|
const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
|
||||||
const __m256 mul = _mm256_set1_ps( id );
|
const __m256 mul = _mm256_set1_ps( id );
|
||||||
|
|
||||||
// Apply the multiplier
|
// Apply the multiplier
|
||||||
@ -915,10 +935,13 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|||||||
ni0 = _mm_packs_epi16( ni0, ni2 );
|
ni0 = _mm_packs_epi16( ni0, ni2 );
|
||||||
ni4 = _mm_packs_epi16( ni4, ni6 );
|
ni4 = _mm_packs_epi16( ni4, ni6 );
|
||||||
|
|
||||||
// Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
|
// Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
|
||||||
const __m128i off = _mm_set1_epi8( 8 );
|
const __m128i off = _mm_set1_epi8( 8 );
|
||||||
ni0 = _mm_add_epi8( ni0, off );
|
ni0 = _mm_add_epi8( ni0, off );
|
||||||
ni4 = _mm_add_epi8( ni4, off );
|
ni4 = _mm_add_epi8( ni4, off );
|
||||||
|
const __m128i maxNibble = _mm_set1_epi8( 15 );
|
||||||
|
ni0 = _mm_min_epi8( ni0, maxNibble );
|
||||||
|
ni4 = _mm_min_epi8( ni4, maxNibble );
|
||||||
|
|
||||||
// Compress the vector into 4 bit/value, and store
|
// Compress the vector into 4 bit/value, and store
|
||||||
__m128i res = packNibbles( ni0, ni4 );
|
__m128i res = packNibbles( ni0, ni4 );
|
||||||
|
Loading…
Reference in New Issue
Block a user