cuda : fix __hisinf() result check

This commit is contained in:
Georgi Gerganov 2024-02-02 15:12:28 +02:00
parent 12eaa22628
commit b68a112204
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -6597,8 +6597,8 @@ static __global__ void flash_attn_ext_f16(
smax = warp_reduce_max(__hmax(smax, s));
M[j] = warp_reduce_max(__hmax(M[j], s));
const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]);
const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]);
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
S[j] = S[j]*ms + warp_reduce_sum(vs);
@ -6624,7 +6624,7 @@ static __global__ void flash_attn_ext_f16(
smax = warp_reduce_max(smax);
M[j] = warp_reduce_max(M[j]);
const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]);
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
// create a QxQ diagonal matrix for rescaling the output
if (lane_id == j) {
@ -6637,7 +6637,7 @@ static __global__ void flash_attn_ext_f16(
for (int64_t p = lane_id; p < C; p += NW) {
const half s = ss[j*T + p];
const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]);
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
ls += vs;
@ -6650,7 +6650,7 @@ static __global__ void flash_attn_ext_f16(
}
// skip -INF blocks
if (__hisinf(smax)) {
if (__hisinf(smax) == -1) {
continue;
}
@ -6735,8 +6735,8 @@ static __global__ void flash_attn_ext_f16(
M = __hmax(M0, M1);
const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M);
const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M);
const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M);
const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M);
S = S0*ms0 + S1*ms1;