mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-29 21:34:51 +01:00
cuda : make loops use the same loop values
Thanks Johannes again for the tip
This commit is contained in:
parent
7c34655b36
commit
1f8a592482
43
ggml-cuda.cu
43
ggml-cuda.cu
@ -6462,10 +6462,20 @@ static __global__ void flash_attn_ext_f16(
|
||||
half16x16_acc lo[Q16][D16];
|
||||
|
||||
// load heads from Q to shared memory
|
||||
for (int j = warp_id; j < Q; j += num_warps) {
|
||||
for (int j0 = 0; j0 < Q; j0 += num_warps) {
|
||||
const int j = j0 + warp_id;
|
||||
if (j >= Q) {
|
||||
break;
|
||||
}
|
||||
|
||||
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
||||
|
||||
for (int i = lane_id; i < D2; i += NW) {
|
||||
for (int i0 = 0; i0 < D2; i0 += NW) {
|
||||
const int i = i0 + lane_id;
|
||||
if (i >= D2) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (iq1 + j < ne01) {
|
||||
sq2[j*T2 + i] = __float22half2_rn(q2[i]);
|
||||
} else {
|
||||
@ -6485,7 +6495,12 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
// zero out shared memory SH
|
||||
for (int j = 0; j < Q; ++j) {
|
||||
for (int i = lane_id; i < SH; i += NW) {
|
||||
for (int i0 = 0; i0 < SH; i0 += NW) {
|
||||
const int i = i0 + lane_id;
|
||||
if (i >= SH) {
|
||||
break;
|
||||
}
|
||||
|
||||
ss[j*T + i] = 0.0;
|
||||
}
|
||||
}
|
||||
@ -6544,7 +6559,12 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
// loop over the KV cache
|
||||
// each simdgroup handles blocks of Q rows and C columns
|
||||
for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) {
|
||||
for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) {
|
||||
const int ic = ic0 + warp_id*C;
|
||||
if (ic >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Q*K^T
|
||||
{
|
||||
for (int cc = 0; cc < C/16; ++cc) {
|
||||
@ -6614,7 +6634,9 @@ static __global__ void flash_attn_ext_f16(
|
||||
for (int j = 0; j < Q; ++j) {
|
||||
const half m = M[j];
|
||||
|
||||
for (int p = lane_id; p < C; p += NW) {
|
||||
for (int p0 = 0; p0 < C; p0 += NW) {
|
||||
const int p = p0 + lane_id;
|
||||
|
||||
const half s = ss[j*T + p];
|
||||
|
||||
smax = __hmax(smax, s);
|
||||
@ -6633,7 +6655,9 @@ static __global__ void flash_attn_ext_f16(
|
||||
// local sum
|
||||
half ls = 0.0f;
|
||||
|
||||
for (int p = lane_id; p < C; p += NW) {
|
||||
for (int p0 = 0; p0 < C; p0 += NW) {
|
||||
const int p = p0 + lane_id;
|
||||
|
||||
const half s = ss[j*T + p];
|
||||
|
||||
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
|
||||
@ -6788,7 +6812,12 @@ static __global__ void flash_attn_ext_f16(
|
||||
for (int j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||
const half S = ss[j*T + 0];
|
||||
|
||||
for (int i = lane_id; i < D; i += NW) {
|
||||
for (int i0 = 0; i0 < D; i0 += NW) {
|
||||
const int i = i0 + lane_id;
|
||||
if (i >= D) {
|
||||
break;
|
||||
}
|
||||
|
||||
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
|
||||
}
|
||||
}
|
||||
|
@ -2210,7 +2210,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
test_cases.emplace_back(new test_leaky_relu());
|
||||
|
||||
#if 1
|
||||
for (int hs : { 64, 80, 128, }) {
|
||||
for (int hs : { 128, 64, 80, }) {
|
||||
for (int nh : { 32, }) {
|
||||
for (int kv : { 512, 1024, 2048, 4096, }) {
|
||||
for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {
|
||||
|
Loading…
Reference in New Issue
Block a user