From 0827b2c1da299805288abbd556d869318f2b121e Mon Sep 17 00:00:00 2001 From: Srihari-mcw <96763064+Srihari-mcw@users.noreply.github.com> Date: Tue, 31 Dec 2024 19:53:33 +0530 Subject: [PATCH] ggml : fixes for AVXVNNI instruction set with MSVC and Clang (#11027) * Fixes for clang AVX VNNI * enable AVX VNNI and alder lake build for MSVC * Apply suggestions from code review --------- Co-authored-by: slaren --- ggml/src/CMakeLists.txt | 4 ++-- ggml/src/ggml-cpu/CMakeLists.txt | 3 +-- ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp | 5 ++++- ggml/src/ggml-cpu/ggml-cpu-quants.c | 6 +++++- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 4 +++- 5 files changed, 15 insertions(+), 7 deletions(-) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index a5f7f7b5b..84101c32c 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -290,9 +290,9 @@ if (GGML_CPU_ALL_VARIANTS) ggml_add_cpu_backend_variant(haswell AVX F16C AVX2 FMA) ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 FMA AVX512) ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI) + ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI) if (NOT MSVC) - # MSVC doesn't support AVX-VNNI or AMX - ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI) + # MSVC doesn't support AMX ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8) endif() else () diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index f0aecac1b..6b3641c42 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -215,8 +215,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) list(APPEND ARCH_DEFINITIONS GGML_SSE42) endif() if (GGML_AVX_VNNI) - # MSVC generates AVX512 with AVX-VNNI intrinsics even with /arch:AVX2 - #list(APPEND ARCH_DEFINITIONS __AVXVNNI__ GGML_AVX_VNNI) + list(APPEND ARCH_DEFINITIONS __AVXVNNI__ GGML_AVX_VNNI) endif() else () if (GGML_NATIVE) diff --git a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp index 2d79b8b61..622c63f1f 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp @@ -194,9 +194,12 @@ static inline __m256i sum_i16_pairs_int32x8(const __m256i x) { } static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) { -#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) const __m256i zero = _mm256_setzero_si256(); return _mm256_dpbusd_epi32(zero, ax, sy); +#elif defined(__AVXVNNI__) + const __m256i zero = _mm256_setzero_si256(); + return _mm256_dpbusd_avx_epi32(zero, ax, sy); #else // Perform multiplication and create 16-bit values const __m256i dot = _mm256_maddubs_epi16(ax, sy); diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c index 634c5fa11..8e1472266 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -103,10 +103,14 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) { } static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { -#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) const __m256i zero = _mm256_setzero_si256(); const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); return _mm256_cvtepi32_ps(summed_pairs); +#elif defined(__AVXVNNI__) + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy); + return _mm256_cvtepi32_ps(summed_pairs); #else // Perform multiplication and create 16-bit values const __m256i dot = _mm256_maddubs_epi16(ax, sy); diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 00f7f1170..8fce576c3 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -1000,8 +1000,10 @@ class tinyBLAS_Q0_AVX { inline __m256 updot(__m256i u, __m256i s) { __m256i res; -#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s); +#elif defined(__AVXVNNI__) + res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s); #else res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); #endif