mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 22:59:24 +01:00
cuda : avoid warp_reduce for smax
This commit is contained in:
parent
b68a112204
commit
b150abe83e
@ -6621,7 +6621,6 @@ static __global__ void flash_attn_ext_f16(
|
||||
M[j] = __hmax(M[j], s);
|
||||
}
|
||||
|
||||
smax = warp_reduce_max(smax);
|
||||
M[j] = warp_reduce_max(M[j]);
|
||||
|
||||
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
|
||||
@ -6649,6 +6648,8 @@ static __global__ void flash_attn_ext_f16(
|
||||
}
|
||||
}
|
||||
|
||||
smax = warp_reduce_max(smax);
|
||||
|
||||
// skip -INF blocks
|
||||
if (__hisinf(smax) == -1) {
|
||||
continue;
|
||||
|
Loading…
Reference in New Issue
Block a user