mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-01 00:39:00 +01:00
cuda : simplify softmax
This commit is contained in:
parent
e04ff39181
commit
cfd9732b2e
23
ggml-cuda.cu
23
ggml-cuda.cu
@ -6512,11 +6512,10 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
{
|
{
|
||||||
half S[Q];
|
half S = __float2half(0.0f);
|
||||||
half M[Q];
|
half M[Q];
|
||||||
|
|
||||||
for (int i = 0; i < Q; ++i) {
|
for (int i = 0; i < Q; ++i) {
|
||||||
S[i] = __float2half(0.0f);
|
|
||||||
M[i] = CUDART_MIN_DENORM_FP16;
|
M[i] = CUDART_MIN_DENORM_FP16;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -6626,13 +6625,6 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
|
|
||||||
M[j] = warp_reduce_max(M[j]);
|
M[j] = warp_reduce_max(M[j]);
|
||||||
|
|
||||||
const half ms = hexp(m - M[j]);
|
|
||||||
|
|
||||||
// create a QxQ diagonal matrix for rescaling the output
|
|
||||||
if (lane_id == j) {
|
|
||||||
ss[j*T + C + j] = ms;
|
|
||||||
}
|
|
||||||
|
|
||||||
// local sum
|
// local sum
|
||||||
half2 ls = make_half2(0.0f, 0.0f);
|
half2 ls = make_half2(0.0f, 0.0f);
|
||||||
half2 M2 = make_half2(M[j], M[j]);
|
half2 M2 = make_half2(M[j], M[j]);
|
||||||
@ -6652,7 +6644,14 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
|
|
||||||
ls = warp_reduce_sum(ls);
|
ls = warp_reduce_sum(ls);
|
||||||
|
|
||||||
S[j] = S[j]*ms + ls.x + ls.y;
|
const half ms = hexp(m - M[j]);
|
||||||
|
|
||||||
|
// create a QxQ diagonal matrix for rescaling the output
|
||||||
|
if (lane_id == j) {
|
||||||
|
ss[j*T + C + j] = ms;
|
||||||
|
|
||||||
|
S = S*ms + ls.x + ls.y;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
smax = warp_reduce_max(smax);
|
smax = warp_reduce_max(smax);
|
||||||
@ -6709,8 +6708,8 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
|
|
||||||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||||
for (int j = 0; j < Q; ++j) {
|
for (int j = 0; j < Q; ++j) {
|
||||||
if (lane_id == 0) {
|
if (lane_id == j) {
|
||||||
ss[j*T + 0] = S[j];
|
ss[j*T + 0] = S;
|
||||||
ss[j*T + 1] = M[j];
|
ss[j*T + 1] = M[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user