mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 14:20:31 +01:00
Add initial AVX512 support for dot product on Linux (#320)
* Update Makefile to detect AVX512 support and add compiler flags if it's available * Based on existing AVX2 implementation, dot product on one 32-value block of 4-bit quantized ints at a time * Perform 8 bit -> 16 bit sign extension and multiply+add on 32 values at time instead of 16 * Use built-in AVX512 horizontal reduce add to get sum at the end * Manual unrolling on inner dot product loop to reduce loop counter overhead
This commit is contained in:
parent
8cf9f34edd
commit
2e664f1ff4
32
Makefile
32
Makefile
@ -95,6 +95,38 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
|
|||||||
ifneq (,$(findstring sse3,$(SSE3_M)))
|
ifneq (,$(findstring sse3,$(SSE3_M)))
|
||||||
CFLAGS += -msse3
|
CFLAGS += -msse3
|
||||||
endif
|
endif
|
||||||
|
AVX512F_M := $(shell grep "avx512f " /proc/cpuinfo)
|
||||||
|
ifneq (,$(findstring avx512f,$(AVX512F_M)))
|
||||||
|
CFLAGS += -mavx512f
|
||||||
|
endif
|
||||||
|
AVX512BW_M := $(shell grep "avx512bw " /proc/cpuinfo)
|
||||||
|
ifneq (,$(findstring avx512bw,$(AVX512BW_M)))
|
||||||
|
CFLAGS += -mavx512bw
|
||||||
|
endif
|
||||||
|
AVX512DQ_M := $(shell grep "avx512dq " /proc/cpuinfo)
|
||||||
|
ifneq (,$(findstring avx512dq,$(AVX512DQ_M)))
|
||||||
|
CFLAGS += -mavx512dq
|
||||||
|
endif
|
||||||
|
AVX512VL_M := $(shell grep "avx512vl " /proc/cpuinfo)
|
||||||
|
ifneq (,$(findstring avx512vl,$(AVX512VL_M)))
|
||||||
|
CFLAGS += -mavx512vl
|
||||||
|
endif
|
||||||
|
AVX512CD_M := $(shell grep "avx512cd " /proc/cpuinfo)
|
||||||
|
ifneq (,$(findstring avx512cd,$(AVX512CD_M)))
|
||||||
|
CFLAGS += -mavx512cd
|
||||||
|
endif
|
||||||
|
AVX512ER_M := $(shell grep "avx512er " /proc/cpuinfo)
|
||||||
|
ifneq (,$(findstring avx512er,$(AVX512ER_M)))
|
||||||
|
CFLAGS += -mavx512er
|
||||||
|
endif
|
||||||
|
AVX512IFMA_M := $(shell grep "avx512ifma " /proc/cpuinfo)
|
||||||
|
ifneq (,$(findstring avx512ifma,$(AVX512IFMA_M)))
|
||||||
|
CFLAGS += -mavx512ifma
|
||||||
|
endif
|
||||||
|
AVX512PF_M := $(shell grep "avx512pf " /proc/cpuinfo)
|
||||||
|
ifneq (,$(findstring avx512pf,$(AVX512PF_M)))
|
||||||
|
CFLAGS += -mavx512pf
|
||||||
|
endif
|
||||||
else ifeq ($(UNAME_S),Haiku)
|
else ifeq ($(UNAME_S),Haiku)
|
||||||
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
|
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
|
||||||
ifneq (,$(findstring avx,$(AVX1_M)))
|
ifneq (,$(findstring avx,$(AVX1_M)))
|
||||||
|
78
ggml.c
78
ggml.c
@ -361,7 +361,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
|||||||
|
|
||||||
// AVX routines provided by GH user Const-me
|
// AVX routines provided by GH user Const-me
|
||||||
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
|
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
|
||||||
#if __AVX2__
|
#if __AVX2__ || __AVX512F__
|
||||||
// Unpack 32 4-bit fields into 32 bytes
|
// Unpack 32 4-bit fields into 32 bytes
|
||||||
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
||||||
static inline __m256i bytesFromNibbles( const uint8_t* rsi )
|
static inline __m256i bytesFromNibbles( const uint8_t* rsi )
|
||||||
@ -397,7 +397,6 @@ static inline __m128i packNibbles( __m256i bytes )
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
// method 5
|
// method 5
|
||||||
// blocks of QK elements
|
// blocks of QK elements
|
||||||
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
|
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
|
||||||
@ -1262,6 +1261,47 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
|
|||||||
*s = sumf;
|
*s = sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if __AVX512F__ && QK == 32
|
||||||
|
static inline __m512 dot_q4_0_oneblock_avx512(
|
||||||
|
__m512 acc,
|
||||||
|
const uint8_t * pd0,
|
||||||
|
const uint8_t * pd1,
|
||||||
|
const uint8_t * pb0,
|
||||||
|
const uint8_t * pb1,
|
||||||
|
size_t bs,
|
||||||
|
int i
|
||||||
|
) {
|
||||||
|
const float * d0_0 = (const float *) (pd0 + i*bs);
|
||||||
|
const float * d1_0 = (const float *) (pd1 + i*bs);
|
||||||
|
|
||||||
|
const uint8_t * restrict p0 = pb0 + (i+0)*bs;
|
||||||
|
const uint8_t * restrict p1 = pb1 + (i+0)*bs;
|
||||||
|
|
||||||
|
// Compute combined scale for the block
|
||||||
|
float scaleScalar = d0_0[0] * d1_0[0];
|
||||||
|
__m512 scale = _mm512_set1_ps( scaleScalar );
|
||||||
|
|
||||||
|
__m256i bx = bytesFromNibbles( p0 );
|
||||||
|
__m256i by = bytesFromNibbles( p1 );
|
||||||
|
|
||||||
|
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
||||||
|
const __m256i off = _mm256_set1_epi8( 8 );
|
||||||
|
bx = _mm256_sub_epi8( bx, off );
|
||||||
|
by = _mm256_sub_epi8( by, off );
|
||||||
|
|
||||||
|
// Sign-extend 16 signed bytes into int16_t
|
||||||
|
__m512i x32 = _mm512_cvtepi8_epi16( bx );
|
||||||
|
__m512i y32 = _mm512_cvtepi8_epi16( by );
|
||||||
|
// Compute products of int16_t integers, add pairwise
|
||||||
|
__m512i i64 = _mm512_madd_epi16( x32, y32 );
|
||||||
|
|
||||||
|
// Convert int32_t to float
|
||||||
|
__m512 p = _mm512_cvtepi32_ps( i64 );
|
||||||
|
// Apply the scale, and accumulate
|
||||||
|
return _mm512_fmadd_ps( scale, p, acc );
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
|
inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
|
||||||
ggml_float sumf = 0.0;
|
ggml_float sumf = 0.0;
|
||||||
|
|
||||||
@ -1417,6 +1457,40 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
|||||||
#else
|
#else
|
||||||
#error "not implemented for QK"
|
#error "not implemented for QK"
|
||||||
#endif
|
#endif
|
||||||
|
#elif defined(__AVX512F__)
|
||||||
|
|
||||||
|
#if QK == 32
|
||||||
|
// Initialize accumulator with zeros
|
||||||
|
__m512 acc0 = _mm512_setzero_ps();
|
||||||
|
__m512 acc1 = _mm512_setzero_ps();
|
||||||
|
|
||||||
|
const int superblock_size = 8;
|
||||||
|
const int superblock_count = nb / superblock_size;
|
||||||
|
const int remainder = nb % superblock_size;
|
||||||
|
|
||||||
|
for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
|
||||||
|
int i = superblock_ix * superblock_size;
|
||||||
|
|
||||||
|
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+0 );
|
||||||
|
acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+1 );
|
||||||
|
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+2 );
|
||||||
|
acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+3 );
|
||||||
|
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+4 );
|
||||||
|
acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+5 );
|
||||||
|
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+6 );
|
||||||
|
acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+7 );
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remainders
|
||||||
|
for (int i = superblock_count * superblock_size; i < nb; ++i) {
|
||||||
|
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i );
|
||||||
|
}
|
||||||
|
|
||||||
|
// Horizontal sum of all lanes of the accumulator
|
||||||
|
sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
|
||||||
|
#else
|
||||||
|
#error "not implemented for QK"
|
||||||
|
#endif
|
||||||
#elif defined(__AVX2__)
|
#elif defined(__AVX2__)
|
||||||
#if QK == 32
|
#if QK == 32
|
||||||
const size_t countBlocks = nb;
|
const size_t countBlocks = nb;
|
||||||
|
Loading…
Reference in New Issue
Block a user