mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 15:18:26 +01:00
metal : avoid redundant loads of the attention
This commit is contained in:
parent
1446a12b29
commit
d917746ddb
@ -2184,20 +2184,22 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
ss[j*T + p] = vs;
|
ss[j*T + p] = vs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
// (Q*K^T)*V
|
// (Q*K^T)*V
|
||||||
{
|
{
|
||||||
simdgroup_half8x8 mv;
|
simdgroup_half8x8 mv;
|
||||||
|
|
||||||
|
simdgroup_half8x8 mp[C/8];
|
||||||
|
for (int cc = 0; cc < C/8; ++cc) {
|
||||||
|
simdgroup_load(mp[cc], ss + 8*cc, T, 0, false);
|
||||||
|
}
|
||||||
|
|
||||||
for (int64_t i = 0; i < D8; ++i) {
|
for (int64_t i = 0; i < D8; ++i) {
|
||||||
simdgroup_half8x8 mp[C/8];
|
|
||||||
simdgroup_half8x8 mqkv;
|
simdgroup_half8x8 mqkv;
|
||||||
|
|
||||||
simdgroup_load(mqkv, ps + i*8, T, 0, false);
|
simdgroup_load(mqkv, ps + i*8, T, 0, false);
|
||||||
|
|
||||||
for (int cc = 0; cc < C/8; ++cc) {
|
|
||||||
simdgroup_load(mp[cc], ss + 8*cc, T, 0, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int cc = 0; cc < C/8; ++cc) {
|
for (int cc = 0; cc < C/8; ++cc) {
|
||||||
device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user