mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 07:34:18 +01:00
cuda : fix __hisinf() result check
This commit is contained in:
parent
12eaa22628
commit
b68a112204
14
ggml-cuda.cu
14
ggml-cuda.cu
@ -6597,8 +6597,8 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
smax = warp_reduce_max(__hmax(smax, s));
|
smax = warp_reduce_max(__hmax(smax, s));
|
||||||
M[j] = warp_reduce_max(__hmax(M[j], s));
|
M[j] = warp_reduce_max(__hmax(M[j], s));
|
||||||
|
|
||||||
const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]);
|
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
|
||||||
const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]);
|
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
|
||||||
|
|
||||||
S[j] = S[j]*ms + warp_reduce_sum(vs);
|
S[j] = S[j]*ms + warp_reduce_sum(vs);
|
||||||
|
|
||||||
@ -6624,7 +6624,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
smax = warp_reduce_max(smax);
|
smax = warp_reduce_max(smax);
|
||||||
M[j] = warp_reduce_max(M[j]);
|
M[j] = warp_reduce_max(M[j]);
|
||||||
|
|
||||||
const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]);
|
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
|
||||||
|
|
||||||
// create a QxQ diagonal matrix for rescaling the output
|
// create a QxQ diagonal matrix for rescaling the output
|
||||||
if (lane_id == j) {
|
if (lane_id == j) {
|
||||||
@ -6637,7 +6637,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
for (int64_t p = lane_id; p < C; p += NW) {
|
for (int64_t p = lane_id; p < C; p += NW) {
|
||||||
const half s = ss[j*T + p];
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]);
|
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
|
||||||
|
|
||||||
ls += vs;
|
ls += vs;
|
||||||
|
|
||||||
@ -6650,7 +6650,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// skip -INF blocks
|
// skip -INF blocks
|
||||||
if (__hisinf(smax)) {
|
if (__hisinf(smax) == -1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -6735,8 +6735,8 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
|
|
||||||
M = __hmax(M0, M1);
|
M = __hmax(M0, M1);
|
||||||
|
|
||||||
const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M);
|
const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M);
|
||||||
const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M);
|
const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M);
|
||||||
|
|
||||||
S = S0*ms0 + S1*ms1;
|
S = S0*ms0 + S1*ms1;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user