From 672244a88b678e1e2c17c5fb82ae6c75638bf954 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 21 May 2024 19:38:25 +0200 Subject: [PATCH] CUDA: quantized KV support for FA vec --- CMakeLists.txt | 4 + Makefile | 5 +- README.md | 1 + ggml-cuda/fattn-common.cuh | 495 +++++++++++++++++++++++++++++++++++- ggml-cuda/fattn-tile-f16.cu | 3 + ggml-cuda/fattn-tile-f32.cu | 3 + ggml-cuda/fattn-vec-f16.cu | 248 ++++++++++++------ ggml-cuda/fattn-vec-f32.cu | 186 ++++++++++---- ggml-cuda/fattn.cu | 11 +- ggml-cuda/mmq.cu | 4 +- ggml-cuda/vecdotq.cuh | 8 +- 11 files changed, 826 insertions(+), 142 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c5add8239..743619b11 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,6 +106,7 @@ set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING "llama: max. batch size for using peer access") option(LLAMA_CUDA_NO_PEER_COPY "llama: do not use peer to peer copies" OFF) option(LLAMA_CUDA_NO_VMM "llama: do not try to use CUDA VMM" OFF) +option(LLAMA_CUDA_FA_ALL_QUANTS "llama: compile all quants for FlashAttention" OFF) option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF) option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) @@ -427,6 +428,9 @@ if (LLAMA_CUDA) if (LLAMA_CUDA_NO_PEER_COPY) add_compile_definitions(GGML_CUDA_NO_PEER_COPY) endif() + if (LLAMA_CUDA_FA_ALL_QUANTS) + add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) + endif() if (LLAMA_STATIC) if (WIN32) diff --git a/Makefile b/Makefile index 5caf31cdf..94cfdab4e 100644 --- a/Makefile +++ b/Makefile @@ -493,7 +493,10 @@ ifdef LLAMA_CUDA_NO_PEER_COPY endif # LLAMA_CUDA_NO_PEER_COPY ifdef LLAMA_CUDA_CCBIN MK_NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN) -endif +endif # LLAMA_CUDA_CCBIN +ifdef LLAMA_CUDA_FA_ALL_QUANTS + MK_NVCCFLAGS += -DGGML_CUDA_FA_ALL_QUANTS +endif # LLAMA_CUDA_FA_ALL_QUANTS ifdef JETSON_EOL_MODULE_DETECT define NVCC_COMPILE diff --git a/README.md b/README.md index 15519c97f..4134f1b8d 100644 --- a/README.md +++ b/README.md @@ -481,6 +481,7 @@ Building the program with BLAS support may lead to some performance improvements | LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | | LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | | LLAMA_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. | + | LLAMA_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. | - #### hipBLAS diff --git a/ggml-cuda/fattn-common.cuh b/ggml-cuda/fattn-common.cuh index 1dd519bde..89bbc9749 100644 --- a/ggml-cuda/fattn-common.cuh +++ b/ggml-cuda/fattn-common.cuh @@ -1,4 +1,5 @@ #include "common.cuh" +#include "vecdotq.cuh" #include @@ -34,11 +35,463 @@ typedef void (* fattn_kernel_t)( const int nb11, const int nb12, const int nb13, + const int nb21, + const int nb22, + const int nb23, const int ne0, const int ne1, const int ne2, const int ne3); +typedef half (*vec_dot_KQ_f16_t)( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); +typedef float (*vec_dot_KQ_f32_t)( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { +#if __CUDA_ARCH__ > MIN_CC_DP4A + + const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c; + GGML_UNUSED(Q_v); + + half sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI4_0; + const int shift = k_KQ & (QI8_1/2); + + const int v = (get_int_from_uint8(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int u = Q_q8[k_KQ_0/WARP_SIZE]; + + const int sumi = __dp4a(v, u, 0); + +#if FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + + const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE]; + sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */); + } else +#endif // FP16_AVAILABLE + { + const float2 * Q_ds = (const float2 *) Q_ds_v; + + sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y)); + } + } + + return sum; +#else + GGML_UNUSED(K_c); + GGML_UNUSED(Q_v); + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ > MIN_CC_DP4A +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { +#if __CUDA_ARCH__ > MIN_CC_DP4A + + const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI4_1; + const int shift = k_KQ & (QI8_1/2); + + const int v = (get_int_from_uint8_aligned(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int u = Q_q8[k_KQ_0/WARP_SIZE]; + + const int sumi = __dp4a(v, u, 0); + +#if FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + + const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE]; + const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1); + sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled)); + } else +#endif // FP16_AVAILABLE + { + const float2 * Q_ds = (const float2 *) Q_ds_v; + + const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi; + const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1; + + sum += (T) (sumid4d8 + m4s8scaled); + } + } + + return sum; +#else + GGML_UNUSED(K_c); + GGML_UNUSED(Q_v); + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ > MIN_CC_DP4A +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { +#if __CUDA_ARCH__ > MIN_CC_DP4A + + const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI5_0; + const int iqs8 = k_KQ % QI8_1; + const int shift = k_KQ & (QI8_1/2); + + int v = (get_int_from_uint8(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int vh = get_int_from_uint8(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0); + v |= (vh << 4) & 0x00000010; // 0 -> 4 + v |= (vh << 11) & 0x00001000; // 1 -> 12 + v |= (vh << 18) & 0x00100000; // 2 -> 20 + v |= (vh << 25) & 0x10000000; // 3 -> 28 + + const int u = Q_q8[k_KQ_0/WARP_SIZE]; + + const int sumi = __dp4a(v, u, 0); + +#if FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + + const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE]; + sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */; + } else +#endif // FP16_AVAILABLE + { + const float2 * Q_ds = (const float2 *) Q_ds_v; + + sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y)); + } + } + + return sum; +#else + GGML_UNUSED(K_c); + GGML_UNUSED(Q_v); + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ > MIN_CC_DP4A +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { +#if __CUDA_ARCH__ > MIN_CC_DP4A + + const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI5_1; + const int iqs8 = k_KQ % QI8_1; + const int shift = k_KQ & (QI8_1/2); + + int v = (get_int_from_uint8(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int vh = get_int_from_uint8(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1); + v |= (vh << 4) & 0x00000010; // 0 -> 4 + v |= (vh << 11) & 0x00001000; // 1 -> 12 + v |= (vh << 18) & 0x00100000; // 2 -> 20 + v |= (vh << 25) & 0x10000000; // 3 -> 28 + + const int u = Q_q8[k_KQ_0/WARP_SIZE]; + + const int sumi = __dp4a(v, u, 0); + +#if FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + + const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE]; + const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1); + sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled)); + } else +#endif // FP16_AVAILABLE + { + const float2 * Q_ds = (const float2 *) Q_ds_v; + + const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi; + const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1; + + sum += (T) (sumid5d8 + m5s8scaled); + } + } + + return sum; +#else + GGML_UNUSED(K_c); + GGML_UNUSED(Q_v); + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ > MIN_CC_DP4A +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { +#if __CUDA_ARCH__ > MIN_CC_DP4A + + const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_0; + const int iqs = k_KQ % QI8_0; + + const int v = get_int_from_int8(K_q8_0[ib].qs, iqs); + + T Q_d; + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]); + } else { + const float2 * Q_ds = (const float2 *) Q_ds_v; + Q_d = Q_ds[k_KQ_0/WARP_SIZE].x; + } + + sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d); + } + + return sum; +#else + GGML_UNUSED(K_c); + GGML_UNUSED(Q_v); + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ > MIN_CC_DP4A +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { + + const half2 * K_h2 = (const half2 *) K_c; + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + +#if FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_h2 = (const half2 *) Q_v; + + half2 sum2 = make_half2(0.0f, 0.0f); + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const half2 K_ik = K_h2[k_KQ]; + sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE]; + } + + return __low2half(sum2) + __high2half(sum2); + } +#endif // FP16_AVAILABLE + + const float2 * Q_f2 = (const float2 *) Q_v; + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const half2 K_ik = K_h2[k_KQ]; + sum += __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x; + sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y; + } + + return sum; +} + +template +static __device__ __forceinline__ void quantize_q8_1_to_shared( + const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) { + + float vals[sizeof(int)] = {0.0f}; +#pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + vals[l] = scale * x[4*threadIdx.x + l]; + } + + float amax = fabsf(vals[0]); + float sum = vals[0]; +#pragma unroll + for (int l = 1; l < sizeof(int); ++l) { + amax = fmaxf(amax, fabsf(vals[l])); + sum += vals[l]; + } +#pragma unroll + for (int mask = QI8_1/2; mask > 0; mask >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32)); + sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, 32); + } + + const float d = amax / 127; + int q32 = 0; + int8_t * q8 = (int8_t *) &q32; + + if (d != 0.0f) { +#pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + q8[l] = roundf(vals[l] / d); + } + } + + yq32[threadIdx.x] = q32; + if (threadIdx.x % QI8_1 == 0) { + if (std::is_same::value) { + ((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum); + } else { + ((float2 *) yds)[threadIdx.x/QI8_1] = make_float2(d, sum); + } + } +} + +typedef half (*dequantize_1_f16_t)(const void *, const int64_t); +typedef float (*dequantize_1_f32_t)(const void *, const int64_t); + +template +static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) { + const block_q4_0 * x = (const block_q4_0 *) vx; + + const int64_t ib = i / QK4_0; + const int iqs = i % (QK4_0/2); + const int shift = (i % QK4_0) / (QK4_0/2); + + const T d = x[ib].d; + const int q0 = x[ib].qs[iqs]; + const int q = ((q0 >> (4*shift)) & 0x0F) - 8; + + return d*((T) q); +} + +template +static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ vx, const int64_t i) { + const block_q4_1 * x = (const block_q4_1 *) vx; + + const int64_t ib = i / QK4_1; + const int iqs = i % (QK4_1/2); + const int shift = (i % QK4_1) / (QK4_1/2); + + const half2 dm = x[ib].dm; + const int q0 = x[ib].qs[iqs]; + const int q = ((q0 >> (4*shift)) & 0x0F); + +#if FP16_AVAILABLE + if (std::is_same::value) { + return __low2half(dm)*((half) q) + __high2half(dm); + } +#endif // FP16_AVAILABLE + + return __low2float(dm)*((float) q) + __high2float(dm); +} + +template +static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ vx, const int64_t i) { + const block_q5_0 * x = (const block_q5_0 *) vx; + + const int64_t ib = i / QK5_0; + const int idq = i % QK5_0; + const int iqs = i % (QK5_0/2); + const int shift = (i % QK5_0) / (QK5_0/2); + + const T d = x[ib].d; + const int ql0 = x[ib].qs[iqs]; + const int qh0 = get_int_from_uint8(x[ib].qh, 0); + const int ql = ((ql0 >> (4*shift)) & 0x0F); + const int qh = ((qh0 >> idq) << 4) & 0x10; + const int q = (ql | qh) - 16; + + return d*((T) q); +} + +template +static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ vx, const int64_t i) { + const block_q5_1 * x = (const block_q5_1 *) vx; + + const int64_t ib = i / QK5_1; + const int idq = i % QK5_1; + const int iqs = i % (QK5_1/2); + const int shift = (i % QK5_1) / (QK5_1/2); + + const half2 dm = x[ib].dm; + const int ql0 = x[ib].qs[iqs]; + const int qh0 = get_int_from_uint8_aligned(x[ib].qh, 0); + const int ql = ((ql0 >> (4*shift)) & 0x0F); + const int qh = ((qh0 >> idq) << 4) & 0x10; + const int q = (ql | qh); + +#if FP16_AVAILABLE + if (std::is_same::value) { + return __low2half(dm)*((half) q) + __high2half(dm); + } +#endif // FP16_AVAILABLE + + return __low2float(dm)*((float) q) + __high2float(dm); +} + +template +static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) { + const block_q8_0 * x = (const block_q8_0 *) vx; + + const int64_t ib = i / QK8_0; + const int iqs = i % QK8_0; + + const T d = x[ib].d; + const int q = x[ib].qs[iqs]; + + return d*((T) q); +} + +template +static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) { + const half * x = (const half *) vx; + + return x[i]; +} + template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) @@ -83,6 +536,45 @@ static __global__ void flash_attn_combine_results( dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; } +// Aliases for FATTN_VEC_CASE macro: +static constexpr ggml_type ggml_type_q4_0 = GGML_TYPE_Q4_0; +static constexpr ggml_type ggml_type_q4_1 = GGML_TYPE_Q4_1; +static constexpr ggml_type ggml_type_q5_0 = GGML_TYPE_Q5_0; +static constexpr ggml_type ggml_type_q5_1 = GGML_TYPE_Q5_1; +static constexpr ggml_type ggml_type_q8_0 = GGML_TYPE_Q8_0; +static constexpr ggml_type ggml_type_f16 = GGML_TYPE_F16; + +typedef half f16; +typedef float f32; + +#define FATTN_VEC_CASE(type_VKQ, D, type_suffix_K, type_suffix_V) \ + if (Q->ne[0] == (D) && K->type == ggml_type_##type_suffix_K && V->type == ggml_type_##type_suffix_V) { \ + constexpr int nwarps = (D)/WARP_SIZE; \ + constexpr bool Q_q8_1 = ggml_type_##type_suffix_K != GGML_TYPE_F16; \ + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_##type_VKQ< \ + (D), cols_per_block, parallel_blocks, \ + vec_dot_fattn_vec_KQ_##type_suffix_K, Q_q8_1, dequantize_1_##type_suffix_V>; \ + launch_fattn<(D), parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block); \ + return; \ + } \ + +static void on_no_fattn_vec_case(const int D) { + if (D == 64) { + fprintf(stderr, "Unsupported KV type combination for head_size 64.\n"); + fprintf(stderr, "By default only f16 KV cache is supported.\n"); + fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for V cache quantization support.\n"); + GGML_ASSERT(false); + } else { + fprintf(stderr, "Unsupported KV type combination for head_size 128.\n"); + fprintf(stderr, "Supported combinations:\n"); + fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n"); + fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n"); + fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n"); + fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n"); + GGML_ASSERT(false); + } +} + template void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) { const ggml_tensor * Q = dst->src[0]; @@ -94,8 +586,6 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern ggml_tensor * KQV = dst; GGML_ASSERT(Q->type == GGML_TYPE_F32); - GGML_ASSERT(K->type == GGML_TYPE_F16); - GGML_ASSERT(V->type == GGML_TYPE_F16); GGML_ASSERT(KQV->type == GGML_TYPE_F32); GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); @@ -143,6 +633,7 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, Q->nb[1], Q->nb[2], Q->nb[3], K->nb[1], K->nb[2], K->nb[3], + V->nb[1], V->nb[2], V->nb[3], KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] ); CUDA_CHECK(cudaGetLastError()); diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index cdb5eaff7..3d64a9eba 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -36,6 +36,9 @@ static __global__ void flash_attn_tile_ext_f16( const int nb11, const int nb12, const int nb13, + const int nb21, + const int nb22, + const int nb23, const int ne0, const int ne1, const int ne2, diff --git a/ggml-cuda/fattn-tile-f32.cu b/ggml-cuda/fattn-tile-f32.cu index 5a3de2918..61fce0a7e 100644 --- a/ggml-cuda/fattn-tile-f32.cu +++ b/ggml-cuda/fattn-tile-f32.cu @@ -36,6 +36,9 @@ static __global__ void flash_attn_tile_ext_f32( const int nb11, const int nb12, const int nb13, + const int nb21, + const int nb22, + const int nb23, const int ne0, const int ne1, const int ne2, diff --git a/ggml-cuda/fattn-vec-f16.cu b/ggml-cuda/fattn-vec-f16.cu index 808e8f362..fa6750364 100644 --- a/ggml-cuda/fattn-vec-f16.cu +++ b/ggml-cuda/fattn-vec-f16.cu @@ -2,7 +2,7 @@ #include "fattn-common.cuh" #include "fattn-vec-f16.cuh" -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -34,6 +34,9 @@ static __global__ void flash_attn_vec_ext_f16( const int nb11, const int nb12, const int nb13, + const int nb21, + const int nb22, + const int nb23, const int ne0, const int ne1, const int ne2, @@ -45,13 +48,11 @@ static __global__ void flash_attn_vec_ext_f16( const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + ne11*ic0; + Q += nb02* blockIdx.y + nb01*ic0; + K += nb12*(blockIdx.y / gqa_ratio); + V += nb22*(blockIdx.y / gqa_ratio); - const int stride_KV = nb11 / sizeof(half); - const int stride_KV2 = nb11 / sizeof(half2); + const half * maskh = (const half *) mask + ne11*ic0; const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); @@ -62,10 +63,6 @@ static __global__ void flash_attn_vec_ext_f16( __builtin_assume(tid < D); __shared__ half KQ[ncols*D]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - KQ[j*D + tid] = -HALF_MAX_HALF; - } half2 * KQ2 = (half2 *) KQ; half kqmax[ncols]; @@ -86,17 +83,76 @@ static __global__ void flash_attn_vec_ext_f16( } __syncthreads(); - // Convert Q to half2 and store in registers: - half2 Q_h2[ncols][D/(2*WARP_SIZE)]; + // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers: + half2 Q_h2[ncols][D/(2*WARP_SIZE)]; + int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)]; + half2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; + if (Q_q8_1) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + // Reuse KQ as temporary storage for converting Q to q8_1: + int * tmp_q_i32 = (int *) &KQ[j*D]; + half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); + + // Set memory to zero if out of bounds: + if (ncols > 2 && ic0 + j >= ne01) { +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + tmp_q_i32[i] = 0; + } + if (threadIdx.x < D/QK8_1) { + tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f); + } + continue; + } + + const float * Q_f = (const float *) (Q + j*nb01); +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + int * tmp_q_i32 = (int *) &KQ[j*D]; + half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); + +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; + Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1]; + } + } + + __syncthreads(); + } else { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); + Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); + } + } + } + + #pragma unroll for (int j = 0; j < ncols; ++j) { -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f); - Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); - } + KQ[j*D + tid] = -HALF_MAX_HALF; } half2 VKQ[ncols] = {{0.0f, 0.0f}}; @@ -123,22 +179,10 @@ static __global__ void flash_attn_vec_ext_f16( break; } - half2 sum2[ncols] = {{0.0f, 0.0f}}; -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - sum2[j] += K_ik * Q_h2[j][k_KQ_0/WARP_SIZE]; - } - } - #pragma unroll for (int j = 0; j < ncols; ++j) { - sum2[j] = warp_reduce_sum(sum2[j]); - half sum = __low2half(sum2[j]) + __high2half(sum2[j]); + half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); + sum = warp_reduce_sum(sum); sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); if (ncols == 1) { @@ -189,8 +233,8 @@ static __global__ void flash_attn_vec_ext_f16( } half2 V_k; - reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; - reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; + reinterpret_cast(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid); + reinterpret_cast(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid); #pragma unroll for (int j = 0; j < ncols; ++j) { VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; @@ -248,19 +292,22 @@ void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tens case 64: { constexpr int D = 64; constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16< + D, cols_per_block, parallel_blocks, vec_dot_fattn_vec_KQ_f16, false, dequantize_1_f16>; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; case 128: { constexpr int D = 128; constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16< + D, cols_per_block, parallel_blocks, vec_dot_fattn_vec_KQ_f16, false, dequantize_1_f16>; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; case 256: { constexpr int D = 256; constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16< + D, cols_per_block, parallel_blocks, vec_dot_fattn_vec_KQ_f16, false, dequantize_1_f16>; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; default: @@ -272,57 +319,100 @@ void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tens template void launch_fattn_vec_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; - switch (Q->ne[0]) { - case 64: { - constexpr int D = 64; - constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); - } break; - case 128: { - constexpr int D = 128; - constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); - } break; - default: { - GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128."); - } break; - } + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + +#ifdef GGML_CUDA_FA_ALL_QUANTS + FATTN_VEC_CASE(f16, 64, f16, q4_0) + FATTN_VEC_CASE(f16, 64, f16, q4_1) + FATTN_VEC_CASE(f16, 64, f16, q5_0) + FATTN_VEC_CASE(f16, 64, f16, q5_1) + FATTN_VEC_CASE(f16, 64, f16, q8_0) + FATTN_VEC_CASE(f16, 64, f16, f16) + + FATTN_VEC_CASE(f16, 128, q4_0, q4_0) + FATTN_VEC_CASE(f16, 128, q4_0, q4_1) + FATTN_VEC_CASE(f16, 128, q4_0, q5_0) + FATTN_VEC_CASE(f16, 128, q4_0, q5_1) + FATTN_VEC_CASE(f16, 128, q4_0, q8_0) + FATTN_VEC_CASE(f16, 128, q4_0, f16) + + FATTN_VEC_CASE(f16, 128, q4_1, q4_0) + FATTN_VEC_CASE(f16, 128, q4_1, q4_1) + FATTN_VEC_CASE(f16, 128, q4_1, q5_0) + FATTN_VEC_CASE(f16, 128, q4_1, q5_1) + FATTN_VEC_CASE(f16, 128, q4_1, q8_0) + FATTN_VEC_CASE(f16, 128, q4_1, f16) + + FATTN_VEC_CASE(f16, 128, q5_0, q4_0) + FATTN_VEC_CASE(f16, 128, q5_0, q4_1) + FATTN_VEC_CASE(f16, 128, q5_0, q5_0) + FATTN_VEC_CASE(f16, 128, q5_0, q5_1) + FATTN_VEC_CASE(f16, 128, q5_0, q8_0) + FATTN_VEC_CASE(f16, 128, q5_0, f16) + + FATTN_VEC_CASE(f16, 128, q5_1, q4_0) + FATTN_VEC_CASE(f16, 128, q5_1, q4_1) + FATTN_VEC_CASE(f16, 128, q5_1, q5_0) + FATTN_VEC_CASE(f16, 128, q5_1, q5_1) + FATTN_VEC_CASE(f16, 128, q5_1, q8_0) + FATTN_VEC_CASE(f16, 128, q5_1, f16) + + FATTN_VEC_CASE(f16, 128, q8_0, q4_0) + FATTN_VEC_CASE(f16, 128, q8_0, q4_1) + FATTN_VEC_CASE(f16, 128, q8_0, q5_0) + FATTN_VEC_CASE(f16, 128, q8_0, q5_1) + FATTN_VEC_CASE(f16, 128, q8_0, q8_0) + FATTN_VEC_CASE(f16, 128, q8_0, f16) + + FATTN_VEC_CASE(f16, 128, f16, q4_0) + FATTN_VEC_CASE(f16, 128, f16, q4_1) + FATTN_VEC_CASE(f16, 128, f16, q5_0) + FATTN_VEC_CASE(f16, 128, f16, q5_1) + FATTN_VEC_CASE(f16, 128, f16, q8_0) + FATTN_VEC_CASE(f16, 128, f16, f16) +#else + FATTN_VEC_CASE(f16, 128, q4_0, q4_0) + FATTN_VEC_CASE(f16, 128, q8_0, q8_0) + FATTN_VEC_CASE(f16, 128, f16, f16) +#endif // GGML_CUDA_FA_ALL_QUANTS + + on_no_fattn_vec_case(Q->ne[0]); } void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; const int32_t precision = KQV->op_params[2]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); - if (Q->ne[1] == 1) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); - return; - } + // if (Q->ne[1] == 1) { + // constexpr int cols_per_block = 1; + // constexpr int parallel_blocks = 4; + // launch_fattn_vec_f16_64_128(ctx, dst); + // return; + // } - if (Q->ne[1] == 2) { - constexpr int cols_per_block = 2; - constexpr int parallel_blocks = 4; - launch_fattn_vec_f16_64_128(ctx, dst); - return; - } + // if (Q->ne[1] == 2) { + // constexpr int cols_per_block = 2; + // constexpr int parallel_blocks = 4; + // launch_fattn_vec_f16_64_128(ctx, dst); + // return; + // } - if (Q->ne[1] <= 4) { - constexpr int cols_per_block = 4; - constexpr int parallel_blocks = 4; - launch_fattn_vec_f16_64_128(ctx, dst); - return; - } + // if (Q->ne[1] <= 4) { + // constexpr int cols_per_block = 4; + // constexpr int parallel_blocks = 4; + // launch_fattn_vec_f16_64_128(ctx, dst); + // return; + // } - if (Q->ne[1] <= 8) { - constexpr int cols_per_block = 8; - constexpr int parallel_blocks = 4; - launch_fattn_vec_f16_64_128(ctx, dst); - return; - } + // if (Q->ne[1] <= 8) { + // constexpr int cols_per_block = 8; + // constexpr int parallel_blocks = 4; + // launch_fattn_vec_f16_64_128(ctx, dst); + // return; + // } constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; diff --git a/ggml-cuda/fattn-vec-f32.cu b/ggml-cuda/fattn-vec-f32.cu index b4652301b..dded24320 100644 --- a/ggml-cuda/fattn-vec-f32.cu +++ b/ggml-cuda/fattn-vec-f32.cu @@ -2,7 +2,7 @@ #include "fattn-common.cuh" #include "fattn-vec-f32.cuh" -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -34,6 +34,9 @@ static __global__ void flash_attn_vec_ext_f32( const int nb11, const int nb12, const int nb13, + const int nb21, + const int nb22, + const int nb23, const int ne0, const int ne1, const int ne2, @@ -44,13 +47,10 @@ static __global__ void flash_attn_vec_ext_f32( const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + ne11*ic0; - - const int stride_KV = nb11 / sizeof(half); - const int stride_KV2 = nb11 / sizeof(half2); + Q += nb02* blockIdx.y + nb01*ic0; + K += nb12*(blockIdx.y / gqa_ratio); + V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape + const half * maskh = (const half *) mask + ne11*ic0; const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); @@ -83,17 +83,69 @@ static __global__ void flash_attn_vec_ext_f32( } __syncthreads(); - // Convert Q to half2 and store in registers: - float2 Q_h2[ncols][D/(2*WARP_SIZE)]; + // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: + float2 Q_f2[ncols][D/(2*WARP_SIZE)]; + int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D >= D/(sizeof(int)*QK8_1)]; + float2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; + if (Q_q8_1) { #pragma unroll - for (int j = 0; j < ncols; ++j) { -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; - Q_h2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f); - Q_h2[j][i0/WARP_SIZE].x *= scale; - Q_h2[j][i0/WARP_SIZE].y *= scale; + // Reuse KQ as temporary storage for converting Q to q8_1: + int * tmp_q_i32 = (int *) &KQ[j*D]; + float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); + + // Set memory to zero if out of bounds: + if (ncols > 2 && ic0 + j >= ne01) { +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + tmp_q_i32[i] = 0; + } + if (threadIdx.x < D/QK8_1) { + tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f); + } + continue; + } + + const float * Q_f = (const float *) (Q + j*nb01); +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + int * tmp_q_i32 = (int *) &KQ[j*D]; + float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); + +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; + Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1]; + } + } + + __syncthreads(); + } else { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2_j[i] : make_float2(0.0f, 0.0f); + Q_f2[j][i0/WARP_SIZE].x *= scale; + Q_f2[j][i0/WARP_SIZE].y *= scale; + } } } @@ -117,28 +169,16 @@ static __global__ void flash_attn_vec_ext_f32( break; } - float sum[ncols] = {0.0f}; -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - sum[j] += __low2float(K_ik) * Q_h2[j][k_KQ_0/WARP_SIZE].x; - sum[j] += __high2float(K_ik) * Q_h2[j][k_KQ_0/WARP_SIZE].y; - } - } - #pragma unroll for (int j = 0; j < ncols; ++j) { - sum[j] = warp_reduce_sum(sum[j]); - sum[j] += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; + float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); + sum = warp_reduce_sum(sum); + sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; - kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum[j]); + kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); if (threadIdx.x == 0) { - KQ[j*D + i_KQ] = sum[j]; + KQ[j*D + i_KQ] = sum; } } } @@ -178,7 +218,7 @@ static __global__ void flash_attn_vec_ext_f32( break; } - const float V_ki = __half2float(V_h[(k_VKQ_0 + k)*stride_KV + tid]); + const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid); #pragma unroll for (int j = 0; j < ncols; ++j) { VKQ[j] += V_ki*KQ[j*D + k]; @@ -223,23 +263,65 @@ static __global__ void flash_attn_vec_ext_f32( template void launch_fattn_vec_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; - switch (Q->ne[0]) { - case 64: { - constexpr int D = 64; - constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); - } break; - case 128: { - constexpr int D = 128; - constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); - } break; - default: { - GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128."); - } break; - } + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + +#ifdef GGML_CUDA_FA_ALL_QUANTS + FATTN_VEC_CASE(f32, 64, f16, q4_0) + FATTN_VEC_CASE(f32, 64, f16, q4_1) + FATTN_VEC_CASE(f32, 64, f16, q5_0) + FATTN_VEC_CASE(f32, 64, f16, q5_1) + FATTN_VEC_CASE(f32, 64, f16, q8_0) + FATTN_VEC_CASE(f32, 64, f16, f16) + + FATTN_VEC_CASE(f32, 128, q4_0, q4_0) + FATTN_VEC_CASE(f32, 128, q4_0, q4_1) + FATTN_VEC_CASE(f32, 128, q4_0, q5_0) + FATTN_VEC_CASE(f32, 128, q4_0, q5_1) + FATTN_VEC_CASE(f32, 128, q4_0, q8_0) + FATTN_VEC_CASE(f32, 128, q4_0, f16) + + FATTN_VEC_CASE(f32, 128, q4_1, q4_0) + FATTN_VEC_CASE(f32, 128, q4_1, q4_1) + FATTN_VEC_CASE(f32, 128, q4_1, q5_0) + FATTN_VEC_CASE(f32, 128, q4_1, q5_1) + FATTN_VEC_CASE(f32, 128, q4_1, q8_0) + FATTN_VEC_CASE(f32, 128, q4_1, f16) + + FATTN_VEC_CASE(f32, 128, q5_0, q4_0) + FATTN_VEC_CASE(f32, 128, q5_0, q4_1) + FATTN_VEC_CASE(f32, 128, q5_0, q5_0) + FATTN_VEC_CASE(f32, 128, q5_0, q5_1) + FATTN_VEC_CASE(f32, 128, q5_0, q8_0) + FATTN_VEC_CASE(f32, 128, q5_0, f16) + + FATTN_VEC_CASE(f32, 128, q5_1, q4_0) + FATTN_VEC_CASE(f32, 128, q5_1, q4_1) + FATTN_VEC_CASE(f32, 128, q5_1, q5_0) + FATTN_VEC_CASE(f32, 128, q5_1, q5_1) + FATTN_VEC_CASE(f32, 128, q5_1, q8_0) + FATTN_VEC_CASE(f32, 128, q5_1, f16) + + FATTN_VEC_CASE(f32, 128, q8_0, q4_0) + FATTN_VEC_CASE(f32, 128, q8_0, q4_1) + FATTN_VEC_CASE(f32, 128, q8_0, q5_0) + FATTN_VEC_CASE(f32, 128, q8_0, q5_1) + FATTN_VEC_CASE(f32, 128, q8_0, q8_0) + FATTN_VEC_CASE(f32, 128, q8_0, f16) + + FATTN_VEC_CASE(f32, 128, f16, q4_0) + FATTN_VEC_CASE(f32, 128, f16, q4_1) + FATTN_VEC_CASE(f32, 128, f16, q5_0) + FATTN_VEC_CASE(f32, 128, f16, q5_1) + FATTN_VEC_CASE(f32, 128, f16, q8_0) + FATTN_VEC_CASE(f32, 128, f16, f16) +#else + FATTN_VEC_CASE(f32, 128, q4_0, q4_0) + FATTN_VEC_CASE(f32, 128, q8_0, q8_0) + FATTN_VEC_CASE(f32, 128, f16, f16) +#endif // GGML_CUDA_FA_ALL_QUANTS + + on_no_fattn_vec_case(Q->ne[0]); } void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index af7c95232..81b50eb58 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -45,6 +45,9 @@ static __global__ void flash_attn_ext_f16( const int nb11, const int nb12, const int nb13, + const int nb21, + const int nb22, + const int nb23, const int ne0, const int ne1, const int ne2, @@ -457,14 +460,18 @@ void launch_fattn_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int32_t precision = KQV->op_params[2]; + const bool quantized_KV = ggml_is_quantized(K->type) || ggml_is_quantized(V->type); + // On AMD the tile kernels perform poorly, use the vec kernel instead: - if (cc >= CC_OFFSET_AMD) { - if (precision == GGML_PREC_DEFAULT) { + if (cc >= CC_OFFSET_AMD || quantized_KV) { + if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst); } else { ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); diff --git a/ggml-cuda/mmq.cu b/ggml-cuda/mmq.cu index c0a66d9b6..ebe1dc5c8 100644 --- a/ggml-cuda/mmq.cu +++ b/ggml-cuda/mmq.cu @@ -386,7 +386,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; } - return vec_dot_q8_0_q8_1_impl + return vec_dot_q8_0_q8_1_impl (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); } @@ -547,7 +547,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( const float * x_dmf = (const float *) x_dm; const float * y_df = (const float *) y_ds; - return vec_dot_q8_0_q8_1_impl + return vec_dot_q8_0_q8_1_impl (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0], y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]); } diff --git a/ggml-cuda/vecdotq.cuh b/ggml-cuda/vecdotq.cuh index 5ebdddcc7..df9752390 100644 --- a/ggml-cuda/vecdotq.cuh +++ b/ggml-cuda/vecdotq.cuh @@ -180,8 +180,8 @@ template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp #define VDR_Q8_0_Q8_1_MMVQ 2 #define VDR_Q8_0_Q8_1_MMQ 8 -template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl( - const int * v, const int * u, const float & d8_0, const float & d8_1) { +template static __device__ __forceinline__ T vec_dot_q8_0_q8_1_impl( + const int * v, const int * u, const T & d8_0, const T & d8_1) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics int sumi = 0; @@ -192,7 +192,7 @@ template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp sumi = __dp4a(v[i], u[i], sumi); } - return d8_0*d8_1 * sumi; + return d8_0*d8_1 * ((T) sumi); #else NO_DEVICE_CODE; #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -656,7 +656,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1( u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); } - return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, __low2half(bq8_1->ds)); + return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, __low2half(bq8_1->ds)); } static __device__ __forceinline__ float vec_dot_q2_K_q8_1(