ggml : add AVX support based on AVX2 code (#1430)

This commit is contained in:
katsu560 2023-05-14 19:03:51 +09:00 committed by GitHub
parent 601a033475
commit 60f8c361ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

135
ggml.c
View File

@ -580,7 +580,63 @@ static inline __m128i packNibbles( __m256i bytes )
return _mm_packus_epi16( r0, r1 ); return _mm_packus_epi16( r0, r1 );
#endif #endif
} }
#else #elif defined(__AVX__)
// spread 32 bits to 32 bytes { 0x00, 0xFF }
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
uint32_t x32;
memcpy(&x32, x, sizeof(uint32_t));
const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
__m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
__m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
bytesl = _mm_or_si128(bytesl, bit_mask);
bytesh = _mm_or_si128(bytesh, bit_mask);
bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
return _mm256_set_m128i(bytesh, bytesl);
}
// Unpack 32 4-bit fields into 32 bytes
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
{
// Load 16 bytes from memory
__m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
__m128i tmph = _mm_srli_epi16(tmpl, 4);
const __m128i lowMask = _mm_set1_epi8(0xF);
tmpl = _mm_and_si128(lowMask, tmpl);
tmph = _mm_and_si128(lowMask, tmph);
return _mm256_set_m128i(tmph, tmpl);
}
// add int16_t pairwise and return as float vector
static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
const __m128i ones = _mm_set1_epi16(1);
const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
const __m256i summed_pairs = _mm256_set_m128i(summed_pairsh, summed_pairsl);
return _mm256_cvtepi32_ps(summed_pairs);
}
// multiply int8_t, add results pairwise twice and return as float vector
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
const __m128i xl = _mm256_castsi256_si128(x);
const __m128i xh = _mm256_extractf128_si256(x, 1);
const __m128i yl = _mm256_castsi256_si128(y);
const __m128i yh = _mm256_extractf128_si256(y, 1);
// Get absolute values of x vectors
const __m128i axl = _mm_sign_epi8(xl, xl);
const __m128i axh = _mm_sign_epi8(xh, xh);
// Sign the values of the y vectors
const __m128i syl = _mm_sign_epi8(yl, xl);
const __m128i syh = _mm_sign_epi8(yh, xh);
// Perform multiplication and create 16-bit values
const __m128i dotl = _mm_maddubs_epi16(axl, syl);
const __m128i doth = _mm_maddubs_epi16(axh, syh);
return sum_i16_pairs_float(doth, dotl);
}
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
{ {
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@ -2355,7 +2411,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
} }
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
#elif defined(__AVX2__) #elif defined(__AVX2__) || defined(__AVX__)
// Initialize accumulator with zeros // Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps(); __m256 acc = _mm256_setzero_ps();
@ -2381,7 +2437,11 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
const __m256 xy = mul_sum_i8_pairs_float(bx, by); const __m256 xy = mul_sum_i8_pairs_float(bx, by);
// Accumulate d0*d1*x*y // Accumulate d0*d1*x*y
#if defined(__AVX2__)
acc = _mm256_fmadd_ps( d0d1, xy, acc ); acc = _mm256_fmadd_ps( d0d1, xy, acc );
#else
acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
#endif
} }
*s = hsum_float_8(acc) + summs; *s = hsum_float_8(acc) + summs;
@ -2592,6 +2652,37 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
acc = _mm256_fmadd_ps(d, q, acc); acc = _mm256_fmadd_ps(d, q, acc);
} }
*s = hsum_float_8(acc);
#elif defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
__m128i mask = _mm_set1_epi8((char)0xF0);
// Main loop
for (int i = 0; i < nb; i++) {
/* Compute combined scale for the block */
const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
__m256i bx = bytes_from_nibbles_32(x[i].qs);
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
__m128i bxhil = _mm256_castsi256_si128(bxhi);
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
bxhil = _mm_andnot_si128(bxhil, mask);
bxhih = _mm_andnot_si128(bxhih, mask);
__m128i bxl = _mm256_castsi256_si128(bx);
__m128i bxh = _mm256_extractf128_si256(bx, 1);
bxl = _mm_or_si128(bxl, bxhil);
bxh = _mm_or_si128(bxh, bxhih);
bx = _mm256_set_m128i(bxh, bxl);
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256 q = mul_sum_i8_pairs_float(bx, by);
/* Multiply q with scale and accumulate */
acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
}
*s = hsum_float_8(acc); *s = hsum_float_8(acc);
#else #else
// scalar // scalar
@ -2820,6 +2911,40 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
} }
*s = hsum_float_8(acc) + summs;
#elif defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
__m128i mask = _mm_set1_epi8(0x10);
float summs = 0.0f;
// Main loop
for (int i = 0; i < nb; i++) {
const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
__m256i bx = bytes_from_nibbles_32(x[i].qs);
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
__m128i bxhil = _mm256_castsi256_si128(bxhi);
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
bxhil = _mm_and_si128(bxhil, mask);
bxhih = _mm_and_si128(bxhih, mask);
__m128i bxl = _mm256_castsi256_si128(bx);
__m128i bxh = _mm256_extractf128_si256(bx, 1);
bxl = _mm_or_si128(bxl, bxhil);
bxh = _mm_or_si128(bxh, bxhih);
bx = _mm256_set_m128i(bxh, bxl);
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256 q = mul_sum_i8_pairs_float(bx, by);
acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
}
*s = hsum_float_8(acc) + summs; *s = hsum_float_8(acc) + summs;
#else #else
// scalar // scalar
@ -2910,7 +3035,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
} }
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
#elif defined(__AVX2__) #elif defined(__AVX2__) || defined(__AVX__)
// Initialize accumulator with zeros // Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps(); __m256 acc = _mm256_setzero_ps();
@ -2924,7 +3049,11 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
const __m256 q = mul_sum_i8_pairs_float(bx, by); const __m256 q = mul_sum_i8_pairs_float(bx, by);
// Multiply q with scale and accumulate // Multiply q with scale and accumulate
#if defined(__AVX2__)
acc = _mm256_fmadd_ps( d, q, acc ); acc = _mm256_fmadd_ps( d, q, acc );
#else
acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
#endif
} }
*s = hsum_float_8(acc); *s = hsum_float_8(acc);