diff --git a/CMakeLists.txt b/CMakeLists.txt index 616698c7f..92c9f09eb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -77,6 +77,7 @@ option(LLAMA_AVX2 "llama: enable AVX2" option(LLAMA_AVX512 "llama: enable AVX512" OFF) option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) +option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF) option(LLAMA_FMA "llama: enable FMA" ${INS_ENB}) # in MSVC F16C is implied with AVX2/AVX512 if (NOT MSVC) @@ -1060,6 +1061,10 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW add_compile_definitions($<$:__AVX512VNNI__>) add_compile_definitions($<$:__AVX512VNNI__>) endif() + if (LLAMA_AVX512_BF16) + add_compile_definitions($<$:__AVX512BF16__>) + add_compile_definitions($<$:__AVX512BF16__>) + endif() elseif (LLAMA_AVX2) list(APPEND ARCH_FLAGS /arch:AVX2) elseif (LLAMA_AVX) @@ -1091,6 +1096,9 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW if (LLAMA_AVX512_VNNI) list(APPEND ARCH_FLAGS -mavx512vnni) endif() + if (LLAMA_AVX512_BF16) + list(APPEND ARCH_FLAGS -mavx512bf16) + endif() endif() elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") message(STATUS "PowerPC detected") diff --git a/ggml-impl.h b/ggml-impl.h index 59684fa81..5ff014fe3 100644 --- a/ggml-impl.h +++ b/ggml-impl.h @@ -17,6 +17,18 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b)) +#if defined(_WIN32) + +#define m512bh(p) p +#define m512i(p) p + +#else + +#define m512bh(p) (__m512bh)(p) +#define m512i(p) (__m512i)(p) + +#endif + /** * Converts brain16 to float32. * diff --git a/ggml.c b/ggml.c index 3a104c486..53da231ee 100644 --- a/ggml.c +++ b/ggml.c @@ -406,10 +406,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) { int i = 0; #if defined(__AVX512BF16__) for (; i + 32 <= n; i += 32) { - _mm512_storeu_ps( - (__m512 *)(y + i), - (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16), - _mm512_loadu_ps(x + i))); + _mm512_storeu_si512( + (__m512i *)(y + i), + m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16), + _mm512_loadu_ps(x + i)))); } #endif for (; i < n; i++) { @@ -1666,10 +1666,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t __m512 c1 = _mm512_setzero_ps(); __m512 c2 = _mm512_setzero_ps(); for (; i + 64 <= n; i += 64) { - c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)), - (__m512bh)_mm512_loadu_ps((const float *)(y + i))); - c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)), - (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32))); + c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))), + m512bh(_mm512_loadu_si512((y + i)))); + c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))), + m512bh(_mm512_loadu_si512((y + i + 32)))); } sumf += (ggml_float)_mm512_reduce_add_ps(c1); sumf += (ggml_float)_mm512_reduce_add_ps(c2); @@ -23137,6 +23137,14 @@ int ggml_cpu_has_avx512_vnni(void) { #endif } +int ggml_cpu_has_avx512_bf16(void) { +#if defined(__AVX512BF16__) + return 1; +#else + return 0; +#endif +} + int ggml_cpu_has_fma(void) { #if defined(__FMA__) return 1; diff --git a/ggml.h b/ggml.h index 8c13f4ba8..774757101 100644 --- a/ggml.h +++ b/ggml.h @@ -2390,6 +2390,7 @@ extern "C" { GGML_API int ggml_cpu_has_avx512 (void); GGML_API int ggml_cpu_has_avx512_vbmi(void); GGML_API int ggml_cpu_has_avx512_vnni(void); + GGML_API int ggml_cpu_has_avx512_bf16(void); GGML_API int ggml_cpu_has_fma (void); GGML_API int ggml_cpu_has_neon (void); GGML_API int ggml_cpu_has_arm_fma (void); diff --git a/llama.cpp b/llama.cpp index 102bc2020..ca3e9fcc0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -18074,6 +18074,7 @@ const char * llama_print_system_info(void) { s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; s += "AVX512_VBMI = " + std::to_string(ggml_cpu_has_avx512_vbmi()) + " | "; s += "AVX512_VNNI = " + std::to_string(ggml_cpu_has_avx512_vnni()) + " | "; + s += "AVX512_BF16 = " + std::to_string(ggml_cpu_has_avx512_bf16()) + " | "; s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";