From e9e661bd59364e5d4fce035834b6cadcadf8c2ef Mon Sep 17 00:00:00 2001 From: mahorozte <41834471+mahorozte@users.noreply.github.com> Date: Tue, 3 Dec 2024 21:11:43 +0800 Subject: [PATCH] CUDA: remove unnecessary warp reduce in FA (ggml/1032) * kqmax_new_j in every thread within warp is same after operate at line 199,this reduce can be omit * same problem in vec32 --------- Co-authored-by: ZhaoXiaoYu --- ggml/src/ggml-cuda/fattn-vec-f16.cuh | 1 - ggml/src/ggml-cuda/fattn-vec-f32.cuh | 1 - 2 files changed, 2 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 5ec3b91ae..34a2992c7 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -220,7 +220,6 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); if (threadIdx.x == 0) { kqmax_shared[j][threadIdx.y] = kqmax_new_j; } diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 3d93f4a8a..a28fc8b7f 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -206,7 +206,6 @@ static __global__ void flash_attn_vec_ext_f32( for (int j = 0; j < ncols; ++j) { float kqmax_new_j = kqmax_new_arr[j]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); if (threadIdx.x == 0) { kqmax_shared[j][threadIdx.y] = kqmax_new_j; }