ggml : AVX2 optimization for vec_dot_q4_3_q8_0 and refactoring (#1099)

* AVX2 optimization for vec_dot_q4_3_q8_0 and refactoring

* finish AVX vectorization of quantize_row_q8_0

* Rename hsum_int_8 to hsum_i32_8
This commit is contained in:
Stephan Walter 2023-04-22 07:37:05 +00:00 committed by GitHub
parent e9a9cb0c54
commit c5aa5e5777
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

213
ggml.c
View File

@ -450,6 +450,24 @@ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
return bytes; return bytes;
} }
// horizontally add 8 floats
static inline float hsum_float_8(const __m256 x) {
__m128 res = _mm256_extractf128_ps(x, 1);
res = _mm_add_ps(res, _mm256_castps256_ps128(x));
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
res = _mm_add_ss(res, _mm_movehdup_ps(res));
return _mm_cvtss_f32(res);
}
// horizontally add 8 int32_t
static inline int hsum_i32_8(const __m256i a) {
const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
const __m128i sum64 = _mm_add_epi32(hi64, sum128);
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
}
#if __AVX2__ || __AVX512F__ #if __AVX2__ || __AVX512F__
// Unpack 32 4-bit fields into 32 bytes // Unpack 32 4-bit fields into 32 bytes
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
@ -470,6 +488,24 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
return bytes; return bytes;
} }
// add int16_t pairwise and return as float vector
static inline __m256 sum_i16_pairs_float(const __m256i x) {
const __m256i ones = _mm256_set1_epi16(1);
const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
return _mm256_cvtepi32_ps(summed_pairs);
}
// multiply int8_t, add results pairwise twice and return as float vector
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
// Get absolute values of x vectors
const __m256i ax = _mm256_sign_epi8(x, x);
// Sign the values of the y vectors
const __m256i sy = _mm256_sign_epi8(y, x);
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
return sum_i16_pairs_float(dot);
}
static inline __m128i packNibbles( __m256i bytes ) static inline __m128i packNibbles( __m256i bytes )
{ {
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@ -1273,29 +1309,6 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
} }
} }
#ifdef __AVX2__
// There is no better way of doing this?
// I guess not, AVX is not very good at horizontal sums.
// The commented solution for a hotrizontal sum was suggested by @pubby as being slightly
// faster than the solution below. As I don't have an AVX2 system handt right now to test,
// keeping the original.
// TODO: Please try and if it does make a differece, uncomment and remove the implementation below.
//static inline float horizontal_sum(__m256i a) {
// __m256i b = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(a)));
// __m256i sum = _mm256_add_epi32(a, b);
// __m256i hi = _mm256_unpackhi_epi64(sum, sum);
// sum = _mm256_add_epi32(sum, hi);
// return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4);
//}
static inline float horizontal_sum(__m256i a) {
__m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extracti128_si256(a, 1));
__m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
__m128i sum64 = _mm_add_epi32(hi64, sum128);
__m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
}
#endif
static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
assert(k % QK8_0 == 0); assert(k % QK8_0 == 0);
const int nb = k / QK8_0; const int nb = k / QK8_0;
@ -1384,9 +1397,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
__m256i i3 = _mm256_cvtps_epi32( v3 ); __m256i i3 = _mm256_cvtps_epi32( v3 );
#if defined(__AVX2__) #if defined(__AVX2__)
// Compute the sum of the quants and set y[i].s // Compute the sum of the quants and set y[i].s
y[i].s = d * horizontal_sum(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
// Convert int32 to int16 // Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
@ -1413,6 +1425,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
__m128i ni6 = _mm256_castsi256_si128( i3 ); __m128i ni6 = _mm256_castsi256_si128( i3 );
__m128i ni7 = _mm256_extractf128_si256( i3, 1); __m128i ni7 = _mm256_extractf128_si256( i3, 1);
// Compute the sum of the quants and set y[i].s
const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
y[i].s = d * hsum_i32_8(_mm256_set_m128i(s1, s0));
// Convert int32 to int16 // Convert int32 to int16
ni0 = _mm_packs_epi32( ni0, ni1 ); ni0 = _mm_packs_epi32( ni0, ni1 );
ni2 = _mm_packs_epi32( ni2, ni3 ); ni2 = _mm_packs_epi32( ni2, ni3 );
@ -1430,14 +1447,6 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
// scalar // scalar
quantize_row_q8_0_reference(x, y, k); quantize_row_q8_0_reference(x, y, k);
#endif #endif
#if defined __AVX__
// TODO: vectorize this
for (int i=0; i<nb; ++i) {
int sum = 0;
for (int l=0; l<QK8_0; ++l) sum += y[i].qs[l];
y[i].s = y[i].d * sum;
}
#endif
} }
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) { static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
@ -2374,8 +2383,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
const block_q4_0 * restrict x = vx; const block_q4_0 * restrict x = vx;
const block_q8_0 * restrict y = vy; const block_q8_0 * restrict y = vy;
float sumf = 0.0;
#if defined(__ARM_NEON) #if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f);
@ -2441,7 +2448,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
#endif #endif
} }
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8; *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8;
#elif defined(__AVX2__) #elif defined(__AVX2__)
// Initialize accumulator with zeros // Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps(); __m256 acc = _mm256_setzero_ps();
@ -2459,32 +2466,13 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
// Get absolute values of x vectors const __m256 q = mul_sum_i8_pairs_float(bx, by);
const __m256i ax = _mm256_sign_epi8(bx, bx);
// Sign the values of the y vectors
const __m256i sy = _mm256_sign_epi8(by, bx);
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
const __m256i ones = _mm256_set1_epi16(1);
__m256i xy_q = _mm256_madd_epi16(ones, dot);
/* Convert to vectore of 8 int32_t to 8 floats */
__m256 q = _mm256_cvtepi32_ps( xy_q );
/* Multiply q with scale and accumulate */ /* Multiply q with scale and accumulate */
acc = _mm256_fmadd_ps( d, q, acc ); acc = _mm256_fmadd_ps( d, q, acc );
} }
// Return horizontal sum of the acc vector *s = hsum_float_8(acc);
__m128 res = _mm256_extractf128_ps( acc, 1 );
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
sumf = _mm_cvtss_f32( res );
#elif defined(__AVX__) #elif defined(__AVX__)
// Initialize accumulator with zeros // Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps(); __m256 acc = _mm256_setzero_ps();
@ -2523,15 +2511,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
} }
// Return horizontal sum of the acc vector *s = hsum_float_8(acc);
__m128 res = _mm256_extractf128_ps( acc, 1 );
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
sumf = _mm_cvtss_f32( res );
#else #else
// scalar // scalar
float sumf = 0.0;
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const float d0 = x[i].d; const float d0 = x[i].d;
const float d1 = y[i].d; const float d1 = y[i].d;
@ -2553,9 +2536,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
} }
sumf += d0*d1*sumi; sumf += d0*d1*sumi;
} }
#endif
*s = sumf; *s = sumf;
#endif
} }
static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@ -2567,8 +2549,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
const block_q4_1 * restrict x = vx; const block_q4_1 * restrict x = vx;
const block_q8_0 * restrict y = vy; const block_q8_0 * restrict y = vy;
float sumf = 0.0;
// TODO: add AVX / WASM SIMD / etc // TODO: add AVX / WASM SIMD / etc
#if defined(__ARM_NEON) #if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv0 = vdupq_n_f32(0.0f);
@ -2635,7 +2615,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
#endif #endif
} }
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
#elif defined(__AVX2__) #elif defined(__AVX2__)
// Initialize accumulator with zeros // Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps(); __m256 acc = _mm256_setzero_ps();
@ -2646,7 +2626,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
const float * d0 = &x[i].d; const float * d0 = &x[i].d;
const float * d1 = &y[i].d; const float * d1 = &y[i].d;
//const float * m0 = &x[i].m;
summs += x[i].m * y[i].s; summs += x[i].m * y[i].s;
@ -2660,33 +2639,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
const __m256i bx = bytes_from_nibbles_32(x[i].qs); const __m256i bx = bytes_from_nibbles_32(x[i].qs);
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
// Get absolute values of x vectors const __m256 xy = mul_sum_i8_pairs_float(bx, by);
const __m256i ax = _mm256_sign_epi8( bx, bx );
// Sign the values of the y vectors
const __m256i sy = _mm256_sign_epi8( by, bx );
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16( ax, sy );
const __m256i ones = _mm256_set1_epi16( 1 );
const __m256i xy_q = _mm256_madd_epi16( ones, dot );
// Convert to vector of 8 int32_t to 8 floats
const __m256 xy = _mm256_cvtepi32_ps( xy_q );
// Accumulate d0*d1*x*y // Accumulate d0*d1*x*y
acc = _mm256_fmadd_ps( d0d1, xy, acc ); acc = _mm256_fmadd_ps( d0d1, xy, acc );
} }
// Return horizontal sum of the acc vector *s = hsum_float_8(acc) + summs;
__m128 res = _mm256_extractf128_ps( acc, 1 );
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
sumf = _mm_cvtss_f32( res ) + summs;
#else #else
// scalar // scalar
float sumf = 0.0;
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const float d0 = x[i].d; const float d0 = x[i].d;
const float m0 = x[i].m; const float m0 = x[i].m;
@ -2708,9 +2670,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
sumf += f0*f2 + f1*f3; sumf += f0*f2 + f1*f3;
} }
} }
#endif
*s = sumf; *s = sumf;
#endif
} }
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@ -2723,8 +2684,6 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
const block_q4_2 * restrict x = vx; const block_q4_2 * restrict x = vx;
const block_q8_0 * restrict y = vy; const block_q8_0 * restrict y = vy;
float sumf = 0.0;
#if defined(__ARM_NEON) #if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f);
@ -2802,7 +2761,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
#endif #endif
} }
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
#elif defined(__AVX2__) #elif defined(__AVX2__)
// Initialize accumulator with zeros // Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps(); __m256 acc = _mm256_setzero_ps();
@ -2824,32 +2783,16 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
// Get absolute values of x vectors const __m256 q = mul_sum_i8_pairs_float(bx, by);
const __m256i ax = _mm256_sign_epi8(bx, bx);
// Sign the values of the y vectors
const __m256i sy = _mm256_sign_epi8(by, bx);
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
const __m256i ones = _mm256_set1_epi16(1);
__m256i xy_q = _mm256_madd_epi16(ones, dot);
/* Convert to vectore of 8 int32_t to 8 floats */
__m256 q = _mm256_cvtepi32_ps(xy_q);
/* Multiply q with scale and accumulate */ /* Multiply q with scale and accumulate */
acc = _mm256_fmadd_ps(d, q, acc); acc = _mm256_fmadd_ps(d, q, acc);
} }
// Return horizontal sum of the acc vector *s = hsum_float_8(acc);
__m128 res = _mm256_extractf128_ps(acc, 1);
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
res = _mm_add_ss(res, _mm_movehdup_ps(res));
sumf = _mm_cvtss_f32(res);
#else #else
// scalar // scalar
float sumf = 0.0;
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const uint8_t * restrict x0 = x[2*i + 0].qs; const uint8_t * restrict x0 = x[2*i + 0].qs;
const uint8_t * restrict x1 = x[2*i + 1].qs; const uint8_t * restrict x1 = x[2*i + 1].qs;
@ -2884,9 +2827,8 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
sumf += (d0 * y[i].d) * sumi_0; sumf += (d0 * y[i].d) * sumi_0;
sumf += (d1 * y[i].d) * sumi_1; sumf += (d1 * y[i].d) * sumi_1;
} }
#endif
*s = sumf; *s = sumf;
#endif
} }
static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@ -2899,8 +2841,6 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
const block_q4_3 * restrict x = vx; const block_q4_3 * restrict x = vx;
const block_q8_0 * restrict y = vy; const block_q8_0 * restrict y = vy;
float sumf = 0.0;
#if defined(__ARM_NEON) #if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f);
@ -2986,9 +2926,41 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
#endif #endif
} }
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
#elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
// Main loop
for (int i = 0; i < nb; i++) {
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
const __m256 dx = _mm256_set_m128(d1, d0);
const __m128 m0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].m));
const __m128 m1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].m));
const __m256 mx = _mm256_set_m128(m1, m0);
const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
const __m256i bx = _mm256_set_m128i(bx1, bx0);
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256i syi = _mm256_maddubs_epi16(_mm256_set1_epi8(1), by);
const __m256 syf = sum_i16_pairs_float(syi);
const __m256 q = mul_sum_i8_pairs_float(bx, by);
const __m256 sxy = _mm256_fmadd_ps(q, dx, _mm256_mul_ps(mx, syf));
acc = _mm256_fmadd_ps(sxy, dy, acc);
}
*s = hsum_float_8(acc);
#else #else
// scalar // scalar
float sumf = 0.0;
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const uint8_t * restrict x0 = x[2*i + 0].qs; const uint8_t * restrict x0 = x[2*i + 0].qs;
const uint8_t * restrict x1 = x[2*i + 1].qs; const uint8_t * restrict x1 = x[2*i + 1].qs;
@ -3031,9 +3003,8 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
sumf += (d0*sxy_0 + m0*sy_0)*y[i].d; sumf += (d0*sxy_0 + m0*sy_0)*y[i].d;
sumf += (d1*sxy_1 + m1*sy_1)*y[i].d; sumf += (d1*sxy_1 + m1*sy_1)*y[i].d;
} }
#endif
*s = sumf; *s = sumf;
#endif
} }