mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 15:18:26 +01:00
cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
This commit is contained in:
parent
b150abe83e
commit
7c34655b36
70
ggml-cuda.cu
70
ggml-cuda.cu
@ -6462,10 +6462,10 @@ static __global__ void flash_attn_ext_f16(
|
||||
half16x16_acc lo[Q16][D16];
|
||||
|
||||
// load heads from Q to shared memory
|
||||
for (int64_t j = warp_id; j < Q; j += num_warps) {
|
||||
for (int j = warp_id; j < Q; j += num_warps) {
|
||||
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
||||
|
||||
for (int64_t i = lane_id; i < D2; i += NW) {
|
||||
for (int i = lane_id; i < D2; i += NW) {
|
||||
if (iq1 + j < ne01) {
|
||||
sq2[j*T2 + i] = __float22half2_rn(q2[i]);
|
||||
} else {
|
||||
@ -6477,15 +6477,15 @@ static __global__ void flash_attn_ext_f16(
|
||||
nvcuda::wmma::fill_fragment(zr, 0.0);
|
||||
|
||||
// zero out lo
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
// zero out shared memory SH
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
for (int64_t i = lane_id; i < SH; i += NW) {
|
||||
for (int j = 0; j < Q; ++j) {
|
||||
for (int i = lane_id; i < SH; i += NW) {
|
||||
ss[j*T + i] = 0.0;
|
||||
}
|
||||
}
|
||||
@ -6526,8 +6526,8 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
// load the queries from shared memory into local memory
|
||||
half16x16_a mq[Q16][D16];
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
|
||||
}
|
||||
}
|
||||
@ -6544,28 +6544,28 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
// loop over the KV cache
|
||||
// each simdgroup handles blocks of Q rows and C columns
|
||||
for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
|
||||
for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) {
|
||||
// Q*K^T
|
||||
{
|
||||
for (int cc = 0; cc < C/16; ++cc) {
|
||||
half16x16_acc mqk[Q16];
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
nvcuda::wmma::fill_fragment(mqk[j], 0);
|
||||
}
|
||||
|
||||
const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
half16x16_bT mk; // transposed key
|
||||
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
|
||||
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// mqk = mqk*scale + mask
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
half16x16_a mqka;
|
||||
half16x16_acc mm;
|
||||
|
||||
@ -6588,8 +6588,8 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
// online softmax
|
||||
if (C == 32) {
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const int64_t p = lane_id;
|
||||
for (int j = 0; j < Q; ++j) {
|
||||
const int p = lane_id;
|
||||
|
||||
const half m = M[j];
|
||||
const half s = ss[j*T + p];
|
||||
@ -6611,10 +6611,10 @@ static __global__ void flash_attn_ext_f16(
|
||||
ss[j*T + p] = vs;
|
||||
}
|
||||
} else {
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
for (int j = 0; j < Q; ++j) {
|
||||
const half m = M[j];
|
||||
|
||||
for (int64_t p = lane_id; p < C; p += NW) {
|
||||
for (int p = lane_id; p < C; p += NW) {
|
||||
const half s = ss[j*T + p];
|
||||
|
||||
smax = __hmax(smax, s);
|
||||
@ -6633,7 +6633,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
// local sum
|
||||
half ls = 0.0f;
|
||||
|
||||
for (int64_t p = lane_id; p < C; p += NW) {
|
||||
for (int p = lane_id; p < C; p += NW) {
|
||||
const half s = ss[j*T + p];
|
||||
|
||||
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
|
||||
@ -6656,13 +6656,13 @@ static __global__ void flash_attn_ext_f16(
|
||||
}
|
||||
|
||||
// O = diag(ms)*O
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
half16x16_a mm;
|
||||
half16x16_b lob;
|
||||
|
||||
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
|
||||
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
// convert accumulator to matrix_b
|
||||
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);
|
||||
@ -6680,17 +6680,17 @@ static __global__ void flash_attn_ext_f16(
|
||||
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
half16x16_b mk[D16];
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half));
|
||||
}
|
||||
|
||||
half16x16_a mv[Q16];
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T);
|
||||
}
|
||||
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]);
|
||||
}
|
||||
}
|
||||
@ -6699,7 +6699,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
}
|
||||
|
||||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
for (int j = 0; j < Q; ++j) {
|
||||
if (lane_id == 0) {
|
||||
ss[j*T + 0] = S[j];
|
||||
ss[j*T + 1] = M[j];
|
||||
@ -6708,7 +6708,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
}
|
||||
|
||||
// reduce the warps sequentially
|
||||
for (int64_t sg = 1; sg < num_warps; ++sg) {
|
||||
for (int sg = 1; sg < num_warps; ++sg) {
|
||||
half S = __float2half(0.0f);
|
||||
half M = __float2half(-INFINITY);
|
||||
|
||||
@ -6716,8 +6716,8 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
// each simdgroup stores its output to shared memory, reusing sq
|
||||
if (warp_id == sg) {
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||
}
|
||||
}
|
||||
@ -6727,7 +6727,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
// the first simdgroup accumulates the results from the other simdgroups
|
||||
if (warp_id == 0) {
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
for (int j = 0; j < Q; ++j) {
|
||||
const half S0 = ss[j*T + 0];
|
||||
const half S1 = ss[j*T + sg*SH + 0];
|
||||
|
||||
@ -6751,7 +6751,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
}
|
||||
|
||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
half16x16_a ms0;
|
||||
half16x16_a ms1;
|
||||
half16x16_b t;
|
||||
@ -6760,7 +6760,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T);
|
||||
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
|
||||
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
|
||||
nvcuda::wmma::mma_sync(t2, ms1, t, zr);
|
||||
|
||||
@ -6776,8 +6776,8 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
// store result to shared memory (reuse sq)
|
||||
if (warp_id == 0) {
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
for (int i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||
}
|
||||
}
|
||||
@ -6785,10 +6785,10 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
// final rescale with 1/S and store to global memory
|
||||
if (warp_id == 0) {
|
||||
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||
for (int j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||
const half S = ss[j*T + 0];
|
||||
|
||||
for (int64_t i = lane_id; i < D; i += NW) {
|
||||
for (int i = lane_id; i < D; i += NW) {
|
||||
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user