CUDA: deduplicate FlashAttention code (#7352)

This commit is contained in:
Johannes Gäßler 2024-05-18 12:36:25 +02:00 committed by GitHub
parent cb42c29427
commit 133d99c599
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 316 additions and 654 deletions

View File

@ -477,6 +477,17 @@ static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
static __device__ __forceinline__ float get_alibi_slope(
const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
) {
if (max_bias <= 0.0f) {
return 1.0f;
}
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
return powf(base, exph);
}
////////////////////// //////////////////////

View File

@ -1,7 +1,44 @@
#include "common.cuh"
#include <cstdint>
#define FATTN_KQ_STRIDE 256 #define FATTN_KQ_STRIDE 256
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
typedef void (* fattn_kernel_t)(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int ne0,
const int ne1,
const int ne2,
const int ne3);
template<int D, int parallel_blocks> // D == head size template<int D, int parallel_blocks> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1) __launch_bounds__(D, 1)
@ -45,3 +82,81 @@ static __global__ void flash_attn_combine_results(
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
} }
template <int D, int parallel_blocks>
void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
ggml_tensor * KQV = dst;
GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(K->type == GGML_TYPE_F16);
GGML_ASSERT(V->type == GGML_TYPE_F16);
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
}
const dim3 block_dim(WARP_SIZE, nwarps, 1);
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
const int shmem = 0;
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
(const char *) Q->data,
(const char *) K->data,
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
CUDA_CHECK(cudaGetLastError());
if ((parallel_blocks) == 1) {
return;
}
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;
flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
CUDA_CHECK(cudaGetLastError());
}

View File

@ -54,17 +54,8 @@ static __global__ void flash_attn_tile_ext_f16(
const int stride_KV2 = nb11 / sizeof(half2); const int stride_KV2 = nb11 / sizeof(half2);
half slopeh = __float2half(1.0f); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
// ALiBi
if (max_bias > 0.0f) {
const uint32_t h = blockIdx.y;
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slopeh = __float2half(powf(base, exph));
}
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@ -272,124 +263,50 @@ static __global__ void flash_attn_tile_ext_f16(
#endif // FP16_AVAILABLE #endif // FP16_AVAILABLE
} }
template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_tile_f16( template <int cols_per_block, int parallel_blocks>
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_pool & pool, cudaStream_t main_stream const ggml_tensor * Q = dst->src[0];
) { switch (Q->ne[0]) {
ggml_cuda_pool_alloc<float> dst_tmp(pool); case 64: {
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool); constexpr int D = 64;
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
}
constexpr int nwarps = 8; constexpr int nwarps = 8;
const dim3 block_dim(WARP_SIZE, nwarps, 1); fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
const int shmem = 0; } break;
case 128: {
float scale = 1.0f; constexpr int D = 128;
float max_bias = 0.0f; constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); } break;
default: {
const uint32_t n_head = Q->ne[2]; GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); } break;
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data,
(const char *) K->data,
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
CUDA_CHECK(cudaGetLastError());
if (parallel_blocks == 1) {
return;
} }
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;
flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
CUDA_CHECK(cudaGetLastError());
} }
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0]; const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
ggml_tensor * KQV = dst;
const int32_t precision = KQV->op_params[2]; const int32_t precision = KQV->op_params[2];
GGML_ASSERT(precision == GGML_PREC_DEFAULT); GGML_ASSERT(precision == GGML_PREC_DEFAULT);
GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
if (Q->ne[1] <= 16) { if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16; constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
switch (Q->ne[0]) { launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return; return;
} }
if (Q->ne[1] <= 32) { if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32; constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
switch (Q->ne[0]) { launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return; return;
} }
constexpr int cols_per_block = 32; constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1; constexpr int parallel_blocks = 1;
switch (Q->ne[0]) { launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
} }

View File

@ -53,17 +53,7 @@ static __global__ void flash_attn_tile_ext_f32(
const int stride_KV2 = nb11 / sizeof(half2); const int stride_KV2 = nb11 / sizeof(half2);
float slope = 1.0f; const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
// ALiBi
if (max_bias > 0.0f) {
const uint32_t h = blockIdx.y;
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = powf(base, exph);
}
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@ -270,124 +260,50 @@ static __global__ void flash_attn_tile_ext_f32(
} }
} }
template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_tile_f32( template <int cols_per_block, int parallel_blocks>
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_pool & pool, cudaStream_t main_stream const ggml_tensor * Q = dst->src[0];
) { switch (Q->ne[0]) {
ggml_cuda_pool_alloc<float> dst_tmp(pool); case 64: {
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool); constexpr int D = 64;
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
}
constexpr int nwarps = 8; constexpr int nwarps = 8;
const dim3 block_dim(WARP_SIZE, nwarps, 1); fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
const int shmem = 0; } break;
case 128: {
float scale = 1.0f; constexpr int D = 128;
float max_bias = 0.0f; constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); } break;
default: {
const uint32_t n_head = Q->ne[2]; GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); } break;
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data,
(const char *) K->data,
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
CUDA_CHECK(cudaGetLastError());
if (parallel_blocks == 1) {
return;
} }
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;
flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
CUDA_CHECK(cudaGetLastError());
} }
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0]; const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
ggml_tensor * KQV = dst;
const int32_t precision = KQV->op_params[2]; const int32_t precision = KQV->op_params[2];
GGML_ASSERT(precision == GGML_PREC_DEFAULT); GGML_ASSERT(precision == GGML_PREC_DEFAULT);
GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
if (Q->ne[1] <= 16) { if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16; constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
switch (Q->ne[0]) { launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return; return;
} }
if (Q->ne[1] <= 32) { if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32; constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
switch (Q->ne[0]) { launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return; return;
} }
constexpr int cols_per_block = 32; constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1; constexpr int parallel_blocks = 1;
switch (Q->ne[0]) { launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
} }

View File

@ -53,17 +53,8 @@ static __global__ void flash_attn_vec_ext_f16(
const int stride_KV = nb11 / sizeof(half); const int stride_KV = nb11 / sizeof(half);
const int stride_KV2 = nb11 / sizeof(half2); const int stride_KV2 = nb11 / sizeof(half2);
half slopeh = __float2half(1.0f); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
// ALiBi
if (max_bias > 0.0f) {
const uint32_t h = blockIdx.y;
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slopeh = __float2half(powf(base, exph));
}
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = D / WARP_SIZE; constexpr int nwarps = D / WARP_SIZE;
@ -232,82 +223,17 @@ static __global__ void flash_attn_vec_ext_f16(
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
} }
if (parallel_blocks != 1 && threadIdx.x < ncols) { if (parallel_blocks != 1 && tid < ncols) {
dst_meta[(ic0 + threadIdx.x)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[threadIdx.x], kqsum[threadIdx.x]); dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
} }
#else #else
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // FP16_AVAILABLE #endif // FP16_AVAILABLE
} }
template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_f16(
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
ggml_cuda_pool & pool, cudaStream_t main_stream
) {
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
}
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
const dim3 block_dim(WARP_SIZE, nwarps, 1);
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
const int shmem = 0;
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data,
(const char *) K->data,
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
CUDA_CHECK(cudaGetLastError());
if (parallel_blocks == 1) {
return;
}
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;
flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
CUDA_CHECK(cudaGetLastError());
}
void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
ggml_tensor * KQV = dst; ggml_tensor * KQV = dst;
ggml_tensor * Q = dst->src[0];
const int32_t precision = KQV->op_params[2]; const int32_t precision = KQV->op_params[2];
GGML_ASSERT(precision == GGML_PREC_DEFAULT); GGML_ASSERT(precision == GGML_PREC_DEFAULT);
@ -315,113 +241,86 @@ void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tens
constexpr int cols_per_block = 1; constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
switch (Q->ne[0]) { switch (Q->ne[0]) {
case 64: case 64: {
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); constexpr int D = 64;
break; constexpr int nwarps = D/WARP_SIZE;
case 128: fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
break; } break;
case 256: case 128: {
launch_fattn_vec_f16<256, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); constexpr int D = 128;
break; constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
case 256: {
constexpr int D = 256;
constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
break; break;
} }
} }
template <int cols_per_block, int parallel_blocks>
void launch_fattn_vec_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
default: {
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
} break;
}
}
void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0]; const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
ggml_tensor * KQV = dst;
const int32_t precision = KQV->op_params[2]; const int32_t precision = KQV->op_params[2];
GGML_ASSERT(precision == GGML_PREC_DEFAULT); GGML_ASSERT(precision == GGML_PREC_DEFAULT);
GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
if (Q->ne[1] == 1) { if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1; ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
constexpr int parallel_blocks = 4;
switch (Q->ne[0]) {
case 64:
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return; return;
} }
if (Q->ne[1] == 2) { if (Q->ne[1] == 2) {
constexpr int cols_per_block = 2; constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
switch (Q->ne[0]) { launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return; return;
} }
if (Q->ne[1] <= 4) { if (Q->ne[1] <= 4) {
constexpr int cols_per_block = 4; constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
switch (Q->ne[0]) { launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return; return;
} }
if (Q->ne[1] <= 8) { if (Q->ne[1] <= 8) {
constexpr int cols_per_block = 8; constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
switch (Q->ne[0]) { launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return; return;
} }
constexpr int cols_per_block = 8; constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1; constexpr int parallel_blocks = 1;
switch (Q->ne[0]) { launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
} }

View File

@ -52,17 +52,7 @@ static __global__ void flash_attn_vec_ext_f32(
const int stride_KV = nb11 / sizeof(half); const int stride_KV = nb11 / sizeof(half);
const int stride_KV2 = nb11 / sizeof(half2); const int stride_KV2 = nb11 / sizeof(half2);
float slope = 1.0f; const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
// ALiBi
if (max_bias > 0.0f) {
const uint32_t h = blockIdx.y;
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = powf(base, exph);
}
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = D / WARP_SIZE; constexpr int nwarps = D / WARP_SIZE;
@ -221,161 +211,65 @@ static __global__ void flash_attn_vec_ext_f32(
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
} }
if (parallel_blocks != 1 && threadIdx.x < ncols) { if (parallel_blocks != 1 && tid < ncols) {
dst_meta[(ic0 + threadIdx.x)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[threadIdx.x], kqsum[threadIdx.x]); dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
} }
} }
template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_f32( template <int cols_per_block, int parallel_blocks>
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, void launch_fattn_vec_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_pool & pool, cudaStream_t main_stream const ggml_tensor * Q = dst->src[0];
) { switch (Q->ne[0]) {
ggml_cuda_pool_alloc<float> dst_tmp(pool); case 64: {
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool); constexpr int D = 64;
constexpr int nwarps = D/WARP_SIZE;
if (parallel_blocks > 1) { fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>;
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); } break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
default: {
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
} break;
} }
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
const dim3 block_dim(WARP_SIZE, nwarps, 1);
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
const int shmem = 0;
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data,
(const char *) K->data,
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
CUDA_CHECK(cudaGetLastError());
if (parallel_blocks == 1) {
return;
}
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;
flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
CUDA_CHECK(cudaGetLastError());
} }
void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0]; const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
ggml_tensor * KQV = dst;
GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
if (Q->ne[1] == 1) { if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1; constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
switch (Q->ne[0]) { launch_fattn_vec_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return; return;
} }
if (Q->ne[1] == 2) { if (Q->ne[1] == 2) {
constexpr int cols_per_block = 2; constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
switch (Q->ne[0]) { launch_fattn_vec_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return; return;
} }
if (Q->ne[1] <= 4) { if (Q->ne[1] <= 4) {
constexpr int cols_per_block = 4; constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
switch (Q->ne[0]) { launch_fattn_vec_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return; return;
} }
if (Q->ne[1] <= 8) { if (Q->ne[1] <= 8) {
constexpr int cols_per_block = 8; constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
switch (Q->ne[0]) { launch_fattn_vec_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
return; return;
} }
constexpr int cols_per_block = 8; constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1; constexpr int parallel_blocks = 1;
switch (Q->ne[0]) { launch_fattn_vec_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
case 64:
launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
case 128:
launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
break;
default:
GGML_ASSERT(false);
break;
}
} }

View File

@ -85,19 +85,9 @@ static __global__ void flash_attn_ext_f16(
const int stride_Q = nb01 / sizeof(float); const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half); const int stride_KV = nb11 / sizeof(half);
half slopeh = __float2half(1.0f); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
half2 slope2 = make_half2(1.0f, 1.0f); const half slopeh = __float2half(slopef);
const half2 slope2 = make_half2(slopef, slopef);
// ALiBi
if (max_bias > 0.0f) {
const uint32_t h = blockIdx.y;
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slopeh = __float2half(powf(base, exph));
slope2 = make_half2(slopeh, slopeh);
}
frag_b Q_b[D/16][ncols/frag_n]; frag_b Q_b[D/16][ncols/frag_n];
@ -439,108 +429,37 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename KQ_acc_t> void launch_fattn_f16_impl( template <int D, int cols_per_block, int nwarps, typename KQ_acc_t>
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, void launch_fattn_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_pool & pool, cudaStream_t main_stream const ggml_tensor * Q = dst->src[0];
) {
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
if (parallel_blocks > 1) { constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
}
constexpr int frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16;
const dim3 block_dim(WARP_SIZE, nwarps, 1);
const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
const int shmem = 0;
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data,
(const char *) K->data,
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
CUDA_CHECK(cudaGetLastError());
if ((parallel_blocks) == 1) {
return;
}
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;
flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
CUDA_CHECK(cudaGetLastError());
}
template <int D, int cols_per_block, int nwarps, typename KQ_acc_t> void launch_fattn_f16(
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream
) {
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
if (4*blocks_num_pb1 < 2*nsm) { if (4*blocks_num_pb1 < 2*nsm) {
launch_fattn_f16_impl<D, cols_per_block, nwarps, 4, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream); constexpr int parallel_blocks = 4;
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
return; return;
} }
if (2*blocks_num_pb1 < 2*nsm) { if (2*blocks_num_pb1 < 2*nsm) {
launch_fattn_f16_impl<D, cols_per_block, nwarps, 2, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream); constexpr int parallel_blocks = 2;
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
return; return;
} }
launch_fattn_f16_impl<D, cols_per_block, nwarps, 1, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream); constexpr int parallel_blocks = 1;
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} }
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0]; const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
ggml_tensor * KQV = dst;
GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(K->type == GGML_TYPE_F16);
GGML_ASSERT(V->type == GGML_TYPE_F16);
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
ggml_cuda_set_device(ctx.device); ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
const int32_t precision = KQV->op_params[2]; const int32_t precision = KQV->op_params[2];
// On AMD the tile kernels perform poorly, use the vec kernel instead: // On AMD the tile kernels perform poorly, use the vec kernel instead:
@ -582,22 +501,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
constexpr int nwarps = 4; constexpr int nwarps = 4;
switch (Q->ne[0]) { switch (Q->ne[0]) {
case 64: case 64:
launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 64, cols_per_block, nwarps, float>(ctx, dst);
break; break;
case 80: case 80:
launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 80, cols_per_block, nwarps, float>(ctx, dst);
break; break;
case 96: case 96:
launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 96, cols_per_block, nwarps, float>(ctx, dst);
break; break;
case 112: case 112:
launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<112, cols_per_block, nwarps, float>(ctx, dst);
break; break;
case 128: case 128:
launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<128, cols_per_block, nwarps, float>(ctx, dst);
break; break;
case 256: case 256:
launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst);
break; break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
@ -608,22 +527,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
constexpr int nwarps = 4; constexpr int nwarps = 4;
switch (Q->ne[0]) { switch (Q->ne[0]) {
case 64: case 64:
launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 64, cols_per_block, nwarps, float>(ctx, dst);
break; break;
case 80: case 80:
launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 80, cols_per_block, nwarps, float>(ctx, dst);
break; break;
case 96: case 96:
launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 96, cols_per_block, nwarps, float>(ctx, dst);
break; break;
case 112: case 112:
launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<112, cols_per_block, nwarps, float>(ctx, dst);
break; break;
case 128: case 128:
launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<128, cols_per_block, nwarps, float>(ctx, dst);
break; break;
// case 256: // case 256:
// launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); // launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst);
// break; // break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
@ -643,16 +562,16 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
constexpr int nwarps = 4; constexpr int nwarps = 4;
switch (Q->ne[0]) { switch (Q->ne[0]) {
case 64: case 64:
launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 96: case 96:
launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 128: case 128:
launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 256: case 256:
launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
break; break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
@ -666,22 +585,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
constexpr int nwarps = 4; constexpr int nwarps = 4;
switch (Q->ne[0]) { switch (Q->ne[0]) {
case 64: case 64:
launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 80: case 80:
launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 80, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 96: case 96:
launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 112: case 112:
launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<112, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 128: case 128:
launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 256: case 256:
launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
break; break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
@ -694,22 +613,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
constexpr int nwarps = 4; constexpr int nwarps = 4;
switch (Q->ne[0]) { switch (Q->ne[0]) {
case 64: case 64:
launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 80: case 80:
launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 80, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 96: case 96:
launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 112: case 112:
launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<112, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 128: case 128:
launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
break; break;
case 256: case 256:
launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
break; break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);

View File

@ -1,3 +1,4 @@
#include "common.cuh"
#include "softmax.cuh" #include "softmax.cuh"
template <typename T> template <typename T>
@ -23,17 +24,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
const int warp_id = threadIdx.x / WARP_SIZE; const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE;
float slope = 1.0f; const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1);
// ALiBi
if (max_bias > 0.0f) {
const int h = rowx/nrows_y; // head index
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = powf(base, exph);
}
extern __shared__ float data_soft_max_f32[]; extern __shared__ float data_soft_max_f32[];
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication