cuda : make loops use the same loop values

Thanks Johannes again for the tip
This commit is contained in:
Georgi Gerganov 2024-02-03 14:01:32 +02:00
parent 7c34655b36
commit 1f8a592482
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 37 additions and 8 deletions

View File

@ -6462,10 +6462,20 @@ static __global__ void flash_attn_ext_f16(
half16x16_acc lo[Q16][D16];
// load heads from Q to shared memory
for (int j = warp_id; j < Q; j += num_warps) {
for (int j0 = 0; j0 < Q; j0 += num_warps) {
const int j = j0 + warp_id;
if (j >= Q) {
break;
}
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
for (int i = lane_id; i < D2; i += NW) {
for (int i0 = 0; i0 < D2; i0 += NW) {
const int i = i0 + lane_id;
if (i >= D2) {
break;
}
if (iq1 + j < ne01) {
sq2[j*T2 + i] = __float22half2_rn(q2[i]);
} else {
@ -6485,7 +6495,12 @@ static __global__ void flash_attn_ext_f16(
// zero out shared memory SH
for (int j = 0; j < Q; ++j) {
for (int i = lane_id; i < SH; i += NW) {
for (int i0 = 0; i0 < SH; i0 += NW) {
const int i = i0 + lane_id;
if (i >= SH) {
break;
}
ss[j*T + i] = 0.0;
}
}
@ -6544,7 +6559,12 @@ static __global__ void flash_attn_ext_f16(
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) {
for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) {
const int ic = ic0 + warp_id*C;
if (ic >= ne11) {
break;
}
// Q*K^T
{
for (int cc = 0; cc < C/16; ++cc) {
@ -6614,7 +6634,9 @@ static __global__ void flash_attn_ext_f16(
for (int j = 0; j < Q; ++j) {
const half m = M[j];
for (int p = lane_id; p < C; p += NW) {
for (int p0 = 0; p0 < C; p0 += NW) {
const int p = p0 + lane_id;
const half s = ss[j*T + p];
smax = __hmax(smax, s);
@ -6633,7 +6655,9 @@ static __global__ void flash_attn_ext_f16(
// local sum
half ls = 0.0f;
for (int p = lane_id; p < C; p += NW) {
for (int p0 = 0; p0 < C; p0 += NW) {
const int p = p0 + lane_id;
const half s = ss[j*T + p];
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
@ -6788,7 +6812,12 @@ static __global__ void flash_attn_ext_f16(
for (int j = 0; j < Q && iq1 + j < ne01; ++j) {
const half S = ss[j*T + 0];
for (int i = lane_id; i < D; i += NW) {
for (int i0 = 0; i0 < D; i0 += NW) {
const int i = i0 + lane_id;
if (i >= D) {
break;
}
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
}
}

View File

@ -2210,7 +2210,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_leaky_relu());
#if 1
for (int hs : { 64, 80, 128, }) {
for (int hs : { 128, 64, 80, }) {
for (int nh : { 32, }) {
for (int kv : { 512, 1024, 2048, 4096, }) {
for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {