diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2c050c0c4..0136fbf28 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6597,8 +6597,8 @@ static __global__ void flash_attn_ext_f16( smax = warp_reduce_max(__hmax(smax, s)); M[j] = warp_reduce_max(__hmax(M[j], s)); - const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); - const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); + const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); + const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); S[j] = S[j]*ms + warp_reduce_sum(vs); @@ -6624,7 +6624,7 @@ static __global__ void flash_attn_ext_f16( smax = warp_reduce_max(smax); M[j] = warp_reduce_max(M[j]); - const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); + const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); // create a QxQ diagonal matrix for rescaling the output if (lane_id == j) { @@ -6637,7 +6637,7 @@ static __global__ void flash_attn_ext_f16( for (int64_t p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; - const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); + const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); ls += vs; @@ -6650,7 +6650,7 @@ static __global__ void flash_attn_ext_f16( } // skip -INF blocks - if (__hisinf(smax)) { + if (__hisinf(smax) == -1) { continue; } @@ -6735,8 +6735,8 @@ static __global__ void flash_attn_ext_f16( M = __hmax(M0, M1); - const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M); - const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M); + const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M); + const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M); S = S0*ms0 + S1*ms1;