From 1f8a5924823aecaa6ab1d5c2ac70ddde1d6c27d0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 Feb 2024 14:01:32 +0200 Subject: [PATCH] cuda : make loops use the same loop values Thanks Johannes again for the tip --- ggml-cuda.cu | 43 +++++++++++++++++++++++++++++++------- tests/test-backend-ops.cpp | 2 +- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 558ffb8ac..a3a6c6455 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6462,10 +6462,20 @@ static __global__ void flash_attn_ext_f16( half16x16_acc lo[Q16][D16]; // load heads from Q to shared memory - for (int j = warp_id; j < Q; j += num_warps) { + for (int j0 = 0; j0 < Q; j0 += num_warps) { + const int j = j0 + warp_id; + if (j >= Q) { + break; + } + const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); - for (int i = lane_id; i < D2; i += NW) { + for (int i0 = 0; i0 < D2; i0 += NW) { + const int i = i0 + lane_id; + if (i >= D2) { + break; + } + if (iq1 + j < ne01) { sq2[j*T2 + i] = __float22half2_rn(q2[i]); } else { @@ -6485,7 +6495,12 @@ static __global__ void flash_attn_ext_f16( // zero out shared memory SH for (int j = 0; j < Q; ++j) { - for (int i = lane_id; i < SH; i += NW) { + for (int i0 = 0; i0 < SH; i0 += NW) { + const int i = i0 + lane_id; + if (i >= SH) { + break; + } + ss[j*T + i] = 0.0; } } @@ -6544,7 +6559,12 @@ static __global__ void flash_attn_ext_f16( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) { + for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) { + const int ic = ic0 + warp_id*C; + if (ic >= ne11) { + break; + } + // Q*K^T { for (int cc = 0; cc < C/16; ++cc) { @@ -6614,7 +6634,9 @@ static __global__ void flash_attn_ext_f16( for (int j = 0; j < Q; ++j) { const half m = M[j]; - for (int p = lane_id; p < C; p += NW) { + for (int p0 = 0; p0 < C; p0 += NW) { + const int p = p0 + lane_id; + const half s = ss[j*T + p]; smax = __hmax(smax, s); @@ -6633,7 +6655,9 @@ static __global__ void flash_attn_ext_f16( // local sum half ls = 0.0f; - for (int p = lane_id; p < C; p += NW) { + for (int p0 = 0; p0 < C; p0 += NW) { + const int p = p0 + lane_id; + const half s = ss[j*T + p]; const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); @@ -6788,7 +6812,12 @@ static __global__ void flash_attn_ext_f16( for (int j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int i = lane_id; i < D; i += NW) { + for (int i0 = 0; i0 < D; i0 += NW) { + const int i = i0 + lane_id; + if (i >= D) { + break; + } + dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9feb5e1fe..e4076b49c 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2210,7 +2210,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_leaky_relu()); #if 1 - for (int hs : { 64, 80, 128, }) { + for (int hs : { 128, 64, 80, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {