mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 22:59:24 +01:00
metal : avoid redundant loads of the attention
This commit is contained in:
parent
1446a12b29
commit
d917746ddb
@ -2184,20 +2184,22 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
ss[j*T + p] = vs;
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// (Q*K^T)*V
|
||||
{
|
||||
simdgroup_half8x8 mv;
|
||||
|
||||
simdgroup_half8x8 mp[C/8];
|
||||
for (int cc = 0; cc < C/8; ++cc) {
|
||||
simdgroup_load(mp[cc], ss + 8*cc, T, 0, false);
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_half8x8 mp[C/8];
|
||||
simdgroup_half8x8 mqkv;
|
||||
|
||||
simdgroup_load(mqkv, ps + i*8, T, 0, false);
|
||||
|
||||
for (int cc = 0; cc < C/8; ++cc) {
|
||||
simdgroup_load(mp[cc], ss + 8*cc, T, 0, false);
|
||||
}
|
||||
|
||||
for (int cc = 0; cc < C/8; ++cc) {
|
||||
device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user