mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 22:59:24 +01:00
cuda : simplify softmax
This commit is contained in:
parent
e04ff39181
commit
cfd9732b2e
23
ggml-cuda.cu
23
ggml-cuda.cu
@ -6512,11 +6512,10 @@ static __global__ void flash_attn_ext_f16(
|
||||
__syncthreads();
|
||||
|
||||
{
|
||||
half S[Q];
|
||||
half S = __float2half(0.0f);
|
||||
half M[Q];
|
||||
|
||||
for (int i = 0; i < Q; ++i) {
|
||||
S[i] = __float2half(0.0f);
|
||||
M[i] = CUDART_MIN_DENORM_FP16;
|
||||
}
|
||||
|
||||
@ -6626,13 +6625,6 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
M[j] = warp_reduce_max(M[j]);
|
||||
|
||||
const half ms = hexp(m - M[j]);
|
||||
|
||||
// create a QxQ diagonal matrix for rescaling the output
|
||||
if (lane_id == j) {
|
||||
ss[j*T + C + j] = ms;
|
||||
}
|
||||
|
||||
// local sum
|
||||
half2 ls = make_half2(0.0f, 0.0f);
|
||||
half2 M2 = make_half2(M[j], M[j]);
|
||||
@ -6652,7 +6644,14 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
ls = warp_reduce_sum(ls);
|
||||
|
||||
S[j] = S[j]*ms + ls.x + ls.y;
|
||||
const half ms = hexp(m - M[j]);
|
||||
|
||||
// create a QxQ diagonal matrix for rescaling the output
|
||||
if (lane_id == j) {
|
||||
ss[j*T + C + j] = ms;
|
||||
|
||||
S = S*ms + ls.x + ls.y;
|
||||
}
|
||||
}
|
||||
|
||||
smax = warp_reduce_max(smax);
|
||||
@ -6709,8 +6708,8 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||
for (int j = 0; j < Q; ++j) {
|
||||
if (lane_id == 0) {
|
||||
ss[j*T + 0] = S[j];
|
||||
if (lane_id == j) {
|
||||
ss[j*T + 0] = S;
|
||||
ss[j*T + 1] = M[j];
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user