diff --git a/ggml-cuda/fattn-common.cuh b/ggml-cuda/fattn-common.cuh index 4bf03a49f..c00f8606a 100644 --- a/ggml-cuda/fattn-common.cuh +++ b/ggml-cuda/fattn-common.cuh @@ -1,6 +1,7 @@ #pragma once #include "common.cuh" +#include "convert.cuh" #include "vecdotq.cuh" #include @@ -53,7 +54,7 @@ typedef float (*vec_dot_KQ_f32_t)( template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { -#if __CUDA_ARCH__ > MIN_CC_DP4A +#if __CUDA_ARCH__ >= MIN_CC_DP4A const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c; GGML_UNUSED(Q_v); @@ -95,13 +96,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( GGML_UNUSED(Q_q8); GGML_UNUSED(Q_ds_v); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ > MIN_CC_DP4A +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { -#if __CUDA_ARCH__ > MIN_CC_DP4A +#if __CUDA_ARCH__ >= MIN_CC_DP4A const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c; GGML_UNUSED(Q_v); @@ -147,13 +148,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( GGML_UNUSED(Q_q8); GGML_UNUSED(Q_ds_v); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ > MIN_CC_DP4A +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { -#if __CUDA_ARCH__ > MIN_CC_DP4A +#if __CUDA_ARCH__ >= MIN_CC_DP4A const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c; GGML_UNUSED(Q_v); @@ -202,13 +203,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( GGML_UNUSED(Q_q8); GGML_UNUSED(Q_ds_v); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ > MIN_CC_DP4A +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { -#if __CUDA_ARCH__ > MIN_CC_DP4A +#if __CUDA_ARCH__ >= MIN_CC_DP4A const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c; GGML_UNUSED(Q_v); @@ -261,13 +262,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( GGML_UNUSED(Q_q8); GGML_UNUSED(Q_ds_v); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ > MIN_CC_DP4A +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { -#if __CUDA_ARCH__ > MIN_CC_DP4A +#if __CUDA_ARCH__ >= MIN_CC_DP4A const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c; GGML_UNUSED(Q_v); @@ -302,7 +303,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( GGML_UNUSED(Q_q8); GGML_UNUSED(Q_ds_v); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ > MIN_CC_DP4A +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } template @@ -620,7 +621,10 @@ static void on_no_fattn_vec_case(const int D) { } template -void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) { +void launch_fattn( + ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, + const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V +) { const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; @@ -641,9 +645,49 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); + ggml_cuda_pool_alloc K_f16(pool); + ggml_cuda_pool_alloc V_f16(pool); ggml_cuda_pool_alloc dst_tmp(pool); ggml_cuda_pool_alloc dst_tmp_meta(pool); + char * K_data = (char *) K->data; + size_t nb11 = K->nb[1]; + size_t nb12 = K->nb[2]; + size_t nb13 = K->nb[3]; + + char * V_data = (char *) V->data; + size_t nb21 = V->nb[1]; + size_t nb22 = V->nb[2]; + size_t nb23 = V->nb[3]; + + if (need_f16_K && K->type != GGML_TYPE_F16) { + K_f16.alloc(ggml_nelements(K)); + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); + to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream); + K_data = (char *) K_f16.ptr; + + const size_t bs = ggml_blck_size(K->type); + const size_t ts = ggml_type_size(K->type); + + nb11 = nb11*bs*sizeof(half)/ts; + nb12 = nb12*bs*sizeof(half)/ts; + nb13 = nb13*bs*sizeof(half)/ts; + } + + if (need_f16_V && V->type != GGML_TYPE_F16) { + V_f16.alloc(ggml_nelements(V)); + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); + to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); + V_data = (char *) V_f16.ptr; + + const size_t bs = ggml_blck_size(V->type); + const size_t ts = ggml_type_size(V->type); + + nb21 = nb21*bs*sizeof(half)/ts; + nb22 = nb22*bs*sizeof(half)/ts; + nb23 = nb23*bs*sizeof(half)/ts; + } + if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); @@ -667,8 +711,8 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern fattn_kernel<<>>( (const char *) Q->data, - (const char *) K->data, - (const char *) V->data, + K_data, + V_data, mask ? ((const char *) mask->data) : nullptr, (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, @@ -676,8 +720,8 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - V->nb[1], V->nb[2], V->nb[3], + nb11, nb12, nb13, + nb21, nb22, nb23, KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] ); CUDA_CHECK(cudaGetLastError()); diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 3d64a9eba..cb11d7212 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -278,13 +278,13 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int D = 64; constexpr int nwarps = 8; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/ggml-cuda/fattn-tile-f32.cu b/ggml-cuda/fattn-tile-f32.cu index 61fce0a7e..15e22f495 100644 --- a/ggml-cuda/fattn-tile-f32.cu +++ b/ggml-cuda/fattn-tile-f32.cu @@ -275,13 +275,13 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int D = 64; constexpr int nwarps = 8; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/ggml-cuda/fattn-vec-f16.cuh b/ggml-cuda/fattn-vec-f16.cuh index ea4fcc972..9e1aa2c6b 100644 --- a/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml-cuda/fattn-vec-f16.cuh @@ -290,7 +290,9 @@ template ; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + constexpr bool need_f16_K = D != 128; + constexpr bool need_f16_V = D != 128 && D != 64; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); } template diff --git a/ggml-cuda/fattn-vec-f32.cuh b/ggml-cuda/fattn-vec-f32.cuh index 3009f0f43..ce23a4ebd 100644 --- a/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml-cuda/fattn-vec-f32.cuh @@ -271,7 +271,9 @@ template ; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + constexpr bool need_f16_K = D != 128; + constexpr bool need_f16_V = D != 128 && D != 64; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); } template diff --git a/ggml-cuda/fattn-wmma-f16.cuh b/ggml-cuda/fattn-wmma-f16.cuh index 65ed31852..59cd30d78 100644 --- a/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml-cuda/fattn-wmma-f16.cuh @@ -438,18 +438,18 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm if (4*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 4; fattn_kernel_t fattn_kernel = flash_attn_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } if (2*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 2; fattn_kernel_t fattn_kernel = flash_attn_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } constexpr int parallel_blocks = 1; fattn_kernel_t fattn_kernel = flash_attn_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } #define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \ diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index b35ab67a8..38d30b210 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -298,17 +298,13 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; - const ggml_tensor * K = dst->src[1]; - const ggml_tensor * V = dst->src[2]; ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int32_t precision = KQV->op_params[2]; - const bool quantized_KV = ggml_is_quantized(K->type) || ggml_is_quantized(V->type); - // On AMD the tile kernels perform poorly, use the vec kernel instead: - if (cc >= CC_OFFSET_AMD || quantized_KV) { + if (cc >= CC_OFFSET_AMD) { if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); } else {