mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 13:58:46 +01:00
AVX2 optimization for vec_dot_q4_2_q8_0 (#1068)
This commit is contained in:
parent
02d6988121
commit
c8c2c52482
99
ggml.c
99
ggml.c
@ -467,12 +467,30 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
||||
// quantization
|
||||
//
|
||||
|
||||
// AVX routines provided by GH user Const-me
|
||||
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
|
||||
#if __AVX__ || __AVX2__ || __AVX512F__
|
||||
// Unpack 16 4-bit fields into 16 bytes
|
||||
// The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
|
||||
static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
|
||||
{
|
||||
// Load 8 bytes from memory
|
||||
__m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
|
||||
|
||||
// Expand bytes into uint16_t values
|
||||
__m128i bytes = _mm_cvtepu8_epi16( tmp );
|
||||
|
||||
// Unpack values into individual bytes
|
||||
const __m128i lowMask = _mm_set1_epi8( 0xF );
|
||||
__m128i high = _mm_andnot_si128( lowMask, bytes );
|
||||
__m128i low = _mm_and_si128( lowMask, bytes );
|
||||
high = _mm_slli_epi16( high, 4 );
|
||||
bytes = _mm_or_si128( low, high );
|
||||
return bytes;
|
||||
}
|
||||
|
||||
#if __AVX2__ || __AVX512F__
|
||||
// Unpack 32 4-bit fields into 32 bytes
|
||||
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
||||
static inline __m256i bytesFromNibbles( const uint8_t* rsi )
|
||||
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
|
||||
{
|
||||
// Load 16 bytes from memory
|
||||
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
|
||||
@ -503,24 +521,7 @@ static inline __m128i packNibbles( __m256i bytes )
|
||||
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
|
||||
return _mm_packus_epi16( r0, r1 );
|
||||
}
|
||||
#elif __AVX__
|
||||
static inline __m128i bytesFromNibbles( const uint8_t* rsi )
|
||||
{
|
||||
// Load 8 bytes from memory
|
||||
__m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
|
||||
|
||||
// Expand bytes into uint16_t values
|
||||
__m128i bytes = _mm_cvtepu8_epi16( tmp );
|
||||
|
||||
// Unpack values into individual bytes
|
||||
const __m128i lowMask = _mm_set1_epi8( 0xF );
|
||||
__m128i high = _mm_andnot_si128( lowMask, bytes );
|
||||
__m128i low = _mm_and_si128( lowMask, bytes );
|
||||
high = _mm_slli_epi16( high, 4 );
|
||||
bytes = _mm_or_si128( low, high );
|
||||
return bytes;
|
||||
}
|
||||
|
||||
#else
|
||||
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
||||
{
|
||||
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
||||
@ -537,6 +538,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
||||
return _mm_packus_epi16( bytes1, bytes2);
|
||||
}
|
||||
#endif
|
||||
#endif // __AVX__ || __AVX2__ || __AVX512F__
|
||||
|
||||
#if __ARM_NEON
|
||||
|
||||
@ -1395,7 +1397,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
||||
|
||||
for (int l = 0; l < QK4_0; l += 32) {
|
||||
// Load 32x4-bit integers into 32x8-bit integers
|
||||
__m256i vx8 = bytesFromNibbles(pp+l/2);
|
||||
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
|
||||
|
||||
// Subtract 8 from the integers
|
||||
vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
|
||||
@ -1513,7 +1515,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
||||
|
||||
for (int l = 0; l < QK4_1; l += 32) {
|
||||
// Load 32x4-bit integers into 32x8-bit integers
|
||||
__m256i vx8 = bytesFromNibbles(pp+l/2);
|
||||
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
|
||||
|
||||
// Convert to 16-bit int
|
||||
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
|
||||
@ -2356,7 +2358,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
||||
/* Compute combined scale for the block */
|
||||
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
||||
|
||||
__m256i bx = bytesFromNibbles(x[i].qs);
|
||||
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
||||
|
||||
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
||||
const __m256i off = _mm256_set1_epi8( 8 );
|
||||
@ -2402,7 +2404,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
||||
__m128i i32[2];
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
|
||||
__m128i bx = bytesFromNibbles( x[i].qs + 8*j );
|
||||
__m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
|
||||
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
|
||||
|
||||
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
||||
@ -2567,7 +2569,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
|
||||
const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
|
||||
|
||||
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
|
||||
const __m256i bx = bytesFromNibbles( x[i].qs );
|
||||
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
|
||||
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
|
||||
|
||||
// Get absolute values of x vectors
|
||||
@ -2721,6 +2723,51 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
|
||||
}
|
||||
|
||||
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
||||
#elif defined(__AVX2__)
|
||||
// Initialize accumulator with zeros
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
// Main loop
|
||||
for (int i = 0; i < nb; i++) {
|
||||
/* Compute combined scale for the block */
|
||||
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
|
||||
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
|
||||
const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
|
||||
|
||||
__m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
|
||||
__m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
|
||||
__m256i bx = _mm256_set_m128i(bx1, bx0);
|
||||
|
||||
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
||||
const __m256i off = _mm256_set1_epi8(8);
|
||||
bx = _mm256_sub_epi8(bx, off);
|
||||
|
||||
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
||||
|
||||
// Get absolute values of x vectors
|
||||
const __m256i ax = _mm256_sign_epi8(bx, bx);
|
||||
// Sign the values of the y vectors
|
||||
const __m256i sy = _mm256_sign_epi8(by, bx);
|
||||
// Perform multiplication and create 16-bit values
|
||||
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
||||
|
||||
const __m256i ones = _mm256_set1_epi16(1);
|
||||
__m256i xy_q = _mm256_madd_epi16(ones, dot);
|
||||
|
||||
/* Convert to vectore of 8 int32_t to 8 floats */
|
||||
__m256 q = _mm256_cvtepi32_ps(xy_q);
|
||||
|
||||
/* Multiply q with scale and accumulate */
|
||||
acc = _mm256_fmadd_ps(d, q, acc);
|
||||
}
|
||||
|
||||
// Return horizontal sum of the acc vector
|
||||
__m128 res = _mm256_extractf128_ps(acc, 1);
|
||||
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
|
||||
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
|
||||
res = _mm_add_ss(res, _mm_movehdup_ps(res));
|
||||
|
||||
sumf = _mm_cvtss_f32(res);
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < nb; i++) {
|
||||
|
Loading…
Reference in New Issue
Block a user