mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-26 20:22:25 +01:00
ggml : add AVX2 implementation of quantize_row_q4_1 (#515)
* Add AVX2 implementation of quantize_row_q4_1 * Actually use AVX2 * Make quantize_row_q4_1 static Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
d0aaff571c
commit
2a98bc18ea
91
ggml.c
91
ggml.c
@ -688,7 +688,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
|
static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) {
|
||||||
assert(k % QK == 0);
|
assert(k % QK == 0);
|
||||||
const int nb = k / QK;
|
const int nb = k / QK;
|
||||||
|
|
||||||
@ -729,6 +729,93 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
|
||||||
|
assert(k % QK == 0);
|
||||||
|
|
||||||
|
#if defined(__AVX2__)
|
||||||
|
const int nb = k / QK;
|
||||||
|
|
||||||
|
block_q4_1 * restrict y = vy;
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; i++) {
|
||||||
|
// Load elements into 4 AVX vectors
|
||||||
|
__m256 v0 = _mm256_loadu_ps( x );
|
||||||
|
__m256 v1 = _mm256_loadu_ps( x + 8 );
|
||||||
|
__m256 v2 = _mm256_loadu_ps( x + 16 );
|
||||||
|
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
||||||
|
x += 32;
|
||||||
|
|
||||||
|
// Compute max for the block
|
||||||
|
__m256 vmax;
|
||||||
|
vmax = _mm256_max_ps( v0, v1 );
|
||||||
|
vmax = _mm256_max_ps( vmax, v2 );
|
||||||
|
vmax = _mm256_max_ps( vmax, v3 );
|
||||||
|
|
||||||
|
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( vmax, 1 ), _mm256_castps256_ps128( vmax ) );
|
||||||
|
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
||||||
|
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
||||||
|
const float maxScalar = _mm_cvtss_f32( max4 );
|
||||||
|
|
||||||
|
// Compute min for the block
|
||||||
|
__m256 vmin;
|
||||||
|
vmin = _mm256_min_ps( v0, v1 );
|
||||||
|
vmin = _mm256_min_ps( vmin, v2 );
|
||||||
|
vmin = _mm256_min_ps( vmin, v3 );
|
||||||
|
|
||||||
|
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( vmin, 1 ), _mm256_castps256_ps128( vmin ) );
|
||||||
|
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
|
||||||
|
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
|
||||||
|
const float minScalar = _mm_cvtss_f32( min4 );
|
||||||
|
|
||||||
|
// Quantize these floats
|
||||||
|
const float d = (maxScalar - minScalar) / ((1 << 4) - 1);
|
||||||
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
|
y[i].m = minScalar;
|
||||||
|
y[i].d = d;
|
||||||
|
|
||||||
|
// x = (x-min)*id
|
||||||
|
const __m256 mul = _mm256_set1_ps( id );
|
||||||
|
const __m256 off = _mm256_set1_ps( minScalar );
|
||||||
|
v0 = _mm256_mul_ps( _mm256_sub_ps( v0, off ), mul );
|
||||||
|
v1 = _mm256_mul_ps( _mm256_sub_ps( v1, off ), mul );
|
||||||
|
v2 = _mm256_mul_ps( _mm256_sub_ps( v2, off ), mul );
|
||||||
|
v3 = _mm256_mul_ps( _mm256_sub_ps( v3, off ), mul );
|
||||||
|
|
||||||
|
// Round to nearest integer
|
||||||
|
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
|
||||||
|
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
|
||||||
|
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
|
||||||
|
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
|
||||||
|
|
||||||
|
// Convert floats to integers
|
||||||
|
__m256i i0 = _mm256_cvtps_epi32( v0 );
|
||||||
|
__m256i i1 = _mm256_cvtps_epi32( v1 );
|
||||||
|
__m256i i2 = _mm256_cvtps_epi32( v2 );
|
||||||
|
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
||||||
|
|
||||||
|
// Convert int32 to int16
|
||||||
|
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
|
||||||
|
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
|
||||||
|
// Convert int16 to int8
|
||||||
|
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
|
||||||
|
|
||||||
|
// We got our precious signed bytes, but the order is now wrong
|
||||||
|
// These AVX2 pack instructions process 16-byte pieces independently
|
||||||
|
// The following instruction is fixing the order
|
||||||
|
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
||||||
|
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
||||||
|
|
||||||
|
// Compress the vector into 4 bit/value, and store
|
||||||
|
__m128i res = packNibbles( i0 );
|
||||||
|
_mm_storeu_si128( ( __m128i* )y[i].qs, res );
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// scalar
|
||||||
|
quantize_row_q4_1_reference(x, vy, k);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
|
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
|
||||||
assert(k % QK == 0);
|
assert(k % QK == 0);
|
||||||
const int nb = k / QK;
|
const int nb = k / QK;
|
||||||
@ -10135,7 +10222,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
|
|||||||
for (int j = 0; j < n; j += k) {
|
for (int j = 0; j < n; j += k) {
|
||||||
block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK;
|
block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK;
|
||||||
|
|
||||||
quantize_row_q4_1(src + j, y, k);
|
quantize_row_q4_1_reference(src + j, y, k);
|
||||||
|
|
||||||
for (int i = 0; i < nb; i++) {
|
for (int i = 0; i < nb; i++) {
|
||||||
for (int l = 0; l < QK; l += 2) {
|
for (int l = 0; l < QK; l += 2) {
|
||||||
|
Loading…
Reference in New Issue
Block a user