mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-15 23:00:46 +01:00
wip : simdify ms, vs
This commit is contained in:
parent
f2efa6cd98
commit
806382a3a6
@ -2052,7 +2052,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
threadgroup half * pq = (threadgroup half *) (shared + 0*D);
|
threadgroup half * pq = (threadgroup half *) (shared + 0*D);
|
||||||
threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D);
|
threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D);
|
||||||
threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*C) + 1*D);
|
threadgroup half * ss = (threadgroup half *) (shared + sgitg*(2*C) + 1*D);
|
||||||
threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*C) + 1*D);
|
//threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(2*C) + 1*D);
|
||||||
|
|
||||||
half4 ps4[Q][L4];
|
half4 ps4[Q][L4];
|
||||||
|
|
||||||
@ -2079,10 +2079,10 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
half S = { 0.0h };
|
|
||||||
half M = { -INFINITY };
|
|
||||||
|
|
||||||
{
|
{
|
||||||
|
half S[Q] = { 0.0h };
|
||||||
|
half M[Q] = { -INFINITY };
|
||||||
|
|
||||||
// assume K and V are same shape
|
// assume K and V are same shape
|
||||||
const int64_t ne22 = ne12;
|
const int64_t ne22 = ne12;
|
||||||
const int64_t ne23 = ne13;
|
const int64_t ne23 = ne13;
|
||||||
@ -2107,6 +2107,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
const int64_t iv3 = iq3 / rv3;
|
const int64_t iv3 = iq3 / rv3;
|
||||||
|
|
||||||
simdgroup_half8x8 mq[D8];
|
simdgroup_half8x8 mq[D8];
|
||||||
|
|
||||||
for (int64_t i = 0; i < D8; ++i) {
|
for (int64_t i = 0; i < D8; ++i) {
|
||||||
simdgroup_load(mq[i], pq + i*8, T);
|
simdgroup_load(mq[i], pq + i*8, T);
|
||||||
}
|
}
|
||||||
@ -2151,46 +2152,85 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
simdgroup_store(mqk, ss, T, 0, false);
|
simdgroup_store(mqk, ss, T, 0, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tiisg < Q) {
|
//if (tiisg < Q) {
|
||||||
const int64_t j = tiisg;
|
// const int64_t j = tiisg;
|
||||||
|
|
||||||
for (int p = 0; p < C; ++p) {
|
// for (int p = 0; p < C; ++p) {
|
||||||
const half s = ss[j*T + p + 0]*scale + (mp[j][iic + p]);
|
// const half s = ss[j*T + p + 0]*scale + (mp[j][iic + p]);
|
||||||
|
|
||||||
const half m = M;
|
// const half m = M;
|
||||||
|
|
||||||
M = max(M, s);
|
// M = max(M, s);
|
||||||
|
|
||||||
const half ms = m == -INFINITY ? 0.0h : exp(m - M);
|
// const half ms = m == -INFINITY ? 0.0h : exp(m - M);
|
||||||
const half vs = s == -INFINITY ? 0.0h : exp(s - M);
|
// const half vs = s == -INFINITY ? 0.0h : exp(s - M);
|
||||||
|
|
||||||
S = S*ms + vs;
|
// S = S*ms + vs;
|
||||||
|
|
||||||
ss[j*T + 0 + p] = ms;
|
// ss[j*T + 0 + p] = ms;
|
||||||
|
// ss[j*T + C + p] = vs;
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
|
||||||
|
// not sure why this barrier is needed
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
|
const int64_t p = tiisg % C;
|
||||||
|
|
||||||
|
const half s = ss[j*T + p + 0]*scale + (mp[j][iic + p]);
|
||||||
|
|
||||||
|
half m = M[j];
|
||||||
|
|
||||||
|
M[j] = simd_max(max(M[j], s));
|
||||||
|
|
||||||
|
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
||||||
|
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
||||||
|
|
||||||
|
S[j] = S[j]*ms + 0.25h*simd_sum(vs);
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < L4; ++i) {
|
||||||
|
ps4[j][i] *= ms;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tiisg < C) {
|
||||||
ss[j*T + C + p] = vs;
|
ss[j*T + C + p] = vs;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int p = 0; p < C; ++p) {
|
for (int64_t p = 0; p < C; ++p) {
|
||||||
const int64_t ic = iic + p;
|
const int64_t ic = iic + p;
|
||||||
device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));
|
device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
const half ms = ss[j*T + 0 + p];
|
|
||||||
const half vs = ss[j*T + C + p];
|
const half vs = ss[j*T + C + p];
|
||||||
|
|
||||||
for (int64_t i = 0; i < L4; ++i) {
|
for (int64_t i = 0; i < L4; ++i) {
|
||||||
ps4[j][i] = ps4[j][i]*ms + pv4[N4*i + tiisg]*vs;
|
ps4[j][i] += pv4[N4*i + tiisg]*vs;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//for (int p = 0; p < C; ++p) {
|
||||||
|
// const int64_t ic = iic + p;
|
||||||
|
// device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
|
// for (int64_t j = 0; j < Q; ++j) {
|
||||||
|
// const half ms = ss[j*T + 0 + p];
|
||||||
|
// const half vs = ss[j*T + C + p];
|
||||||
|
|
||||||
|
// for (int64_t i = 0; i < L4; ++i) {
|
||||||
|
// ps4[j][i] = ps4[j][i]*ms + pv4[N4*i + tiisg]*vs;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tiisg < Q) {
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
const int64_t j = tiisg;
|
if (tiisg == 0) {
|
||||||
|
ss[j*T + 0] = S[j];
|
||||||
ss[j*T + 0] = S;
|
ss[j*T + 1] = M[j];
|
||||||
ss[j*T + 1] = M;
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2198,6 +2238,9 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
|
|
||||||
// reduce the warps
|
// reduce the warps
|
||||||
|
|
||||||
|
half S = { 0.0h };
|
||||||
|
half M = { -INFINITY };
|
||||||
|
|
||||||
for (int64_t sg = 1; sg < nsg; ++sg) {
|
for (int64_t sg = 1; sg < nsg; ++sg) {
|
||||||
if (sgitg == sg) {
|
if (sgitg == sg) {
|
||||||
// store heads to shared memory - reuse pq4
|
// store heads to shared memory - reuse pq4
|
||||||
@ -2241,7 +2284,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
|
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
S = ss[j*T + 0];
|
const half S = ss[j*T + 0];
|
||||||
for (int64_t i = 0; i < L4; ++i) {
|
for (int64_t i = 0; i < L4; ++i) {
|
||||||
ps4[j][i] = ps4[j][i]/S;
|
ps4[j][i] = ps4[j][i]/S;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user