mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-10-30 06:30:15 +01:00
CUDA: faster large batch FA without tensor cores (#7314)
This commit is contained in:
parent
82ca83db3c
commit
0fc1e820a9
395
ggml-cuda/fattn-tile-f16.cu
Normal file
395
ggml-cuda/fattn-tile-f16.cu
Normal file
@ -0,0 +1,395 @@
|
|||||||
|
#include "common.cuh"
|
||||||
|
#include "fattn-common.cuh"
|
||||||
|
#include "fattn-tile-f16.cuh"
|
||||||
|
|
||||||
|
#define FATTN_KQ_STRIDE_TILE_F16 64
|
||||||
|
|
||||||
|
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
|
||||||
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
static __global__ void flash_attn_tile_ext_f16(
|
||||||
|
const char * __restrict__ Q,
|
||||||
|
const char * __restrict__ K,
|
||||||
|
const char * __restrict__ V,
|
||||||
|
const char * __restrict__ mask,
|
||||||
|
float * __restrict__ dst,
|
||||||
|
float2 * __restrict__ dst_meta,
|
||||||
|
const float scale,
|
||||||
|
const float max_bias,
|
||||||
|
const float m0,
|
||||||
|
const float m1,
|
||||||
|
const uint32_t n_head_log2,
|
||||||
|
const int ne00,
|
||||||
|
const int ne01,
|
||||||
|
const int ne02,
|
||||||
|
const int ne03,
|
||||||
|
const int ne10,
|
||||||
|
const int ne11,
|
||||||
|
const int ne12,
|
||||||
|
const int ne13,
|
||||||
|
const int ne31,
|
||||||
|
const int nb31,
|
||||||
|
const int nb01,
|
||||||
|
const int nb02,
|
||||||
|
const int nb03,
|
||||||
|
const int nb11,
|
||||||
|
const int nb12,
|
||||||
|
const int nb13,
|
||||||
|
const int ne0,
|
||||||
|
const int ne1,
|
||||||
|
const int ne2,
|
||||||
|
const int ne3) {
|
||||||
|
#if FP16_AVAILABLE
|
||||||
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
|
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||||
|
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||||
|
|
||||||
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||||
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||||
|
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||||
|
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||||
|
const half * maskh = (const half *) mask + ne11*ic0;
|
||||||
|
|
||||||
|
const int stride_KV2 = nb11 / sizeof(half2);
|
||||||
|
|
||||||
|
half slopeh = __float2half(1.0f);
|
||||||
|
|
||||||
|
// ALiBi
|
||||||
|
if (max_bias > 0.0f) {
|
||||||
|
const uint32_t h = blockIdx.y;
|
||||||
|
|
||||||
|
const float base = h < n_head_log2 ? m0 : m1;
|
||||||
|
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||||
|
|
||||||
|
slopeh = __float2half(powf(base, exph));
|
||||||
|
}
|
||||||
|
|
||||||
|
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||||
|
|
||||||
|
__shared__ half KQ[ncols*FATTN_KQ_STRIDE_TILE_F16];
|
||||||
|
half2 * KQ2 = (half2 *) KQ;
|
||||||
|
|
||||||
|
__shared__ half2 KV_tmp[FATTN_KQ_STRIDE_TILE_F16][D/2 + 1]; // Pad D to avoid memory bank conflicts.
|
||||||
|
|
||||||
|
half kqmax[ncols/nwarps];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
|
kqmax[j0/nwarps] = -HALF_MAX_HALF;
|
||||||
|
}
|
||||||
|
half2 kqsum[ncols/nwarps] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
|
half2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
|
||||||
|
|
||||||
|
// Convert Q to half2 and store in registers:
|
||||||
|
__shared__ half2 Q_h2[ncols][D/2];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
|
const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
|
||||||
|
Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
|
||||||
|
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
|
||||||
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||||
|
|
||||||
|
half kqmax_new[ncols/nwarps];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols/nwarps; ++j) {
|
||||||
|
kqmax_new[j] = kqmax[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += nwarps) {
|
||||||
|
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
||||||
|
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||||
|
|
||||||
|
KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
half2 sum2[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE][ncols/nwarps] = {{{0.0f, 0.0f}}};
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int k_KQ = 0; k_KQ < D/2; ++k_KQ) {
|
||||||
|
half2 K_k[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE];
|
||||||
|
half2 Q_k[ncols/nwarps];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
|
||||||
|
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||||
|
|
||||||
|
K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||||
|
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||||
|
|
||||||
|
Q_k[j_KQ_0/nwarps] = Q_h2[j_KQ][k_KQ];
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||||
|
sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE]*Q_k[j_KQ_0/nwarps];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
|
||||||
|
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||||
|
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||||
|
|
||||||
|
half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||||
|
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||||
|
|
||||||
|
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
|
||||||
|
|
||||||
|
KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F16 + i_KQ] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
|
kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
|
||||||
|
const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]));
|
||||||
|
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F16/2; i0 += WARP_SIZE) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
|
const half2 diff = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] - __half2half2(kqmax[j0/nwarps]);
|
||||||
|
const half2 val = h2exp(diff);
|
||||||
|
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + val;
|
||||||
|
KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] = val;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||||
|
VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += nwarps) {
|
||||||
|
const int k = k0 + threadIdx.y;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
|
KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += 2) {
|
||||||
|
half2 V_k[(D/2)/WARP_SIZE][2];
|
||||||
|
half2 KQ_k[ncols/nwarps];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
|
V_k[i0/WARP_SIZE][0] = KV_tmp[k0 + 0][i];
|
||||||
|
V_k[i0/WARP_SIZE][1] = KV_tmp[k0 + 1][i];
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
|
KQ_k[j0/nwarps] = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + k0/2];
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
|
VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][0]* __low2half2(KQ_k[j0/nwarps]);
|
||||||
|
VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][1]*__high2half2(KQ_k[j0/nwarps]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
||||||
|
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
||||||
|
|
||||||
|
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
|
||||||
|
kqsum_j = warp_reduce_sum(kqsum_j);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
|
||||||
|
const int i0 = i00 + 2*threadIdx.x;
|
||||||
|
|
||||||
|
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
||||||
|
if (parallel_blocks == 1) {
|
||||||
|
dst_val /= __half2half2(kqsum_j);
|
||||||
|
}
|
||||||
|
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||||
|
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
|
||||||
|
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parallel_blocks != 1 && threadIdx.x == 0) {
|
||||||
|
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // FP16_AVAILABLE
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_tile_f16(
|
||||||
|
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
||||||
|
ggml_cuda_pool & pool, cudaStream_t main_stream
|
||||||
|
) {
|
||||||
|
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
||||||
|
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
||||||
|
|
||||||
|
if (parallel_blocks > 1) {
|
||||||
|
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||||
|
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int nwarps = 8;
|
||||||
|
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||||
|
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
||||||
|
const int shmem = 0;
|
||||||
|
|
||||||
|
float scale = 1.0f;
|
||||||
|
float max_bias = 0.0f;
|
||||||
|
|
||||||
|
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
||||||
|
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
||||||
|
|
||||||
|
const uint32_t n_head = Q->ne[2];
|
||||||
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||||
|
|
||||||
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
|
flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) Q->data,
|
||||||
|
(const char *) K->data,
|
||||||
|
(const char *) V->data,
|
||||||
|
mask ? ((const char *) mask->data) : nullptr,
|
||||||
|
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||||
|
scale, max_bias, m0, m1, n_head_log2,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
|
if (parallel_blocks == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const dim3 block_dim_combine(D, 1, 1);
|
||||||
|
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
||||||
|
const int shmem_combine = 0;
|
||||||
|
|
||||||
|
flash_attn_combine_results<D, parallel_blocks>
|
||||||
|
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
||||||
|
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
const ggml_tensor * K = dst->src[1];
|
||||||
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
|
const ggml_tensor * mask = dst->src[3];
|
||||||
|
|
||||||
|
ggml_tensor * KQV = dst;
|
||||||
|
|
||||||
|
const int32_t precision = KQV->op_params[2];
|
||||||
|
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
||||||
|
GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||||
|
|
||||||
|
if (Q->ne[1] <= 16) {
|
||||||
|
constexpr int cols_per_block = 16;
|
||||||
|
constexpr int parallel_blocks = 4;
|
||||||
|
switch (Q->ne[0]) {
|
||||||
|
case 64:
|
||||||
|
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Q->ne[1] <= 32) {
|
||||||
|
constexpr int cols_per_block = 32;
|
||||||
|
constexpr int parallel_blocks = 4;
|
||||||
|
switch (Q->ne[0]) {
|
||||||
|
case 64:
|
||||||
|
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int cols_per_block = 32;
|
||||||
|
constexpr int parallel_blocks = 1;
|
||||||
|
switch (Q->ne[0]) {
|
||||||
|
case 64:
|
||||||
|
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
3
ggml-cuda/fattn-tile-f16.cuh
Normal file
3
ggml-cuda/fattn-tile-f16.cuh
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
#include "common.cuh"
|
||||||
|
|
||||||
|
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
393
ggml-cuda/fattn-tile-f32.cu
Normal file
393
ggml-cuda/fattn-tile-f32.cu
Normal file
@ -0,0 +1,393 @@
|
|||||||
|
#include "common.cuh"
|
||||||
|
#include "fattn-common.cuh"
|
||||||
|
#include "fattn-tile-f32.cuh"
|
||||||
|
|
||||||
|
#define FATTN_KQ_STRIDE_TILE_F32 32
|
||||||
|
|
||||||
|
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
|
||||||
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
static __global__ void flash_attn_tile_ext_f32(
|
||||||
|
const char * __restrict__ Q,
|
||||||
|
const char * __restrict__ K,
|
||||||
|
const char * __restrict__ V,
|
||||||
|
const char * __restrict__ mask,
|
||||||
|
float * __restrict__ dst,
|
||||||
|
float2 * __restrict__ dst_meta,
|
||||||
|
const float scale,
|
||||||
|
const float max_bias,
|
||||||
|
const float m0,
|
||||||
|
const float m1,
|
||||||
|
const uint32_t n_head_log2,
|
||||||
|
const int ne00,
|
||||||
|
const int ne01,
|
||||||
|
const int ne02,
|
||||||
|
const int ne03,
|
||||||
|
const int ne10,
|
||||||
|
const int ne11,
|
||||||
|
const int ne12,
|
||||||
|
const int ne13,
|
||||||
|
const int ne31,
|
||||||
|
const int nb31,
|
||||||
|
const int nb01,
|
||||||
|
const int nb02,
|
||||||
|
const int nb03,
|
||||||
|
const int nb11,
|
||||||
|
const int nb12,
|
||||||
|
const int nb13,
|
||||||
|
const int ne0,
|
||||||
|
const int ne1,
|
||||||
|
const int ne2,
|
||||||
|
const int ne3) {
|
||||||
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
|
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||||
|
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||||
|
|
||||||
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||||
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||||
|
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||||
|
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||||
|
const half * maskh = (const half *) mask + ne11*ic0;
|
||||||
|
|
||||||
|
const int stride_KV2 = nb11 / sizeof(half2);
|
||||||
|
|
||||||
|
float slope = 1.0f;
|
||||||
|
|
||||||
|
// ALiBi
|
||||||
|
if (max_bias > 0.0f) {
|
||||||
|
const uint32_t h = blockIdx.y;
|
||||||
|
|
||||||
|
const float base = h < n_head_log2 ? m0 : m1;
|
||||||
|
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||||
|
|
||||||
|
slope = powf(base, exph);
|
||||||
|
}
|
||||||
|
|
||||||
|
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||||
|
|
||||||
|
__shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32];
|
||||||
|
|
||||||
|
__shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts.
|
||||||
|
float2 * KV_tmp2 = (float2 *) KV_tmp;
|
||||||
|
|
||||||
|
float kqmax[ncols/nwarps];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
|
kqmax[j0/nwarps] = -FLT_MAX/2.0f;
|
||||||
|
}
|
||||||
|
float kqsum[ncols/nwarps] = {0.0f};
|
||||||
|
|
||||||
|
float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
|
||||||
|
|
||||||
|
// Convert Q to half2 and store in registers:
|
||||||
|
__shared__ float Q_f[ncols][D];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
|
||||||
|
float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x];
|
||||||
|
Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
|
||||||
|
Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F32;
|
||||||
|
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F32) {
|
||||||
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||||
|
|
||||||
|
float kqmax_new[ncols/nwarps];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols/nwarps; ++j) {
|
||||||
|
kqmax_new[j] = kqmax[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += nwarps) {
|
||||||
|
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
|
||||||
|
const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
|
||||||
|
KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
|
||||||
|
KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}};
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int k_KQ = 0; k_KQ < D; ++k_KQ) {
|
||||||
|
float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE];
|
||||||
|
float Q_k[ncols/nwarps];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
|
||||||
|
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||||
|
|
||||||
|
K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||||
|
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||||
|
|
||||||
|
Q_k[j_KQ_0/nwarps] = Q_f[j_KQ][k_KQ];
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||||
|
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE] * Q_k[j_KQ_0/nwarps];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
|
||||||
|
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||||
|
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||||
|
|
||||||
|
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||||
|
|
||||||
|
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||||
|
|
||||||
|
KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F32 + i_KQ] = sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
|
kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
|
||||||
|
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]);
|
||||||
|
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
|
||||||
|
|
||||||
|
float kqsum_add = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F32; i0 += WARP_SIZE) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
|
const float diff = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] - kqmax[j0/nwarps];
|
||||||
|
const float val = expf(diff);
|
||||||
|
kqsum_add += val;
|
||||||
|
KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] = val;
|
||||||
|
}
|
||||||
|
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||||
|
VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
|
||||||
|
VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F32; k0 += nwarps) {
|
||||||
|
const int k = k0 + threadIdx.y;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
|
KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
|
||||||
|
KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) {
|
||||||
|
float2 V_k[(D/2)/WARP_SIZE];
|
||||||
|
float KQ_k[ncols/nwarps];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
|
V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i];
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
|
KQ_k[j0/nwarps] = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + k];
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||||
|
VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps];
|
||||||
|
VKQ[j0/nwarps][i0/WARP_SIZE].y += V_k[i0/WARP_SIZE].y*KQ_k[j0/nwarps];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
||||||
|
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
||||||
|
|
||||||
|
float kqsum_j = kqsum[j_VKQ_0/nwarps];
|
||||||
|
kqsum_j = warp_reduce_sum(kqsum_j);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
|
||||||
|
const int i0 = i00 + 2*threadIdx.x;
|
||||||
|
|
||||||
|
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
||||||
|
if (parallel_blocks == 1) {
|
||||||
|
dst_val.x /= kqsum_j;
|
||||||
|
dst_val.y /= kqsum_j;
|
||||||
|
}
|
||||||
|
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||||
|
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x;
|
||||||
|
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parallel_blocks != 1 && threadIdx.x == 0) {
|
||||||
|
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_tile_f32(
|
||||||
|
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
||||||
|
ggml_cuda_pool & pool, cudaStream_t main_stream
|
||||||
|
) {
|
||||||
|
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
||||||
|
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
||||||
|
|
||||||
|
if (parallel_blocks > 1) {
|
||||||
|
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||||
|
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int nwarps = 8;
|
||||||
|
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||||
|
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
||||||
|
const int shmem = 0;
|
||||||
|
|
||||||
|
float scale = 1.0f;
|
||||||
|
float max_bias = 0.0f;
|
||||||
|
|
||||||
|
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
||||||
|
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
||||||
|
|
||||||
|
const uint32_t n_head = Q->ne[2];
|
||||||
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||||
|
|
||||||
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
|
flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) Q->data,
|
||||||
|
(const char *) K->data,
|
||||||
|
(const char *) V->data,
|
||||||
|
mask ? ((const char *) mask->data) : nullptr,
|
||||||
|
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||||
|
scale, max_bias, m0, m1, n_head_log2,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
|
if (parallel_blocks == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const dim3 block_dim_combine(D, 1, 1);
|
||||||
|
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
||||||
|
const int shmem_combine = 0;
|
||||||
|
|
||||||
|
flash_attn_combine_results<D, parallel_blocks>
|
||||||
|
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
||||||
|
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
const ggml_tensor * K = dst->src[1];
|
||||||
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
|
const ggml_tensor * mask = dst->src[3];
|
||||||
|
|
||||||
|
ggml_tensor * KQV = dst;
|
||||||
|
|
||||||
|
const int32_t precision = KQV->op_params[2];
|
||||||
|
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
||||||
|
GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||||
|
|
||||||
|
if (Q->ne[1] <= 16) {
|
||||||
|
constexpr int cols_per_block = 16;
|
||||||
|
constexpr int parallel_blocks = 4;
|
||||||
|
switch (Q->ne[0]) {
|
||||||
|
case 64:
|
||||||
|
launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Q->ne[1] <= 32) {
|
||||||
|
constexpr int cols_per_block = 32;
|
||||||
|
constexpr int parallel_blocks = 4;
|
||||||
|
switch (Q->ne[0]) {
|
||||||
|
case 64:
|
||||||
|
launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int cols_per_block = 32;
|
||||||
|
constexpr int parallel_blocks = 1;
|
||||||
|
switch (Q->ne[0]) {
|
||||||
|
case 64:
|
||||||
|
launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
3
ggml-cuda/fattn-tile-f32.cuh
Normal file
3
ggml-cuda/fattn-tile-f32.cuh
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
#include "common.cuh"
|
||||||
|
|
||||||
|
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
@ -57,7 +57,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
|
|
||||||
// ALiBi
|
// ALiBi
|
||||||
if (max_bias > 0.0f) {
|
if (max_bias > 0.0f) {
|
||||||
const int h = blockIdx.y;
|
const uint32_t h = blockIdx.y;
|
||||||
|
|
||||||
const float base = h < n_head_log2 ? m0 : m1;
|
const float base = h < n_head_log2 ? m0 : m1;
|
||||||
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||||
@ -232,11 +232,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parallel_blocks != 1 && tid != 0) {
|
if (parallel_blocks != 1 && threadIdx.x < ncols) {
|
||||||
#pragma unroll
|
dst_meta[(ic0 + threadIdx.x)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[threadIdx.x], kqsum[threadIdx.x]);
|
||||||
for (int j = 0; j < ncols; ++j) {
|
|
||||||
dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
|
@ -56,7 +56,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
|
|
||||||
// ALiBi
|
// ALiBi
|
||||||
if (max_bias > 0.0f) {
|
if (max_bias > 0.0f) {
|
||||||
const int h = blockIdx.y;
|
const uint32_t h = blockIdx.y;
|
||||||
|
|
||||||
const float base = h < n_head_log2 ? m0 : m1;
|
const float base = h < n_head_log2 ? m0 : m1;
|
||||||
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||||
@ -221,11 +221,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parallel_blocks != 1 && tid != 0) {
|
if (parallel_blocks != 1 && threadIdx.x < ncols) {
|
||||||
#pragma unroll
|
dst_meta[(ic0 + threadIdx.x)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[threadIdx.x], kqsum[threadIdx.x]);
|
||||||
for (int j = 0; j < ncols; ++j) {
|
|
||||||
dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
#include "fattn-common.cuh"
|
#include "fattn-common.cuh"
|
||||||
|
#include "fattn-tile-f16.cuh"
|
||||||
|
#include "fattn-tile-f32.cuh"
|
||||||
#include "fattn-vec-f16.cuh"
|
#include "fattn-vec-f16.cuh"
|
||||||
#include "fattn-vec-f32.cuh"
|
#include "fattn-vec-f32.cuh"
|
||||||
#include "fattn.cuh"
|
#include "fattn.cuh"
|
||||||
@ -88,7 +90,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
|
|
||||||
// ALiBi
|
// ALiBi
|
||||||
if (max_bias > 0.0f) {
|
if (max_bias > 0.0f) {
|
||||||
const int h = blockIdx.y;
|
const uint32_t h = blockIdx.y;
|
||||||
|
|
||||||
const float base = h < n_head_log2 ? m0 : m1;
|
const float base = h < n_head_log2 ? m0 : m1;
|
||||||
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||||
@ -541,13 +543,31 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||||||
|
|
||||||
const int32_t precision = KQV->op_params[2];
|
const int32_t precision = KQV->op_params[2];
|
||||||
|
|
||||||
|
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||||
|
if (cc >= CC_OFFSET_AMD) {
|
||||||
|
if (precision == GGML_PREC_DEFAULT) {
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
|
||||||
|
} else {
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (!fast_fp16_available(cc)) {
|
if (!fast_fp16_available(cc)) {
|
||||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
if (Q->ne[1] <= 8) {
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||||
|
} else {
|
||||||
|
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!fp16_mma_available(cc)) {
|
if (!fp16_mma_available(cc)) {
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
|
if (Q->ne[1] <= 8) {
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst);
|
||||||
|
} else {
|
||||||
|
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user