From 3194a0105868757412ed010b7878727456e6b920 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 26 May 2024 20:14:55 +0200 Subject: [PATCH] fix commented-out kernel variants --- ggml-cuda/fattn-vec-f16.cu | 49 +++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/ggml-cuda/fattn-vec-f16.cu b/ggml-cuda/fattn-vec-f16.cu index fa6750364..71581509c 100644 --- a/ggml-cuda/fattn-vec-f16.cu +++ b/ggml-cuda/fattn-vec-f16.cu @@ -382,37 +382,38 @@ void launch_fattn_vec_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; const int32_t precision = KQV->op_params[2]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); - // if (Q->ne[1] == 1) { - // constexpr int cols_per_block = 1; - // constexpr int parallel_blocks = 4; - // launch_fattn_vec_f16_64_128(ctx, dst); - // return; - // } + if (Q->ne[1] == 1) { + constexpr int cols_per_block = 1; + constexpr int parallel_blocks = 4; + launch_fattn_vec_f16_64_128(ctx, dst); + return; + } - // if (Q->ne[1] == 2) { - // constexpr int cols_per_block = 2; - // constexpr int parallel_blocks = 4; - // launch_fattn_vec_f16_64_128(ctx, dst); - // return; - // } + if (Q->ne[1] == 2) { + constexpr int cols_per_block = 2; + constexpr int parallel_blocks = 4; + launch_fattn_vec_f16_64_128(ctx, dst); + return; + } - // if (Q->ne[1] <= 4) { - // constexpr int cols_per_block = 4; - // constexpr int parallel_blocks = 4; - // launch_fattn_vec_f16_64_128(ctx, dst); - // return; - // } + if (Q->ne[1] <= 4) { + constexpr int cols_per_block = 4; + constexpr int parallel_blocks = 4; + launch_fattn_vec_f16_64_128(ctx, dst); + return; + } - // if (Q->ne[1] <= 8) { - // constexpr int cols_per_block = 8; - // constexpr int parallel_blocks = 4; - // launch_fattn_vec_f16_64_128(ctx, dst); - // return; - // } + if (Q->ne[1] <= 8) { + constexpr int cols_per_block = 8; + constexpr int parallel_blocks = 4; + launch_fattn_vec_f16_64_128(ctx, dst); + return; + } constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1;