wip : simdify ms, vs

This commit is contained in:
Georgi Gerganov 2024-01-25 09:39:22 +02:00
parent f2efa6cd98
commit 806382a3a6
No known key found for this signature in database
GPG Key ID: BF970631944C16B7

View File

@ -2052,7 +2052,7 @@ kernel void kernel_flash_attn_ext_f16(
threadgroup half * pq = (threadgroup half *) (shared + 0*D);
threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D);
threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*C) + 1*D);
threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*C) + 1*D);
//threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*C) + 1*D);
half4 ps4[Q][L4];
@ -2079,10 +2079,10 @@ kernel void kernel_flash_attn_ext_f16(
threadgroup_barrier(mem_flags::mem_threadgroup);
half S = { 0.0h };
half M = { -INFINITY };
{
half S[Q] = { 0.0h };
half M[Q] = { -INFINITY };
// assume K and V are same shape
const int64_t ne22 = ne12;
const int64_t ne23 = ne13;
@ -2107,6 +2107,7 @@ kernel void kernel_flash_attn_ext_f16(
const int64_t iv3 = iq3 / rv3;
simdgroup_half8x8 mq[D8];
for (int64_t i = 0; i < D8; ++i) {
simdgroup_load(mq[i], pq + i*8, T);
}
@ -2151,46 +2152,85 @@ kernel void kernel_flash_attn_ext_f16(
simdgroup_store(mqk, ss, T, 0, false);
}
if (tiisg < Q) {
const int64_t j = tiisg;
//if (tiisg < Q) {
// const int64_t j = tiisg;
for (int p = 0; p < C; ++p) {
const half s = ss[j*T + p + 0]*scale + (mp[j][iic + p]);
// for (int p = 0; p < C; ++p) {
// const half s = ss[j*T + p + 0]*scale + (mp[j][iic + p]);
const half m = M;
// const half m = M;
M = max(M, s);
// M = max(M, s);
const half ms = m == -INFINITY ? 0.0h : exp(m - M);
const half vs = s == -INFINITY ? 0.0h : exp(s - M);
// const half ms = m == -INFINITY ? 0.0h : exp(m - M);
// const half vs = s == -INFINITY ? 0.0h : exp(s - M);
S = S*ms + vs;
// S = S*ms + vs;
ss[j*T + 0 + p] = ms;
// ss[j*T + 0 + p] = ms;
// ss[j*T + C + p] = vs;
// }
//}
// not sure why this barrier is needed
simdgroup_barrier(mem_flags::mem_none);
for (int64_t j = 0; j < Q; ++j) {
const int64_t p = tiisg % C;
const half s = ss[j*T + p + 0]*scale + (mp[j][iic + p]);
half m = M[j];
M[j] = simd_max(max(M[j], s));
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
S[j] = S[j]*ms + 0.25h*simd_sum(vs);
for (int64_t i = 0; i < L4; ++i) {
ps4[j][i] *= ms;
}
if (tiisg < C) {
ss[j*T + C + p] = vs;
}
}
for (int p = 0; p < C; ++p) {
for (int64_t p = 0; p < C; ++p) {
const int64_t ic = iic + p;
device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));
for (int64_t j = 0; j < Q; ++j) {
const half ms = ss[j*T + 0 + p];
const half vs = ss[j*T + C + p];
for (int64_t i = 0; i < L4; ++i) {
ps4[j][i] = ps4[j][i]*ms + pv4[N4*i + tiisg]*vs;
ps4[j][i] += pv4[N4*i + tiisg]*vs;
}
}
}
//for (int p = 0; p < C; ++p) {
// const int64_t ic = iic + p;
// device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));
// for (int64_t j = 0; j < Q; ++j) {
// const half ms = ss[j*T + 0 + p];
// const half vs = ss[j*T + C + p];
// for (int64_t i = 0; i < L4; ++i) {
// ps4[j][i] = ps4[j][i]*ms + pv4[N4*i + tiisg]*vs;
// }
// }
//}
}
if (tiisg < Q) {
const int64_t j = tiisg;
ss[j*T + 0] = S;
ss[j*T + 1] = M;
for (int64_t j = 0; j < Q; ++j) {
if (tiisg == 0) {
ss[j*T + 0] = S[j];
ss[j*T + 1] = M[j];
}
}
}
@ -2198,6 +2238,9 @@ kernel void kernel_flash_attn_ext_f16(
// reduce the warps
half S = { 0.0h };
half M = { -INFINITY };
for (int64_t sg = 1; sg < nsg; ++sg) {
if (sgitg == sg) {
// store heads to shared memory - reuse pq4
@ -2241,7 +2284,7 @@ kernel void kernel_flash_attn_ext_f16(
if (sgitg == 0) {
for (int64_t j = 0; j < Q; ++j) {
S = ss[j*T + 0];
const half S = ss[j*T + 0];
for (int64_t i = 0; i < L4; ++i) {
ps4[j][i] = ps4[j][i]/S;
}