mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-15 23:00:46 +01:00
cuda : speed-up reduce part of the kernel
This commit is contained in:
parent
a7b471569b
commit
3b1c4e7673
13
ggml-cuda.cu
13
ggml-cuda.cu
@ -6715,9 +6715,6 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
// reduce the warps sequentially
|
||||
for (int sg = 1; sg < num_warps; ++sg) {
|
||||
half S = __float2half(0.0f);
|
||||
half M = CUDART_MIN_DENORM_FP16;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// each simdgroup stores its output to shared memory, reusing sq
|
||||
@ -6733,28 +6730,26 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
// the first simdgroup accumulates the results from the other simdgroups
|
||||
if (warp_id == 0) {
|
||||
for (int j = 0; j < Q; ++j) {
|
||||
for (int j = lane_id; j < Q; j += NW) {
|
||||
const half S0 = ss[j*T + 0];
|
||||
const half S1 = ss[j*T + sg*SH + 0];
|
||||
|
||||
const half M0 = ss[j*T + 1];
|
||||
const half M1 = ss[j*T + sg*SH + 1];
|
||||
|
||||
M = __hmax(M0, M1);
|
||||
const half M = __hmax(M0, M1);
|
||||
|
||||
const half ms0 = hexp(M0 - M);
|
||||
const half ms1 = hexp(M1 - M);
|
||||
|
||||
S = S0*ms0 + S1*ms1;
|
||||
const half S = S0*ms0 + S1*ms1;
|
||||
|
||||
if (lane_id == 0) {
|
||||
ss[j*T + 0] = S;
|
||||
ss[j*T + 1] = M;
|
||||
|
||||
ss[j*T + C + j ] = ms0;
|
||||
ss[j*T + C + j + sg*SH] = ms1;
|
||||
}
|
||||
}
|
||||
|
||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
@ -10931,6 +10926,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||
const int nqpb = NQPB; // queries per block
|
||||
const int ncpw = NCPW; // cache values per warp (does not work for other values)
|
||||
|
||||
GGML_ASSERT(NQPB <= 32);
|
||||
|
||||
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
|
||||
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
|
||||
const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1;
|
||||
|
Loading…
Reference in New Issue
Block a user