CUDA: quantized KV support for FA vec

This commit is contained in:
Johannes Gäßler 2024-05-21 19:38:25 +02:00
parent 10b1e45876
commit 672244a88b
11 changed files with 826 additions and 142 deletions

View File

@ -106,6 +106,7 @@ set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
"llama: max. batch size for using peer access") "llama: max. batch size for using peer access")
option(LLAMA_CUDA_NO_PEER_COPY "llama: do not use peer to peer copies" OFF) option(LLAMA_CUDA_NO_PEER_COPY "llama: do not use peer to peer copies" OFF)
option(LLAMA_CUDA_NO_VMM "llama: do not try to use CUDA VMM" OFF) option(LLAMA_CUDA_NO_VMM "llama: do not try to use CUDA VMM" OFF)
option(LLAMA_CUDA_FA_ALL_QUANTS "llama: compile all quants for FlashAttention" OFF)
option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF) option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
@ -427,6 +428,9 @@ if (LLAMA_CUDA)
if (LLAMA_CUDA_NO_PEER_COPY) if (LLAMA_CUDA_NO_PEER_COPY)
add_compile_definitions(GGML_CUDA_NO_PEER_COPY) add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
endif() endif()
if (LLAMA_CUDA_FA_ALL_QUANTS)
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
endif()
if (LLAMA_STATIC) if (LLAMA_STATIC)
if (WIN32) if (WIN32)

View File

@ -493,7 +493,10 @@ ifdef LLAMA_CUDA_NO_PEER_COPY
endif # LLAMA_CUDA_NO_PEER_COPY endif # LLAMA_CUDA_NO_PEER_COPY
ifdef LLAMA_CUDA_CCBIN ifdef LLAMA_CUDA_CCBIN
MK_NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN) MK_NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN)
endif endif # LLAMA_CUDA_CCBIN
ifdef LLAMA_CUDA_FA_ALL_QUANTS
MK_NVCCFLAGS += -DGGML_CUDA_FA_ALL_QUANTS
endif # LLAMA_CUDA_FA_ALL_QUANTS
ifdef JETSON_EOL_MODULE_DETECT ifdef JETSON_EOL_MODULE_DETECT
define NVCC_COMPILE define NVCC_COMPILE

View File

@ -481,6 +481,7 @@ Building the program with BLAS support may lead to some performance improvements
| LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | | LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | | LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
| LLAMA_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. | | LLAMA_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |
| LLAMA_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. |
- #### hipBLAS - #### hipBLAS

View File

@ -1,4 +1,5 @@
#include "common.cuh" #include "common.cuh"
#include "vecdotq.cuh"
#include <cstdint> #include <cstdint>
@ -34,11 +35,463 @@ typedef void (* fattn_kernel_t)(
const int nb11, const int nb11,
const int nb12, const int nb12,
const int nb13, const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0, const int ne0,
const int ne1, const int ne1,
const int ne2, const int ne2,
const int ne3); const int ne3);
typedef half (*vec_dot_KQ_f16_t)(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
typedef float (*vec_dot_KQ_f32_t)(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
template<typename T, int D>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
#if __CUDA_ARCH__ > MIN_CC_DP4A
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
GGML_UNUSED(Q_v);
half sum = 0.0f;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_1;
const int iqs4 = k_KQ % QI4_0;
const int shift = k_KQ & (QI8_1/2);
const int v = (get_int_from_uint8(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
const int u = Q_q8[k_KQ_0/WARP_SIZE];
const int sumi = __dp4a(v, u, 0);
#if FP16_AVAILABLE
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
} else
#endif // FP16_AVAILABLE
{
const float2 * Q_ds = (const float2 *) Q_ds_v;
sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
}
}
return sum;
#else
GGML_UNUSED(K_c);
GGML_UNUSED(Q_v);
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);
NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ > MIN_CC_DP4A
}
template<typename T, int D>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
#if __CUDA_ARCH__ > MIN_CC_DP4A
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
GGML_UNUSED(Q_v);
T sum = 0.0f;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_1;
const int iqs4 = k_KQ % QI4_1;
const int shift = k_KQ & (QI8_1/2);
const int v = (get_int_from_uint8_aligned(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
const int u = Q_q8[k_KQ_0/WARP_SIZE];
const int sumi = __dp4a(v, u, 0);
#if FP16_AVAILABLE
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
} else
#endif // FP16_AVAILABLE
{
const float2 * Q_ds = (const float2 *) Q_ds_v;
const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
sum += (T) (sumid4d8 + m4s8scaled);
}
}
return sum;
#else
GGML_UNUSED(K_c);
GGML_UNUSED(Q_v);
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);
NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ > MIN_CC_DP4A
}
template<typename T, int D>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
#if __CUDA_ARCH__ > MIN_CC_DP4A
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
GGML_UNUSED(Q_v);
T sum = 0.0f;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_1;
const int iqs4 = k_KQ % QI5_0;
const int iqs8 = k_KQ % QI8_1;
const int shift = k_KQ & (QI8_1/2);
int v = (get_int_from_uint8(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
const int vh = get_int_from_uint8(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
v |= (vh << 4) & 0x00000010; // 0 -> 4
v |= (vh << 11) & 0x00001000; // 1 -> 12
v |= (vh << 18) & 0x00100000; // 2 -> 20
v |= (vh << 25) & 0x10000000; // 3 -> 28
const int u = Q_q8[k_KQ_0/WARP_SIZE];
const int sumi = __dp4a(v, u, 0);
#if FP16_AVAILABLE
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
} else
#endif // FP16_AVAILABLE
{
const float2 * Q_ds = (const float2 *) Q_ds_v;
sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
}
}
return sum;
#else
GGML_UNUSED(K_c);
GGML_UNUSED(Q_v);
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);
NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ > MIN_CC_DP4A
}
template<typename T, int D>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
#if __CUDA_ARCH__ > MIN_CC_DP4A
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
GGML_UNUSED(Q_v);
T sum = 0.0f;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_1;
const int iqs4 = k_KQ % QI5_1;
const int iqs8 = k_KQ % QI8_1;
const int shift = k_KQ & (QI8_1/2);
int v = (get_int_from_uint8(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
const int vh = get_int_from_uint8(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
v |= (vh << 4) & 0x00000010; // 0 -> 4
v |= (vh << 11) & 0x00001000; // 1 -> 12
v |= (vh << 18) & 0x00100000; // 2 -> 20
v |= (vh << 25) & 0x10000000; // 3 -> 28
const int u = Q_q8[k_KQ_0/WARP_SIZE];
const int sumi = __dp4a(v, u, 0);
#if FP16_AVAILABLE
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
} else
#endif // FP16_AVAILABLE
{
const float2 * Q_ds = (const float2 *) Q_ds_v;
const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
sum += (T) (sumid5d8 + m5s8scaled);
}
}
return sum;
#else
GGML_UNUSED(K_c);
GGML_UNUSED(Q_v);
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);
NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ > MIN_CC_DP4A
}
template <typename T, int D>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
#if __CUDA_ARCH__ > MIN_CC_DP4A
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
GGML_UNUSED(Q_v);
T sum = 0.0f;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_0;
const int iqs = k_KQ % QI8_0;
const int v = get_int_from_int8(K_q8_0[ib].qs, iqs);
T Q_d;
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]);
} else {
const float2 * Q_ds = (const float2 *) Q_ds_v;
Q_d = Q_ds[k_KQ_0/WARP_SIZE].x;
}
sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d);
}
return sum;
#else
GGML_UNUSED(K_c);
GGML_UNUSED(Q_v);
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);
NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ > MIN_CC_DP4A
}
template <typename T, int D>
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
const half2 * K_h2 = (const half2 *) K_c;
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);
#if FP16_AVAILABLE
if (std::is_same<T, half>::value) {
const half2 * Q_h2 = (const half2 *) Q_v;
half2 sum2 = make_half2(0.0f, 0.0f);
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const half2 K_ik = K_h2[k_KQ];
sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
}
return __low2half(sum2) + __high2half(sum2);
}
#endif // FP16_AVAILABLE
const float2 * Q_f2 = (const float2 *) Q_v;
float sum = 0.0f;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const half2 K_ik = K_h2[k_KQ];
sum += __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x;
sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y;
}
return sum;
}
template <typename Tds>
static __device__ __forceinline__ void quantize_q8_1_to_shared(
const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) {
float vals[sizeof(int)] = {0.0f};
#pragma unroll
for (int l = 0; l < sizeof(int); ++l) {
vals[l] = scale * x[4*threadIdx.x + l];
}
float amax = fabsf(vals[0]);
float sum = vals[0];
#pragma unroll
for (int l = 1; l < sizeof(int); ++l) {
amax = fmaxf(amax, fabsf(vals[l]));
sum += vals[l];
}
#pragma unroll
for (int mask = QI8_1/2; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32));
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, 32);
}
const float d = amax / 127;
int q32 = 0;
int8_t * q8 = (int8_t *) &q32;
if (d != 0.0f) {
#pragma unroll
for (int l = 0; l < sizeof(int); ++l) {
q8[l] = roundf(vals[l] / d);
}
}
yq32[threadIdx.x] = q32;
if (threadIdx.x % QI8_1 == 0) {
if (std::is_same<Tds, half2>::value) {
((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum);
} else {
((float2 *) yds)[threadIdx.x/QI8_1] = make_float2(d, sum);
}
}
}
typedef half (*dequantize_1_f16_t)(const void *, const int64_t);
typedef float (*dequantize_1_f32_t)(const void *, const int64_t);
template <typename T>
static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) {
const block_q4_0 * x = (const block_q4_0 *) vx;
const int64_t ib = i / QK4_0;
const int iqs = i % (QK4_0/2);
const int shift = (i % QK4_0) / (QK4_0/2);
const T d = x[ib].d;
const int q0 = x[ib].qs[iqs];
const int q = ((q0 >> (4*shift)) & 0x0F) - 8;
return d*((T) q);
}
template <typename T>
static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ vx, const int64_t i) {
const block_q4_1 * x = (const block_q4_1 *) vx;
const int64_t ib = i / QK4_1;
const int iqs = i % (QK4_1/2);
const int shift = (i % QK4_1) / (QK4_1/2);
const half2 dm = x[ib].dm;
const int q0 = x[ib].qs[iqs];
const int q = ((q0 >> (4*shift)) & 0x0F);
#if FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return __low2half(dm)*((half) q) + __high2half(dm);
}
#endif // FP16_AVAILABLE
return __low2float(dm)*((float) q) + __high2float(dm);
}
template <typename T>
static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ vx, const int64_t i) {
const block_q5_0 * x = (const block_q5_0 *) vx;
const int64_t ib = i / QK5_0;
const int idq = i % QK5_0;
const int iqs = i % (QK5_0/2);
const int shift = (i % QK5_0) / (QK5_0/2);
const T d = x[ib].d;
const int ql0 = x[ib].qs[iqs];
const int qh0 = get_int_from_uint8(x[ib].qh, 0);
const int ql = ((ql0 >> (4*shift)) & 0x0F);
const int qh = ((qh0 >> idq) << 4) & 0x10;
const int q = (ql | qh) - 16;
return d*((T) q);
}
template <typename T>
static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ vx, const int64_t i) {
const block_q5_1 * x = (const block_q5_1 *) vx;
const int64_t ib = i / QK5_1;
const int idq = i % QK5_1;
const int iqs = i % (QK5_1/2);
const int shift = (i % QK5_1) / (QK5_1/2);
const half2 dm = x[ib].dm;
const int ql0 = x[ib].qs[iqs];
const int qh0 = get_int_from_uint8_aligned(x[ib].qh, 0);
const int ql = ((ql0 >> (4*shift)) & 0x0F);
const int qh = ((qh0 >> idq) << 4) & 0x10;
const int q = (ql | qh);
#if FP16_AVAILABLE
if (std::is_same<T, half>::value) {
return __low2half(dm)*((half) q) + __high2half(dm);
}
#endif // FP16_AVAILABLE
return __low2float(dm)*((float) q) + __high2float(dm);
}
template <typename T>
static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) {
const block_q8_0 * x = (const block_q8_0 *) vx;
const int64_t ib = i / QK8_0;
const int iqs = i % QK8_0;
const T d = x[ib].d;
const int q = x[ib].qs[iqs];
return d*((T) q);
}
template <typename T>
static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) {
const half * x = (const half *) vx;
return x[i];
}
template<int D, int parallel_blocks> // D == head size template<int D, int parallel_blocks> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1) __launch_bounds__(D, 1)
@ -83,6 +536,45 @@ static __global__ void flash_attn_combine_results(
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
} }
// Aliases for FATTN_VEC_CASE macro:
static constexpr ggml_type ggml_type_q4_0 = GGML_TYPE_Q4_0;
static constexpr ggml_type ggml_type_q4_1 = GGML_TYPE_Q4_1;
static constexpr ggml_type ggml_type_q5_0 = GGML_TYPE_Q5_0;
static constexpr ggml_type ggml_type_q5_1 = GGML_TYPE_Q5_1;
static constexpr ggml_type ggml_type_q8_0 = GGML_TYPE_Q8_0;
static constexpr ggml_type ggml_type_f16 = GGML_TYPE_F16;
typedef half f16;
typedef float f32;
#define FATTN_VEC_CASE(type_VKQ, D, type_suffix_K, type_suffix_V) \
if (Q->ne[0] == (D) && K->type == ggml_type_##type_suffix_K && V->type == ggml_type_##type_suffix_V) { \
constexpr int nwarps = (D)/WARP_SIZE; \
constexpr bool Q_q8_1 = ggml_type_##type_suffix_K != GGML_TYPE_F16; \
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_##type_VKQ< \
(D), cols_per_block, parallel_blocks, \
vec_dot_fattn_vec_KQ_##type_suffix_K<type_VKQ, (D)>, Q_q8_1, dequantize_1_##type_suffix_V<type_VKQ>>; \
launch_fattn<(D), parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block); \
return; \
} \
static void on_no_fattn_vec_case(const int D) {
if (D == 64) {
fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
fprintf(stderr, "By default only f16 KV cache is supported.\n");
fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
GGML_ASSERT(false);
} else {
fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
fprintf(stderr, "Supported combinations:\n");
fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n");
fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n");
fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n");
fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
GGML_ASSERT(false);
}
}
template <int D, int parallel_blocks> template <int D, int parallel_blocks>
void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) { void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) {
const ggml_tensor * Q = dst->src[0]; const ggml_tensor * Q = dst->src[0];
@ -94,8 +586,6 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
ggml_tensor * KQV = dst; ggml_tensor * KQV = dst;
GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(K->type == GGML_TYPE_F16);
GGML_ASSERT(V->type == GGML_TYPE_F16);
GGML_ASSERT(KQV->type == GGML_TYPE_F32); GGML_ASSERT(KQV->type == GGML_TYPE_F32);
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
@ -143,6 +633,7 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3], Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3], K->nb[1], K->nb[2], K->nb[3],
V->nb[1], V->nb[2], V->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
); );
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());

View File

@ -36,6 +36,9 @@ static __global__ void flash_attn_tile_ext_f16(
const int nb11, const int nb11,
const int nb12, const int nb12,
const int nb13, const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0, const int ne0,
const int ne1, const int ne1,
const int ne2, const int ne2,

View File

@ -36,6 +36,9 @@ static __global__ void flash_attn_tile_ext_f32(
const int nb11, const int nb11,
const int nb12, const int nb12,
const int nb13, const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0, const int ne0,
const int ne1, const int ne1,
const int ne2, const int ne2,

View File

@ -2,7 +2,7 @@
#include "fattn-common.cuh" #include "fattn-common.cuh"
#include "fattn-vec-f16.cuh" #include "fattn-vec-f16.cuh"
template<int D, int ncols, int parallel_blocks> // D == head size template<int D, int ncols, int parallel_blocks, vec_dot_KQ_f16_t vec_dot_KQ, bool Q_q8_1, dequantize_1_f16_t dequantize_1_v> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1) __launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@ -34,6 +34,9 @@ static __global__ void flash_attn_vec_ext_f16(
const int nb11, const int nb11,
const int nb12, const int nb12,
const int nb13, const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0, const int ne0,
const int ne1, const int ne1,
const int ne2, const int ne2,
@ -45,13 +48,11 @@ static __global__ void flash_attn_vec_ext_f16(
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0); Q += nb02* blockIdx.y + nb01*ic0;
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); K += nb12*(blockIdx.y / gqa_ratio);
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape V += nb22*(blockIdx.y / gqa_ratio);
const half * maskh = (const half *) mask + ne11*ic0;
const int stride_KV = nb11 / sizeof(half); const half * maskh = (const half *) mask + ne11*ic0;
const int stride_KV2 = nb11 / sizeof(half2);
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef); const half slopeh = __float2half(slopef);
@ -62,10 +63,6 @@ static __global__ void flash_attn_vec_ext_f16(
__builtin_assume(tid < D); __builtin_assume(tid < D);
__shared__ half KQ[ncols*D]; __shared__ half KQ[ncols*D];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
KQ[j*D + tid] = -HALF_MAX_HALF;
}
half2 * KQ2 = (half2 *) KQ; half2 * KQ2 = (half2 *) KQ;
half kqmax[ncols]; half kqmax[ncols];
@ -86,18 +83,77 @@ static __global__ void flash_attn_vec_ext_f16(
} }
__syncthreads(); __syncthreads();
// Convert Q to half2 and store in registers: // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
half2 Q_h2[ncols][D/(2*WARP_SIZE)]; half2 Q_h2[ncols][D/(2*WARP_SIZE)];
int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)];
half2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
if (Q_q8_1) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
// Reuse KQ as temporary storage for converting Q to q8_1:
int * tmp_q_i32 = (int *) &KQ[j*D];
half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int));
// Set memory to zero if out of bounds:
if (ncols > 2 && ic0 + j >= ne01) {
#pragma unroll
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
tmp_q_i32[i] = 0;
}
if (threadIdx.x < D/QK8_1) {
tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f);
}
continue;
}
const float * Q_f = (const float *) (Q + j*nb01);
#pragma unroll
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
quantize_q8_1_to_shared<half2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
}
}
__syncthreads();
#pragma unroll #pragma unroll
for (int j = 0; j < ncols; ++j) { for (int j = 0; j < ncols; ++j) {
int * tmp_q_i32 = (int *) &KQ[j*D];
half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int));
#pragma unroll
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1];
}
}
__syncthreads();
} else {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x; const int i = i0 + threadIdx.x;
const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f); const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
} }
} }
}
#pragma unroll
for (int j = 0; j < ncols; ++j) {
KQ[j*D + tid] = -HALF_MAX_HALF;
}
half2 VKQ[ncols] = {{0.0f, 0.0f}}; half2 VKQ[ncols] = {{0.0f, 0.0f}};
@ -123,22 +179,10 @@ static __global__ void flash_attn_vec_ext_f16(
break; break;
} }
half2 sum2[ncols] = {{0.0f, 0.0f}};
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
#pragma unroll #pragma unroll
for (int j = 0; j < ncols; ++j) { for (int j = 0; j < ncols; ++j) {
sum2[j] += K_ik * Q_h2[j][k_KQ_0/WARP_SIZE]; half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
} sum = warp_reduce_sum(sum);
}
#pragma unroll
for (int j = 0; j < ncols; ++j) {
sum2[j] = warp_reduce_sum(sum2[j]);
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
if (ncols == 1) { if (ncols == 1) {
@ -189,8 +233,8 @@ static __global__ void flash_attn_vec_ext_f16(
} }
half2 V_k; half2 V_k;
reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid);
reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid);
#pragma unroll #pragma unroll
for (int j = 0; j < ncols; ++j) { for (int j = 0; j < ncols; ++j) {
VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
@ -248,19 +292,22 @@ void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tens
case 64: { case 64: {
constexpr int D = 64; constexpr int D = 64;
constexpr int nwarps = D/WARP_SIZE; constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>; fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<
D, cols_per_block, parallel_blocks, vec_dot_fattn_vec_KQ_f16<half, D>, false, dequantize_1_f16<half>>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block); launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break; } break;
case 128: { case 128: {
constexpr int D = 128; constexpr int D = 128;
constexpr int nwarps = D/WARP_SIZE; constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>; fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<
D, cols_per_block, parallel_blocks, vec_dot_fattn_vec_KQ_f16<half, D>, false, dequantize_1_f16<half>>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block); launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break; } break;
case 256: { case 256: {
constexpr int D = 256; constexpr int D = 256;
constexpr int nwarps = D/WARP_SIZE; constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>; fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<
D, cols_per_block, parallel_blocks, vec_dot_fattn_vec_KQ_f16<half, D>, false, dequantize_1_f16<half>>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block); launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break; } break;
default: default:
@ -272,57 +319,100 @@ void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tens
template <int cols_per_block, int parallel_blocks> template <int cols_per_block, int parallel_blocks>
void launch_fattn_vec_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void launch_fattn_vec_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0]; const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) { const ggml_tensor * K = dst->src[1];
case 64: { const ggml_tensor * V = dst->src[2];
constexpr int D = 64;
constexpr int nwarps = D/WARP_SIZE; #ifdef GGML_CUDA_FA_ALL_QUANTS
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>; FATTN_VEC_CASE(f16, 64, f16, q4_0)
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block); FATTN_VEC_CASE(f16, 64, f16, q4_1)
} break; FATTN_VEC_CASE(f16, 64, f16, q5_0)
case 128: { FATTN_VEC_CASE(f16, 64, f16, q5_1)
constexpr int D = 128; FATTN_VEC_CASE(f16, 64, f16, q8_0)
constexpr int nwarps = D/WARP_SIZE; FATTN_VEC_CASE(f16, 64, f16, f16)
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block); FATTN_VEC_CASE(f16, 128, q4_0, q4_0)
} break; FATTN_VEC_CASE(f16, 128, q4_0, q4_1)
default: { FATTN_VEC_CASE(f16, 128, q4_0, q5_0)
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128."); FATTN_VEC_CASE(f16, 128, q4_0, q5_1)
} break; FATTN_VEC_CASE(f16, 128, q4_0, q8_0)
} FATTN_VEC_CASE(f16, 128, q4_0, f16)
FATTN_VEC_CASE(f16, 128, q4_1, q4_0)
FATTN_VEC_CASE(f16, 128, q4_1, q4_1)
FATTN_VEC_CASE(f16, 128, q4_1, q5_0)
FATTN_VEC_CASE(f16, 128, q4_1, q5_1)
FATTN_VEC_CASE(f16, 128, q4_1, q8_0)
FATTN_VEC_CASE(f16, 128, q4_1, f16)
FATTN_VEC_CASE(f16, 128, q5_0, q4_0)
FATTN_VEC_CASE(f16, 128, q5_0, q4_1)
FATTN_VEC_CASE(f16, 128, q5_0, q5_0)
FATTN_VEC_CASE(f16, 128, q5_0, q5_1)
FATTN_VEC_CASE(f16, 128, q5_0, q8_0)
FATTN_VEC_CASE(f16, 128, q5_0, f16)
FATTN_VEC_CASE(f16, 128, q5_1, q4_0)
FATTN_VEC_CASE(f16, 128, q5_1, q4_1)
FATTN_VEC_CASE(f16, 128, q5_1, q5_0)
FATTN_VEC_CASE(f16, 128, q5_1, q5_1)
FATTN_VEC_CASE(f16, 128, q5_1, q8_0)
FATTN_VEC_CASE(f16, 128, q5_1, f16)
FATTN_VEC_CASE(f16, 128, q8_0, q4_0)
FATTN_VEC_CASE(f16, 128, q8_0, q4_1)
FATTN_VEC_CASE(f16, 128, q8_0, q5_0)
FATTN_VEC_CASE(f16, 128, q8_0, q5_1)
FATTN_VEC_CASE(f16, 128, q8_0, q8_0)
FATTN_VEC_CASE(f16, 128, q8_0, f16)
FATTN_VEC_CASE(f16, 128, f16, q4_0)
FATTN_VEC_CASE(f16, 128, f16, q4_1)
FATTN_VEC_CASE(f16, 128, f16, q5_0)
FATTN_VEC_CASE(f16, 128, f16, q5_1)
FATTN_VEC_CASE(f16, 128, f16, q8_0)
FATTN_VEC_CASE(f16, 128, f16, f16)
#else
FATTN_VEC_CASE(f16, 128, q4_0, q4_0)
FATTN_VEC_CASE(f16, 128, q8_0, q8_0)
FATTN_VEC_CASE(f16, 128, f16, f16)
#endif // GGML_CUDA_FA_ALL_QUANTS
on_no_fattn_vec_case(Q->ne[0]);
} }
void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst; const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const int32_t precision = KQV->op_params[2]; const int32_t precision = KQV->op_params[2];
GGML_ASSERT(precision == GGML_PREC_DEFAULT); GGML_ASSERT(precision == GGML_PREC_DEFAULT);
if (Q->ne[1] == 1) { // if (Q->ne[1] == 1) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); // constexpr int cols_per_block = 1;
return; // constexpr int parallel_blocks = 4;
} // launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
// return;
// }
if (Q->ne[1] == 2) { // if (Q->ne[1] == 2) {
constexpr int cols_per_block = 2; // constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4; // constexpr int parallel_blocks = 4;
launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst); // launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
return; // return;
} // }
if (Q->ne[1] <= 4) { // if (Q->ne[1] <= 4) {
constexpr int cols_per_block = 4; // constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4; // constexpr int parallel_blocks = 4;
launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst); // launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
return; // return;
} // }
if (Q->ne[1] <= 8) { // if (Q->ne[1] <= 8) {
constexpr int cols_per_block = 8; // constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4; // constexpr int parallel_blocks = 4;
launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst); // launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
return; // return;
} // }
constexpr int cols_per_block = 8; constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1; constexpr int parallel_blocks = 1;

View File

@ -2,7 +2,7 @@
#include "fattn-common.cuh" #include "fattn-common.cuh"
#include "fattn-vec-f32.cuh" #include "fattn-vec-f32.cuh"
template<int D, int ncols, int parallel_blocks> // D == head size template<int D, int ncols, int parallel_blocks, vec_dot_KQ_f32_t vec_dot_KQ, bool Q_q8_1, dequantize_1_f32_t dequantize_1_v> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1) __launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@ -34,6 +34,9 @@ static __global__ void flash_attn_vec_ext_f32(
const int nb11, const int nb11,
const int nb12, const int nb12,
const int nb13, const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0, const int ne0,
const int ne1, const int ne1,
const int ne2, const int ne2,
@ -44,14 +47,11 @@ static __global__ void flash_attn_vec_ext_f32(
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0); Q += nb02* blockIdx.y + nb01*ic0;
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); K += nb12*(blockIdx.y / gqa_ratio);
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0; const half * maskh = (const half *) mask + ne11*ic0;
const int stride_KV = nb11 / sizeof(half);
const int stride_KV2 = nb11 / sizeof(half2);
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@ -83,17 +83,69 @@ static __global__ void flash_attn_vec_ext_f32(
} }
__syncthreads(); __syncthreads();
// Convert Q to half2 and store in registers: // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
float2 Q_h2[ncols][D/(2*WARP_SIZE)]; float2 Q_f2[ncols][D/(2*WARP_SIZE)];
int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D >= D/(sizeof(int)*QK8_1)];
float2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
if (Q_q8_1) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
// Reuse KQ as temporary storage for converting Q to q8_1:
int * tmp_q_i32 = (int *) &KQ[j*D];
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
// Set memory to zero if out of bounds:
if (ncols > 2 && ic0 + j >= ne01) {
#pragma unroll
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
tmp_q_i32[i] = 0;
}
if (threadIdx.x < D/QK8_1) {
tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f);
}
continue;
}
const float * Q_f = (const float *) (Q + j*nb01);
#pragma unroll
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
quantize_q8_1_to_shared<float2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
}
}
__syncthreads();
#pragma unroll #pragma unroll
for (int j = 0; j < ncols; ++j) { for (int j = 0; j < ncols; ++j) {
int * tmp_q_i32 = (int *) &KQ[j*D];
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
#pragma unroll
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1];
}
}
__syncthreads();
} else {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x; const int i = i0 + threadIdx.x;
Q_h2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f); Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
Q_h2[j][i0/WARP_SIZE].x *= scale; Q_f2[j][i0/WARP_SIZE].x *= scale;
Q_h2[j][i0/WARP_SIZE].y *= scale; Q_f2[j][i0/WARP_SIZE].y *= scale;
}
} }
} }
@ -117,28 +169,16 @@ static __global__ void flash_attn_vec_ext_f32(
break; break;
} }
float sum[ncols] = {0.0f};
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
#pragma unroll #pragma unroll
for (int j = 0; j < ncols; ++j) { for (int j = 0; j < ncols; ++j) {
sum[j] += __low2float(K_ik) * Q_h2[j][k_KQ_0/WARP_SIZE].x; float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
sum[j] += __high2float(K_ik) * Q_h2[j][k_KQ_0/WARP_SIZE].y; sum = warp_reduce_sum(sum);
} sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
}
#pragma unroll kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
for (int j = 0; j < ncols; ++j) {
sum[j] = warp_reduce_sum(sum[j]);
sum[j] += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum[j]);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
KQ[j*D + i_KQ] = sum[j]; KQ[j*D + i_KQ] = sum;
} }
} }
} }
@ -178,7 +218,7 @@ static __global__ void flash_attn_vec_ext_f32(
break; break;
} }
const float V_ki = __half2float(V_h[(k_VKQ_0 + k)*stride_KV + tid]); const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid);
#pragma unroll #pragma unroll
for (int j = 0; j < ncols; ++j) { for (int j = 0; j < ncols; ++j) {
VKQ[j] += V_ki*KQ[j*D + k]; VKQ[j] += V_ki*KQ[j*D + k];
@ -223,23 +263,65 @@ static __global__ void flash_attn_vec_ext_f32(
template <int cols_per_block, int parallel_blocks> template <int cols_per_block, int parallel_blocks>
void launch_fattn_vec_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void launch_fattn_vec_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0]; const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) { const ggml_tensor * K = dst->src[1];
case 64: { const ggml_tensor * V = dst->src[2];
constexpr int D = 64;
constexpr int nwarps = D/WARP_SIZE; #ifdef GGML_CUDA_FA_ALL_QUANTS
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>; FATTN_VEC_CASE(f32, 64, f16, q4_0)
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block); FATTN_VEC_CASE(f32, 64, f16, q4_1)
} break; FATTN_VEC_CASE(f32, 64, f16, q5_0)
case 128: { FATTN_VEC_CASE(f32, 64, f16, q5_1)
constexpr int D = 128; FATTN_VEC_CASE(f32, 64, f16, q8_0)
constexpr int nwarps = D/WARP_SIZE; FATTN_VEC_CASE(f32, 64, f16, f16)
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block); FATTN_VEC_CASE(f32, 128, q4_0, q4_0)
} break; FATTN_VEC_CASE(f32, 128, q4_0, q4_1)
default: { FATTN_VEC_CASE(f32, 128, q4_0, q5_0)
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128."); FATTN_VEC_CASE(f32, 128, q4_0, q5_1)
} break; FATTN_VEC_CASE(f32, 128, q4_0, q8_0)
} FATTN_VEC_CASE(f32, 128, q4_0, f16)
FATTN_VEC_CASE(f32, 128, q4_1, q4_0)
FATTN_VEC_CASE(f32, 128, q4_1, q4_1)
FATTN_VEC_CASE(f32, 128, q4_1, q5_0)
FATTN_VEC_CASE(f32, 128, q4_1, q5_1)
FATTN_VEC_CASE(f32, 128, q4_1, q8_0)
FATTN_VEC_CASE(f32, 128, q4_1, f16)
FATTN_VEC_CASE(f32, 128, q5_0, q4_0)
FATTN_VEC_CASE(f32, 128, q5_0, q4_1)
FATTN_VEC_CASE(f32, 128, q5_0, q5_0)
FATTN_VEC_CASE(f32, 128, q5_0, q5_1)
FATTN_VEC_CASE(f32, 128, q5_0, q8_0)
FATTN_VEC_CASE(f32, 128, q5_0, f16)
FATTN_VEC_CASE(f32, 128, q5_1, q4_0)
FATTN_VEC_CASE(f32, 128, q5_1, q4_1)
FATTN_VEC_CASE(f32, 128, q5_1, q5_0)
FATTN_VEC_CASE(f32, 128, q5_1, q5_1)
FATTN_VEC_CASE(f32, 128, q5_1, q8_0)
FATTN_VEC_CASE(f32, 128, q5_1, f16)
FATTN_VEC_CASE(f32, 128, q8_0, q4_0)
FATTN_VEC_CASE(f32, 128, q8_0, q4_1)
FATTN_VEC_CASE(f32, 128, q8_0, q5_0)
FATTN_VEC_CASE(f32, 128, q8_0, q5_1)
FATTN_VEC_CASE(f32, 128, q8_0, q8_0)
FATTN_VEC_CASE(f32, 128, q8_0, f16)
FATTN_VEC_CASE(f32, 128, f16, q4_0)
FATTN_VEC_CASE(f32, 128, f16, q4_1)
FATTN_VEC_CASE(f32, 128, f16, q5_0)
FATTN_VEC_CASE(f32, 128, f16, q5_1)
FATTN_VEC_CASE(f32, 128, f16, q8_0)
FATTN_VEC_CASE(f32, 128, f16, f16)
#else
FATTN_VEC_CASE(f32, 128, q4_0, q4_0)
FATTN_VEC_CASE(f32, 128, q8_0, q8_0)
FATTN_VEC_CASE(f32, 128, f16, f16)
#endif // GGML_CUDA_FA_ALL_QUANTS
on_no_fattn_vec_case(Q->ne[0]);
} }
void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View File

@ -45,6 +45,9 @@ static __global__ void flash_attn_ext_f16(
const int nb11, const int nb11,
const int nb12, const int nb12,
const int nb13, const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0, const int ne0,
const int ne1, const int ne1,
const int ne2, const int ne2,
@ -457,14 +460,18 @@ void launch_fattn_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst; const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0]; const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
ggml_cuda_set_device(ctx.device); ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int32_t precision = KQV->op_params[2]; const int32_t precision = KQV->op_params[2];
const bool quantized_KV = ggml_is_quantized(K->type) || ggml_is_quantized(V->type);
// On AMD the tile kernels perform poorly, use the vec kernel instead: // On AMD the tile kernels perform poorly, use the vec kernel instead:
if (cc >= CC_OFFSET_AMD) { if (cc >= CC_OFFSET_AMD || quantized_KV) {
if (precision == GGML_PREC_DEFAULT) { if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst); ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
} else { } else {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);

View File

@ -386,7 +386,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
} }
return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ> return vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
(&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
} }
@ -547,7 +547,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
const float * x_dmf = (const float *) x_dm; const float * x_dmf = (const float *) x_dm;
const float * y_df = (const float *) y_ds; const float * y_df = (const float *) y_ds;
return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ> return vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0], (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]); y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
} }

View File

@ -180,8 +180,8 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
#define VDR_Q8_0_Q8_1_MMVQ 2 #define VDR_Q8_0_Q8_1_MMVQ 2
#define VDR_Q8_0_Q8_1_MMQ 8 #define VDR_Q8_0_Q8_1_MMQ 8
template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl( template <typename T, int vdr> static __device__ __forceinline__ T vec_dot_q8_0_q8_1_impl(
const int * v, const int * u, const float & d8_0, const float & d8_1) { const int * v, const int * u, const T & d8_0, const T & d8_1) {
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
int sumi = 0; int sumi = 0;
@ -192,7 +192,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp
sumi = __dp4a(v[i], u[i], sumi); sumi = __dp4a(v[i], u[i], sumi);
} }
return d8_0*d8_1 * sumi; return d8_0*d8_1 * ((T) sumi);
#else #else
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@ -656,7 +656,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
} }
return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds)); return vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds));
} }
static __device__ __forceinline__ float vec_dot_q2_K_q8_1( static __device__ __forceinline__ float vec_dot_q2_K_q8_1(