From 133d99c59980139f5bb75922c8b5fca67d7ba9b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 18 May 2024 12:36:25 +0200 Subject: [PATCH] CUDA: deduplicate FlashAttention code (#7352) --- ggml-cuda/common.cuh | 11 ++ ggml-cuda/fattn-common.cuh | 115 +++++++++++++++++++ ggml-cuda/fattn-tile-f16.cu | 135 +++++----------------- ggml-cuda/fattn-tile-f32.cu | 134 +++++----------------- ggml-cuda/fattn-vec-f16.cu | 215 ++++++++++-------------------------- ggml-cuda/fattn-vec-f32.cu | 170 ++++++---------------------- ggml-cuda/fattn.cu | 177 ++++++++--------------------- ggml-cuda/softmax.cu | 13 +-- 8 files changed, 316 insertions(+), 654 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 784792ba0..8f6fd71cf 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -477,6 +477,17 @@ static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, - typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); +static __device__ __forceinline__ float get_alibi_slope( + const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1 +) { + if (max_bias <= 0.0f) { + return 1.0f; + } + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + return powf(base, exph); +} ////////////////////// diff --git a/ggml-cuda/fattn-common.cuh b/ggml-cuda/fattn-common.cuh index 33f640691..1dd519bde 100644 --- a/ggml-cuda/fattn-common.cuh +++ b/ggml-cuda/fattn-common.cuh @@ -1,7 +1,44 @@ +#include "common.cuh" + +#include + #define FATTN_KQ_STRIDE 256 #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. +typedef void (* fattn_kernel_t)( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int ne0, + const int ne1, + const int ne2, + const int ne3); + template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) @@ -45,3 +82,81 @@ static __global__ void flash_attn_combine_results( dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; } + +template +void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const ggml_tensor * mask = dst->src[3]; + + ggml_tensor * KQV = dst; + + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(K->type == GGML_TYPE_F16); + GGML_ASSERT(V->type == GGML_TYPE_F16); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + + GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); + + ggml_cuda_pool & pool = ctx.pool(); + cudaStream_t main_stream = ctx.stream(); + + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + + const dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); + const int shmem = 0; + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + fattn_kernel<<>>( + (const char *) Q->data, + (const char *) K->data, + (const char *) V->data, + mask ? ((const char *) mask->data) : nullptr, + (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, max_bias, m0, m1, n_head_log2, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if ((parallel_blocks) == 1) { + return; + } + + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + CUDA_CHECK(cudaGetLastError()); +} diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index d2a1077ed..4a07ac6ad 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -54,17 +54,8 @@ static __global__ void flash_attn_tile_ext_f16( const int stride_KV2 = nb11 / sizeof(half2); - half slopeh = __float2half(1.0f); - - // ALiBi - if (max_bias > 0.0f) { - const uint32_t h = blockIdx.y; - - const float base = h < n_head_log2 ? m0 : m1; - const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slopeh = __float2half(powf(base, exph)); - } + const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); + const half slopeh = __float2half(slopef); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); @@ -272,124 +263,50 @@ static __global__ void flash_attn_tile_ext_f16( #endif // FP16_AVAILABLE } -template void launch_fattn_tile_f16( - const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, - ggml_cuda_pool & pool, cudaStream_t main_stream -) { - ggml_cuda_pool_alloc dst_tmp(pool); - ggml_cuda_pool_alloc dst_tmp_meta(pool); - - if (parallel_blocks > 1) { - dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); - dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); +template +void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + switch (Q->ne[0]) { + case 64: { + constexpr int D = 64; + constexpr int nwarps = 8; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + } break; + case 128: { + constexpr int D = 128; + constexpr int nwarps = 8; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + } break; + default: { + GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128."); + } break; } - - constexpr int nwarps = 8; - const dim3 block_dim(WARP_SIZE, nwarps, 1); - const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); - const int shmem = 0; - - float scale = 1.0f; - float max_bias = 0.0f; - - memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); - - const uint32_t n_head = Q->ne[2]; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - flash_attn_tile_ext_f16 - <<>> ( - (const char *) Q->data, - (const char *) K->data, - (const char *) V->data, - mask ? ((const char *) mask->data) : nullptr, - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, max_bias, m0, m1, n_head_log2, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - CUDA_CHECK(cudaGetLastError()); - - if (parallel_blocks == 1) { - return; - } - - const dim3 block_dim_combine(D, 1, 1); - const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); - const int shmem_combine = 0; - - flash_attn_combine_results - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); - CUDA_CHECK(cudaGetLastError()); } void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - const ggml_tensor * K = dst->src[1]; - const ggml_tensor * V = dst->src[2]; - - const ggml_tensor * mask = dst->src[3]; - - ggml_tensor * KQV = dst; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; const int32_t precision = KQV->op_params[2]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); - GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128."); if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; constexpr int parallel_blocks = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } + launch_fattn_tile_f16_64_128(ctx, dst); return; } if (Q->ne[1] <= 32) { constexpr int cols_per_block = 32; constexpr int parallel_blocks = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } + launch_fattn_tile_f16_64_128(ctx, dst); return; } constexpr int cols_per_block = 32; constexpr int parallel_blocks = 1; - switch (Q->ne[0]) { - case 64: - launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } + launch_fattn_tile_f16_64_128(ctx, dst); } diff --git a/ggml-cuda/fattn-tile-f32.cu b/ggml-cuda/fattn-tile-f32.cu index 176895edd..130e7cbdb 100644 --- a/ggml-cuda/fattn-tile-f32.cu +++ b/ggml-cuda/fattn-tile-f32.cu @@ -53,17 +53,7 @@ static __global__ void flash_attn_tile_ext_f32( const int stride_KV2 = nb11 / sizeof(half2); - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - const uint32_t h = blockIdx.y; - - const float base = h < n_head_log2 ? m0 : m1; - const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = powf(base, exph); - } + const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); @@ -270,124 +260,50 @@ static __global__ void flash_attn_tile_ext_f32( } } -template void launch_fattn_tile_f32( - const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, - ggml_cuda_pool & pool, cudaStream_t main_stream -) { - ggml_cuda_pool_alloc dst_tmp(pool); - ggml_cuda_pool_alloc dst_tmp_meta(pool); - - if (parallel_blocks > 1) { - dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); - dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); +template +void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + switch (Q->ne[0]) { + case 64: { + constexpr int D = 64; + constexpr int nwarps = 8; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + } break; + case 128: { + constexpr int D = 128; + constexpr int nwarps = 8; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + } break; + default: { + GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128."); + } break; } - - constexpr int nwarps = 8; - const dim3 block_dim(WARP_SIZE, nwarps, 1); - const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); - const int shmem = 0; - - float scale = 1.0f; - float max_bias = 0.0f; - - memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); - - const uint32_t n_head = Q->ne[2]; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - flash_attn_tile_ext_f32 - <<>> ( - (const char *) Q->data, - (const char *) K->data, - (const char *) V->data, - mask ? ((const char *) mask->data) : nullptr, - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, max_bias, m0, m1, n_head_log2, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - CUDA_CHECK(cudaGetLastError()); - - if (parallel_blocks == 1) { - return; - } - - const dim3 block_dim_combine(D, 1, 1); - const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); - const int shmem_combine = 0; - - flash_attn_combine_results - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); - CUDA_CHECK(cudaGetLastError()); } void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - const ggml_tensor * K = dst->src[1]; - const ggml_tensor * V = dst->src[2]; - - const ggml_tensor * mask = dst->src[3]; - - ggml_tensor * KQV = dst; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; const int32_t precision = KQV->op_params[2]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); - GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128."); if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; constexpr int parallel_blocks = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } + launch_fattn_tile_f32_64_128(ctx, dst); return; } if (Q->ne[1] <= 32) { constexpr int cols_per_block = 32; constexpr int parallel_blocks = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } + launch_fattn_tile_f32_64_128(ctx, dst); return; } constexpr int cols_per_block = 32; constexpr int parallel_blocks = 1; - switch (Q->ne[0]) { - case 64: - launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } + launch_fattn_tile_f32_64_128(ctx, dst); } diff --git a/ggml-cuda/fattn-vec-f16.cu b/ggml-cuda/fattn-vec-f16.cu index a18be5ddc..54e1ac5d1 100644 --- a/ggml-cuda/fattn-vec-f16.cu +++ b/ggml-cuda/fattn-vec-f16.cu @@ -53,17 +53,8 @@ static __global__ void flash_attn_vec_ext_f16( const int stride_KV = nb11 / sizeof(half); const int stride_KV2 = nb11 / sizeof(half2); - half slopeh = __float2half(1.0f); - - // ALiBi - if (max_bias > 0.0f) { - const uint32_t h = blockIdx.y; - - const float base = h < n_head_log2 ? m0 : m1; - const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slopeh = __float2half(powf(base, exph)); - } + const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); + const half slopeh = __float2half(slopef); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); constexpr int nwarps = D / WARP_SIZE; @@ -232,196 +223,104 @@ static __global__ void flash_attn_vec_ext_f16( dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; } - if (parallel_blocks != 1 && threadIdx.x < ncols) { - dst_meta[(ic0 + threadIdx.x)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[threadIdx.x], kqsum[threadIdx.x]); + if (parallel_blocks != 1 && tid < ncols) { + dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); } #else NO_DEVICE_CODE; #endif // FP16_AVAILABLE } -template void launch_fattn_vec_f16( - const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, - ggml_cuda_pool & pool, cudaStream_t main_stream -) { - ggml_cuda_pool_alloc dst_tmp(pool); - ggml_cuda_pool_alloc dst_tmp_meta(pool); - - if (parallel_blocks > 1) { - dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); - dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); - } - - constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; - const dim3 block_dim(WARP_SIZE, nwarps, 1); - const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); - const int shmem = 0; - - float scale = 1.0f; - float max_bias = 0.0f; - - memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); - - const uint32_t n_head = Q->ne[2]; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - flash_attn_vec_ext_f16 - <<>> ( - (const char *) Q->data, - (const char *) K->data, - (const char *) V->data, - mask ? ((const char *) mask->data) : nullptr, - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, max_bias, m0, m1, n_head_log2, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - CUDA_CHECK(cudaGetLastError()); - - if (parallel_blocks == 1) { - return; - } - - const dim3 block_dim_combine(D, 1, 1); - const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); - const int shmem_combine = 0; - - flash_attn_combine_results - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); - CUDA_CHECK(cudaGetLastError()); -} - void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - const ggml_tensor * K = dst->src[1]; - const ggml_tensor * V = dst->src[2]; - - const ggml_tensor * mask = dst->src[3]; - ggml_tensor * KQV = dst; + ggml_tensor * Q = dst->src[0]; const int32_t precision = KQV->op_params[2]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); - constexpr int cols_per_block = 1; + constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; switch (Q->ne[0]) { - case 64: - launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 256: - launch_fattn_vec_f16<256, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; + case 64: { + constexpr int D = 64; + constexpr int nwarps = D/WARP_SIZE; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + } break; + case 128: { + constexpr int D = 128; + constexpr int nwarps = D/WARP_SIZE; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + } break; + case 256: { + constexpr int D = 256; + constexpr int nwarps = D/WARP_SIZE; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + } break; default: GGML_ASSERT(false); break; } } +template +void launch_fattn_vec_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + switch (Q->ne[0]) { + case 64: { + constexpr int D = 64; + constexpr int nwarps = D/WARP_SIZE; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + } break; + case 128: { + constexpr int D = 128; + constexpr int nwarps = D/WARP_SIZE; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + } break; + default: { + GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128."); + } break; + } +} + void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - const ggml_tensor * K = dst->src[1]; - const ggml_tensor * V = dst->src[2]; - - const ggml_tensor * mask = dst->src[3]; - - ggml_tensor * KQV = dst; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; const int32_t precision = KQV->op_params[2]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); - GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128."); if (Q->ne[1] == 1) { - constexpr int cols_per_block = 1; - constexpr int parallel_blocks = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } + ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); return; } if (Q->ne[1] == 2) { - constexpr int cols_per_block = 2; + constexpr int cols_per_block = 2; constexpr int parallel_blocks = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } + launch_fattn_vec_f16_64_128(ctx, dst); return; } if (Q->ne[1] <= 4) { - constexpr int cols_per_block = 4; + constexpr int cols_per_block = 4; constexpr int parallel_blocks = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } + launch_fattn_vec_f16_64_128(ctx, dst); return; } if (Q->ne[1] <= 8) { - constexpr int cols_per_block = 8; + constexpr int cols_per_block = 8; constexpr int parallel_blocks = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } + launch_fattn_vec_f16_64_128(ctx, dst); return; } - constexpr int cols_per_block = 8; + constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; - switch (Q->ne[0]) { - case 64: - launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } + launch_fattn_vec_f16_64_128(ctx, dst); } diff --git a/ggml-cuda/fattn-vec-f32.cu b/ggml-cuda/fattn-vec-f32.cu index 91fcdc8c3..5bcabd092 100644 --- a/ggml-cuda/fattn-vec-f32.cu +++ b/ggml-cuda/fattn-vec-f32.cu @@ -52,17 +52,7 @@ static __global__ void flash_attn_vec_ext_f32( const int stride_KV = nb11 / sizeof(half); const int stride_KV2 = nb11 / sizeof(half2); - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - const uint32_t h = blockIdx.y; - - const float base = h < n_head_log2 ? m0 : m1; - const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = powf(base, exph); - } + const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); constexpr int nwarps = D / WARP_SIZE; @@ -221,161 +211,65 @@ static __global__ void flash_attn_vec_ext_f32( dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; } - if (parallel_blocks != 1 && threadIdx.x < ncols) { - dst_meta[(ic0 + threadIdx.x)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[threadIdx.x], kqsum[threadIdx.x]); + if (parallel_blocks != 1 && tid < ncols) { + dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); } } -template void launch_fattn_vec_f32( - const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, - ggml_cuda_pool & pool, cudaStream_t main_stream -) { - ggml_cuda_pool_alloc dst_tmp(pool); - ggml_cuda_pool_alloc dst_tmp_meta(pool); - - if (parallel_blocks > 1) { - dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); - dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); +template +void launch_fattn_vec_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + switch (Q->ne[0]) { + case 64: { + constexpr int D = 64; + constexpr int nwarps = D/WARP_SIZE; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + } break; + case 128: { + constexpr int D = 128; + constexpr int nwarps = D/WARP_SIZE; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + } break; + default: { + GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128."); + } break; } - - constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; - const dim3 block_dim(WARP_SIZE, nwarps, 1); - const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); - const int shmem = 0; - - float scale = 1.0f; - float max_bias = 0.0f; - - memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); - - const uint32_t n_head = Q->ne[2]; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - flash_attn_vec_ext_f32 - <<>> ( - (const char *) Q->data, - (const char *) K->data, - (const char *) V->data, - mask ? ((const char *) mask->data) : nullptr, - parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, max_bias, m0, m1, n_head_log2, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - CUDA_CHECK(cudaGetLastError()); - - if (parallel_blocks == 1) { - return; - } - - const dim3 block_dim_combine(D, 1, 1); - const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); - const int shmem_combine = 0; - - flash_attn_combine_results - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); - CUDA_CHECK(cudaGetLastError()); } void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; - const ggml_tensor * K = dst->src[1]; - const ggml_tensor * V = dst->src[2]; - - const ggml_tensor * mask = dst->src[3]; - - ggml_tensor * KQV = dst; - - GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128."); if (Q->ne[1] == 1) { - constexpr int cols_per_block = 1; + constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } + launch_fattn_vec_f32_64_128(ctx, dst); return; } if (Q->ne[1] == 2) { - constexpr int cols_per_block = 2; + constexpr int cols_per_block = 2; constexpr int parallel_blocks = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } + launch_fattn_vec_f32_64_128(ctx, dst); return; } if (Q->ne[1] <= 4) { - constexpr int cols_per_block = 4; + constexpr int cols_per_block = 4; constexpr int parallel_blocks = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } + launch_fattn_vec_f32_64_128(ctx, dst); return; } if (Q->ne[1] <= 8) { - constexpr int cols_per_block = 8; + constexpr int cols_per_block = 8; constexpr int parallel_blocks = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } + launch_fattn_vec_f32_64_128(ctx, dst); return; } - constexpr int cols_per_block = 8; + constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; - switch (Q->ne[0]) { - case 64: - launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - case 128: - launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); - break; - default: - GGML_ASSERT(false); - break; - } + launch_fattn_vec_f32_64_128(ctx, dst); } diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index a1918e258..af7c95232 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -85,19 +85,9 @@ static __global__ void flash_attn_ext_f16( const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); - half slopeh = __float2half(1.0f); - half2 slope2 = make_half2(1.0f, 1.0f); - - // ALiBi - if (max_bias > 0.0f) { - const uint32_t h = blockIdx.y; - - const float base = h < n_head_log2 ? m0 : m1; - const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slopeh = __float2half(powf(base, exph)); - slope2 = make_half2(slopeh, slopeh); - } + const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); + const half slopeh = __float2half(slopef); + const half2 slope2 = make_half2(slopef, slopef); frag_b Q_b[D/16][ncols/frag_n]; @@ -439,108 +429,37 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); -template void launch_fattn_f16_impl( - const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, - ggml_cuda_pool & pool, cudaStream_t main_stream -) { - ggml_cuda_pool_alloc dst_tmp(pool); - ggml_cuda_pool_alloc dst_tmp_meta(pool); +template +void launch_fattn_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; - if (parallel_blocks > 1) { - dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); - dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); - } - - constexpr int frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16; - const dim3 block_dim(WARP_SIZE, nwarps, 1); - const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); - const int shmem = 0; - - float scale = 1.0f; - float max_bias = 0.0f; - - memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); - - const uint32_t n_head = Q->ne[2]; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - flash_attn_ext_f16 - <<>> ( - (const char *) Q->data, - (const char *) K->data, - (const char *) V->data, - mask ? ((const char *) mask->data) : nullptr, - (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, max_bias, m0, m1, n_head_log2, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - CUDA_CHECK(cudaGetLastError()); - - if ((parallel_blocks) == 1) { - return; - } - - const dim3 block_dim_combine(D, 1, 1); - const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); - const int shmem_combine = 0; - - flash_attn_combine_results - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); - CUDA_CHECK(cudaGetLastError()); -} - -template void launch_fattn_f16( - const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, - const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream -) { + constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; if (4*blocks_num_pb1 < 2*nsm) { - launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + constexpr int parallel_blocks = 4; + fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); return; } if (2*blocks_num_pb1 < 2*nsm) { - launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + constexpr int parallel_blocks = 2; + fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); return; } - launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + constexpr int parallel_blocks = 1; + fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - const ggml_tensor * K = dst->src[1]; - const ggml_tensor * V = dst->src[2]; - - const ggml_tensor * mask = dst->src[3]; - - ggml_tensor * KQV = dst; - - GGML_ASSERT(Q->type == GGML_TYPE_F32); - GGML_ASSERT(K->type == GGML_TYPE_F16); - GGML_ASSERT(V->type == GGML_TYPE_F16); - GGML_ASSERT(KQV->type == GGML_TYPE_F32); - - GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); - GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && - "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); - - GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; ggml_cuda_set_device(ctx.device); - - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int32_t precision = KQV->op_params[2]; // On AMD the tile kernels perform poorly, use the vec kernel instead: @@ -582,22 +501,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst constexpr int nwarps = 4; switch (Q->ne[0]) { case 64: - launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 64, cols_per_block, nwarps, float>(ctx, dst); break; case 80: - launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 80, cols_per_block, nwarps, float>(ctx, dst); break; case 96: - launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 96, cols_per_block, nwarps, float>(ctx, dst); break; case 112: - launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<112, cols_per_block, nwarps, float>(ctx, dst); break; case 128: - launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<128, cols_per_block, nwarps, float>(ctx, dst); break; case 256: - launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst); break; default: GGML_ASSERT(false); @@ -608,22 +527,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst constexpr int nwarps = 4; switch (Q->ne[0]) { case 64: - launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 64, cols_per_block, nwarps, float>(ctx, dst); break; case 80: - launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 80, cols_per_block, nwarps, float>(ctx, dst); break; case 96: - launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 96, cols_per_block, nwarps, float>(ctx, dst); break; case 112: - launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<112, cols_per_block, nwarps, float>(ctx, dst); break; case 128: - launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<128, cols_per_block, nwarps, float>(ctx, dst); break; // case 256: - // launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + // launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst); // break; default: GGML_ASSERT(false); @@ -643,16 +562,16 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst constexpr int nwarps = 4; switch (Q->ne[0]) { case 64: - launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst); break; case 96: - launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst); break; case 128: - launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst); break; case 256: - launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst); break; default: GGML_ASSERT(false); @@ -666,22 +585,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst constexpr int nwarps = 4; switch (Q->ne[0]) { case 64: - launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst); break; case 80: - launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 80, cols_per_block, nwarps, half>(ctx, dst); break; case 96: - launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst); break; case 112: - launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<112, cols_per_block, nwarps, half>(ctx, dst); break; case 128: - launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst); break; case 256: - launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst); break; default: GGML_ASSERT(false); @@ -694,22 +613,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst constexpr int nwarps = 4; switch (Q->ne[0]) { case 64: - launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst); break; case 80: - launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 80, cols_per_block, nwarps, half>(ctx, dst); break; case 96: - launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst); break; case 112: - launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<112, cols_per_block, nwarps, half>(ctx, dst); break; case 128: - launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst); break; case 256: - launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst); break; default: GGML_ASSERT(false); diff --git a/ggml-cuda/softmax.cu b/ggml-cuda/softmax.cu index ca85285a3..ce64f2f2c 100644 --- a/ggml-cuda/softmax.cu +++ b/ggml-cuda/softmax.cu @@ -1,3 +1,4 @@ +#include "common.cuh" #include "softmax.cuh" template @@ -23,17 +24,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - const int h = rowx/nrows_y; // head index - - const float base = h < n_head_log2 ? m0 : m1; - const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = powf(base, exph); - } + const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1); extern __shared__ float data_soft_max_f32[]; float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication