mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-01 00:39:00 +01:00
cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
This commit is contained in:
parent
b150abe83e
commit
7c34655b36
70
ggml-cuda.cu
70
ggml-cuda.cu
@ -6462,10 +6462,10 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
half16x16_acc lo[Q16][D16];
|
half16x16_acc lo[Q16][D16];
|
||||||
|
|
||||||
// load heads from Q to shared memory
|
// load heads from Q to shared memory
|
||||||
for (int64_t j = warp_id; j < Q; j += num_warps) {
|
for (int j = warp_id; j < Q; j += num_warps) {
|
||||||
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
||||||
|
|
||||||
for (int64_t i = lane_id; i < D2; i += NW) {
|
for (int i = lane_id; i < D2; i += NW) {
|
||||||
if (iq1 + j < ne01) {
|
if (iq1 + j < ne01) {
|
||||||
sq2[j*T2 + i] = __float22half2_rn(q2[i]);
|
sq2[j*T2 + i] = __float22half2_rn(q2[i]);
|
||||||
} else {
|
} else {
|
||||||
@ -6477,15 +6477,15 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
nvcuda::wmma::fill_fragment(zr, 0.0);
|
nvcuda::wmma::fill_fragment(zr, 0.0);
|
||||||
|
|
||||||
// zero out lo
|
// zero out lo
|
||||||
for (int64_t j = 0; j < Q16; ++j) {
|
for (int j = 0; j < Q16; ++j) {
|
||||||
for (int64_t i = 0; i < D16; ++i) {
|
for (int i = 0; i < D16; ++i) {
|
||||||
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
|
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// zero out shared memory SH
|
// zero out shared memory SH
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int j = 0; j < Q; ++j) {
|
||||||
for (int64_t i = lane_id; i < SH; i += NW) {
|
for (int i = lane_id; i < SH; i += NW) {
|
||||||
ss[j*T + i] = 0.0;
|
ss[j*T + i] = 0.0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -6526,8 +6526,8 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
|
|
||||||
// load the queries from shared memory into local memory
|
// load the queries from shared memory into local memory
|
||||||
half16x16_a mq[Q16][D16];
|
half16x16_a mq[Q16][D16];
|
||||||
for (int64_t j = 0; j < Q16; ++j) {
|
for (int j = 0; j < Q16; ++j) {
|
||||||
for (int64_t i = 0; i < D16; ++i) {
|
for (int i = 0; i < D16; ++i) {
|
||||||
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
|
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -6544,28 +6544,28 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
|
|
||||||
// loop over the KV cache
|
// loop over the KV cache
|
||||||
// each simdgroup handles blocks of Q rows and C columns
|
// each simdgroup handles blocks of Q rows and C columns
|
||||||
for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
|
for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) {
|
||||||
// Q*K^T
|
// Q*K^T
|
||||||
{
|
{
|
||||||
for (int cc = 0; cc < C/16; ++cc) {
|
for (int cc = 0; cc < C/16; ++cc) {
|
||||||
half16x16_acc mqk[Q16];
|
half16x16_acc mqk[Q16];
|
||||||
for (int64_t j = 0; j < Q16; ++j) {
|
for (int j = 0; j < Q16; ++j) {
|
||||||
nvcuda::wmma::fill_fragment(mqk[j], 0);
|
nvcuda::wmma::fill_fragment(mqk[j], 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
|
||||||
for (int64_t i = 0; i < D16; ++i) {
|
for (int i = 0; i < D16; ++i) {
|
||||||
half16x16_bT mk; // transposed key
|
half16x16_bT mk; // transposed key
|
||||||
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
|
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
|
||||||
|
|
||||||
for (int64_t j = 0; j < Q16; ++j) {
|
for (int j = 0; j < Q16; ++j) {
|
||||||
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
|
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// mqk = mqk*scale + mask
|
// mqk = mqk*scale + mask
|
||||||
for (int64_t j = 0; j < Q16; ++j) {
|
for (int j = 0; j < Q16; ++j) {
|
||||||
half16x16_a mqka;
|
half16x16_a mqka;
|
||||||
half16x16_acc mm;
|
half16x16_acc mm;
|
||||||
|
|
||||||
@ -6588,8 +6588,8 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
|
|
||||||
// online softmax
|
// online softmax
|
||||||
if (C == 32) {
|
if (C == 32) {
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int j = 0; j < Q; ++j) {
|
||||||
const int64_t p = lane_id;
|
const int p = lane_id;
|
||||||
|
|
||||||
const half m = M[j];
|
const half m = M[j];
|
||||||
const half s = ss[j*T + p];
|
const half s = ss[j*T + p];
|
||||||
@ -6611,10 +6611,10 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
ss[j*T + p] = vs;
|
ss[j*T + p] = vs;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int j = 0; j < Q; ++j) {
|
||||||
const half m = M[j];
|
const half m = M[j];
|
||||||
|
|
||||||
for (int64_t p = lane_id; p < C; p += NW) {
|
for (int p = lane_id; p < C; p += NW) {
|
||||||
const half s = ss[j*T + p];
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
smax = __hmax(smax, s);
|
smax = __hmax(smax, s);
|
||||||
@ -6633,7 +6633,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
// local sum
|
// local sum
|
||||||
half ls = 0.0f;
|
half ls = 0.0f;
|
||||||
|
|
||||||
for (int64_t p = lane_id; p < C; p += NW) {
|
for (int p = lane_id; p < C; p += NW) {
|
||||||
const half s = ss[j*T + p];
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
|
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
|
||||||
@ -6656,13 +6656,13 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// O = diag(ms)*O
|
// O = diag(ms)*O
|
||||||
for (int64_t j = 0; j < Q16; ++j) {
|
for (int j = 0; j < Q16; ++j) {
|
||||||
half16x16_a mm;
|
half16x16_a mm;
|
||||||
half16x16_b lob;
|
half16x16_b lob;
|
||||||
|
|
||||||
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
|
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
|
||||||
|
|
||||||
for (int64_t i = 0; i < D16; ++i) {
|
for (int i = 0; i < D16; ++i) {
|
||||||
// convert accumulator to matrix_b
|
// convert accumulator to matrix_b
|
||||||
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||||
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);
|
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);
|
||||||
@ -6680,17 +6680,17 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
half16x16_b mk[D16];
|
half16x16_b mk[D16];
|
||||||
for (int64_t i = 0; i < D16; ++i) {
|
for (int i = 0; i < D16; ++i) {
|
||||||
nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half));
|
nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half));
|
||||||
}
|
}
|
||||||
|
|
||||||
half16x16_a mv[Q16];
|
half16x16_a mv[Q16];
|
||||||
for (int64_t j = 0; j < Q16; ++j) {
|
for (int j = 0; j < Q16; ++j) {
|
||||||
nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T);
|
nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int64_t j = 0; j < Q16; ++j) {
|
for (int j = 0; j < Q16; ++j) {
|
||||||
for (int64_t i = 0; i < D16; ++i) {
|
for (int i = 0; i < D16; ++i) {
|
||||||
nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]);
|
nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -6699,7 +6699,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int j = 0; j < Q; ++j) {
|
||||||
if (lane_id == 0) {
|
if (lane_id == 0) {
|
||||||
ss[j*T + 0] = S[j];
|
ss[j*T + 0] = S[j];
|
||||||
ss[j*T + 1] = M[j];
|
ss[j*T + 1] = M[j];
|
||||||
@ -6708,7 +6708,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// reduce the warps sequentially
|
// reduce the warps sequentially
|
||||||
for (int64_t sg = 1; sg < num_warps; ++sg) {
|
for (int sg = 1; sg < num_warps; ++sg) {
|
||||||
half S = __float2half(0.0f);
|
half S = __float2half(0.0f);
|
||||||
half M = __float2half(-INFINITY);
|
half M = __float2half(-INFINITY);
|
||||||
|
|
||||||
@ -6716,8 +6716,8 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
|
|
||||||
// each simdgroup stores its output to shared memory, reusing sq
|
// each simdgroup stores its output to shared memory, reusing sq
|
||||||
if (warp_id == sg) {
|
if (warp_id == sg) {
|
||||||
for (int64_t j = 0; j < Q16; ++j) {
|
for (int j = 0; j < Q16; ++j) {
|
||||||
for (int64_t i = 0; i < D16; ++i) {
|
for (int i = 0; i < D16; ++i) {
|
||||||
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -6727,7 +6727,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
|
|
||||||
// the first simdgroup accumulates the results from the other simdgroups
|
// the first simdgroup accumulates the results from the other simdgroups
|
||||||
if (warp_id == 0) {
|
if (warp_id == 0) {
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int j = 0; j < Q; ++j) {
|
||||||
const half S0 = ss[j*T + 0];
|
const half S0 = ss[j*T + 0];
|
||||||
const half S1 = ss[j*T + sg*SH + 0];
|
const half S1 = ss[j*T + sg*SH + 0];
|
||||||
|
|
||||||
@ -6751,7 +6751,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||||
for (int64_t j = 0; j < Q16; ++j) {
|
for (int j = 0; j < Q16; ++j) {
|
||||||
half16x16_a ms0;
|
half16x16_a ms0;
|
||||||
half16x16_a ms1;
|
half16x16_a ms1;
|
||||||
half16x16_b t;
|
half16x16_b t;
|
||||||
@ -6760,7 +6760,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T);
|
nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T);
|
||||||
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
|
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
|
||||||
|
|
||||||
for (int64_t i = 0; i < D16; ++i) {
|
for (int i = 0; i < D16; ++i) {
|
||||||
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
|
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
|
||||||
nvcuda::wmma::mma_sync(t2, ms1, t, zr);
|
nvcuda::wmma::mma_sync(t2, ms1, t, zr);
|
||||||
|
|
||||||
@ -6776,8 +6776,8 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
|
|
||||||
// store result to shared memory (reuse sq)
|
// store result to shared memory (reuse sq)
|
||||||
if (warp_id == 0) {
|
if (warp_id == 0) {
|
||||||
for (int64_t j = 0; j < Q16; ++j) {
|
for (int j = 0; j < Q16; ++j) {
|
||||||
for (int64_t i = 0; i < D16; ++i) {
|
for (int i = 0; i < D16; ++i) {
|
||||||
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -6785,10 +6785,10 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
|
|
||||||
// final rescale with 1/S and store to global memory
|
// final rescale with 1/S and store to global memory
|
||||||
if (warp_id == 0) {
|
if (warp_id == 0) {
|
||||||
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
|
for (int j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||||
const half S = ss[j*T + 0];
|
const half S = ss[j*T + 0];
|
||||||
|
|
||||||
for (int64_t i = lane_id; i < D; i += NW) {
|
for (int i = lane_id; i < D; i += NW) {
|
||||||
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
|
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user