diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0136fbf28..c3f24242b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6621,7 +6621,6 @@ static __global__ void flash_attn_ext_f16( M[j] = __hmax(M[j], s); } - smax = warp_reduce_max(smax); M[j] = warp_reduce_max(M[j]); const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); @@ -6649,6 +6648,8 @@ static __global__ void flash_attn_ext_f16( } } + smax = warp_reduce_max(smax); + // skip -INF blocks if (__hisinf(smax) == -1) { continue;