cuda : use int instead of int64_t

Noticeably improves performance (thanks to Johannes)
This commit is contained in:
Georgi Gerganov 2024-02-03 13:39:46 +02:00
parent b150abe83e
commit 7c34655b36
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -6462,10 +6462,10 @@ static __global__ void flash_attn_ext_f16(
half16x16_acc lo[Q16][D16];
// load heads from Q to shared memory
for (int64_t j = warp_id; j < Q; j += num_warps) {
for (int j = warp_id; j < Q; j += num_warps) {
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
for (int64_t i = lane_id; i < D2; i += NW) {
for (int i = lane_id; i < D2; i += NW) {
if (iq1 + j < ne01) {
sq2[j*T2 + i] = __float22half2_rn(q2[i]);
} else {
@ -6477,15 +6477,15 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::fill_fragment(zr, 0.0);
// zero out lo
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
for (int j = 0; j < Q16; ++j) {
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
}
}
// zero out shared memory SH
for (int64_t j = 0; j < Q; ++j) {
for (int64_t i = lane_id; i < SH; i += NW) {
for (int j = 0; j < Q; ++j) {
for (int i = lane_id; i < SH; i += NW) {
ss[j*T + i] = 0.0;
}
}
@ -6526,8 +6526,8 @@ static __global__ void flash_attn_ext_f16(
// load the queries from shared memory into local memory
half16x16_a mq[Q16][D16];
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
for (int j = 0; j < Q16; ++j) {
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
}
}
@ -6544,28 +6544,28 @@ static __global__ void flash_attn_ext_f16(
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) {
// Q*K^T
{
for (int cc = 0; cc < C/16; ++cc) {
half16x16_acc mqk[Q16];
for (int64_t j = 0; j < Q16; ++j) {
for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::fill_fragment(mqk[j], 0);
}
const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
for (int64_t i = 0; i < D16; ++i) {
for (int i = 0; i < D16; ++i) {
half16x16_bT mk; // transposed key
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
for (int64_t j = 0; j < Q16; ++j) {
for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
}
}
// mqk = mqk*scale + mask
for (int64_t j = 0; j < Q16; ++j) {
for (int j = 0; j < Q16; ++j) {
half16x16_a mqka;
half16x16_acc mm;
@ -6588,8 +6588,8 @@ static __global__ void flash_attn_ext_f16(
// online softmax
if (C == 32) {
for (int64_t j = 0; j < Q; ++j) {
const int64_t p = lane_id;
for (int j = 0; j < Q; ++j) {
const int p = lane_id;
const half m = M[j];
const half s = ss[j*T + p];
@ -6611,10 +6611,10 @@ static __global__ void flash_attn_ext_f16(
ss[j*T + p] = vs;
}
} else {
for (int64_t j = 0; j < Q; ++j) {
for (int j = 0; j < Q; ++j) {
const half m = M[j];
for (int64_t p = lane_id; p < C; p += NW) {
for (int p = lane_id; p < C; p += NW) {
const half s = ss[j*T + p];
smax = __hmax(smax, s);
@ -6633,7 +6633,7 @@ static __global__ void flash_attn_ext_f16(
// local sum
half ls = 0.0f;
for (int64_t p = lane_id; p < C; p += NW) {
for (int p = lane_id; p < C; p += NW) {
const half s = ss[j*T + p];
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
@ -6656,13 +6656,13 @@ static __global__ void flash_attn_ext_f16(
}
// O = diag(ms)*O
for (int64_t j = 0; j < Q16; ++j) {
for (int j = 0; j < Q16; ++j) {
half16x16_a mm;
half16x16_b lob;
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
for (int64_t i = 0; i < D16; ++i) {
for (int i = 0; i < D16; ++i) {
// convert accumulator to matrix_b
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);
@ -6680,17 +6680,17 @@ static __global__ void flash_attn_ext_f16(
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
half16x16_b mk[D16];
for (int64_t i = 0; i < D16; ++i) {
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half));
}
half16x16_a mv[Q16];
for (int64_t j = 0; j < Q16; ++j) {
for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T);
}
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
for (int j = 0; j < Q16; ++j) {
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]);
}
}
@ -6699,7 +6699,7 @@ static __global__ void flash_attn_ext_f16(
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (int64_t j = 0; j < Q; ++j) {
for (int j = 0; j < Q; ++j) {
if (lane_id == 0) {
ss[j*T + 0] = S[j];
ss[j*T + 1] = M[j];
@ -6708,7 +6708,7 @@ static __global__ void flash_attn_ext_f16(
}
// reduce the warps sequentially
for (int64_t sg = 1; sg < num_warps; ++sg) {
for (int sg = 1; sg < num_warps; ++sg) {
half S = __float2half(0.0f);
half M = __float2half(-INFINITY);
@ -6716,8 +6716,8 @@ static __global__ void flash_attn_ext_f16(
// each simdgroup stores its output to shared memory, reusing sq
if (warp_id == sg) {
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
for (int j = 0; j < Q16; ++j) {
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
}
}
@ -6727,7 +6727,7 @@ static __global__ void flash_attn_ext_f16(
// the first simdgroup accumulates the results from the other simdgroups
if (warp_id == 0) {
for (int64_t j = 0; j < Q; ++j) {
for (int j = 0; j < Q; ++j) {
const half S0 = ss[j*T + 0];
const half S1 = ss[j*T + sg*SH + 0];
@ -6751,7 +6751,7 @@ static __global__ void flash_attn_ext_f16(
}
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
for (int64_t j = 0; j < Q16; ++j) {
for (int j = 0; j < Q16; ++j) {
half16x16_a ms0;
half16x16_a ms1;
half16x16_b t;
@ -6760,7 +6760,7 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T);
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
for (int64_t i = 0; i < D16; ++i) {
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
nvcuda::wmma::mma_sync(t2, ms1, t, zr);
@ -6776,8 +6776,8 @@ static __global__ void flash_attn_ext_f16(
// store result to shared memory (reuse sq)
if (warp_id == 0) {
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
for (int j = 0; j < Q16; ++j) {
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
}
}
@ -6785,10 +6785,10 @@ static __global__ void flash_attn_ext_f16(
// final rescale with 1/S and store to global memory
if (warp_id == 0) {
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
for (int j = 0; j < Q && iq1 + j < ne01; ++j) {
const half S = ss[j*T + 0];
for (int64_t i = lane_id; i < D; i += NW) {
for (int i = lane_id; i < D; i += NW) {
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
}
}