mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-15 23:00:46 +01:00
wip : simdify ms, vs
This commit is contained in:
parent
f2efa6cd98
commit
806382a3a6
@ -2052,7 +2052,7 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
threadgroup half * pq = (threadgroup half *) (shared + 0*D);
|
||||
threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D);
|
||||
threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*C) + 1*D);
|
||||
threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*C) + 1*D);
|
||||
//threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*C) + 1*D);
|
||||
|
||||
half4 ps4[Q][L4];
|
||||
|
||||
@ -2079,10 +2079,10 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
half S = { 0.0h };
|
||||
half M = { -INFINITY };
|
||||
|
||||
{
|
||||
half S[Q] = { 0.0h };
|
||||
half M[Q] = { -INFINITY };
|
||||
|
||||
// assume K and V are same shape
|
||||
const int64_t ne22 = ne12;
|
||||
const int64_t ne23 = ne13;
|
||||
@ -2107,6 +2107,7 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
const int64_t iv3 = iq3 / rv3;
|
||||
|
||||
simdgroup_half8x8 mq[D8];
|
||||
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_load(mq[i], pq + i*8, T);
|
||||
}
|
||||
@ -2151,46 +2152,85 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
simdgroup_store(mqk, ss, T, 0, false);
|
||||
}
|
||||
|
||||
if (tiisg < Q) {
|
||||
const int64_t j = tiisg;
|
||||
//if (tiisg < Q) {
|
||||
// const int64_t j = tiisg;
|
||||
|
||||
for (int p = 0; p < C; ++p) {
|
||||
const half s = ss[j*T + p + 0]*scale + (mp[j][iic + p]);
|
||||
// for (int p = 0; p < C; ++p) {
|
||||
// const half s = ss[j*T + p + 0]*scale + (mp[j][iic + p]);
|
||||
|
||||
const half m = M;
|
||||
// const half m = M;
|
||||
|
||||
M = max(M, s);
|
||||
// M = max(M, s);
|
||||
|
||||
const half ms = m == -INFINITY ? 0.0h : exp(m - M);
|
||||
const half vs = s == -INFINITY ? 0.0h : exp(s - M);
|
||||
// const half ms = m == -INFINITY ? 0.0h : exp(m - M);
|
||||
// const half vs = s == -INFINITY ? 0.0h : exp(s - M);
|
||||
|
||||
S = S*ms + vs;
|
||||
// S = S*ms + vs;
|
||||
|
||||
ss[j*T + 0 + p] = ms;
|
||||
// ss[j*T + 0 + p] = ms;
|
||||
// ss[j*T + C + p] = vs;
|
||||
// }
|
||||
//}
|
||||
|
||||
// not sure why this barrier is needed
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const int64_t p = tiisg % C;
|
||||
|
||||
const half s = ss[j*T + p + 0]*scale + (mp[j][iic + p]);
|
||||
|
||||
half m = M[j];
|
||||
|
||||
M[j] = simd_max(max(M[j], s));
|
||||
|
||||
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
||||
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
||||
|
||||
S[j] = S[j]*ms + 0.25h*simd_sum(vs);
|
||||
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
ps4[j][i] *= ms;
|
||||
}
|
||||
|
||||
if (tiisg < C) {
|
||||
ss[j*T + C + p] = vs;
|
||||
}
|
||||
}
|
||||
|
||||
for (int p = 0; p < C; ++p) {
|
||||
for (int64_t p = 0; p < C; ++p) {
|
||||
const int64_t ic = iic + p;
|
||||
device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const half ms = ss[j*T + 0 + p];
|
||||
const half vs = ss[j*T + C + p];
|
||||
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
ps4[j][i] = ps4[j][i]*ms + pv4[N4*i + tiisg]*vs;
|
||||
ps4[j][i] += pv4[N4*i + tiisg]*vs;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//for (int p = 0; p < C; ++p) {
|
||||
// const int64_t ic = iic + p;
|
||||
// device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
// for (int64_t j = 0; j < Q; ++j) {
|
||||
// const half ms = ss[j*T + 0 + p];
|
||||
// const half vs = ss[j*T + C + p];
|
||||
|
||||
// for (int64_t i = 0; i < L4; ++i) {
|
||||
// ps4[j][i] = ps4[j][i]*ms + pv4[N4*i + tiisg]*vs;
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
}
|
||||
|
||||
if (tiisg < Q) {
|
||||
const int64_t j = tiisg;
|
||||
|
||||
ss[j*T + 0] = S;
|
||||
ss[j*T + 1] = M;
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
if (tiisg == 0) {
|
||||
ss[j*T + 0] = S[j];
|
||||
ss[j*T + 1] = M[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2198,6 +2238,9 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
|
||||
// reduce the warps
|
||||
|
||||
half S = { 0.0h };
|
||||
half M = { -INFINITY };
|
||||
|
||||
for (int64_t sg = 1; sg < nsg; ++sg) {
|
||||
if (sgitg == sg) {
|
||||
// store heads to shared memory - reuse pq4
|
||||
@ -2241,7 +2284,7 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
|
||||
if (sgitg == 0) {
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
S = ss[j*T + 0];
|
||||
const half S = ss[j*T + 0];
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
ps4[j][i] = ps4[j][i]/S;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user