metal : avoid redundant loads of the attention

This commit is contained in:
Georgi Gerganov 2024-01-25 15:00:49 +02:00
parent 1446a12b29
commit d917746ddb
No known key found for this signature in database
GPG Key ID: BF970631944C16B7

View File

@ -2184,20 +2184,22 @@ kernel void kernel_flash_attn_ext_f16(
ss[j*T + p] = vs;
}
simdgroup_barrier(mem_flags::mem_none);
// (Q*K^T)*V
{
simdgroup_half8x8 mv;
simdgroup_half8x8 mp[C/8];
for (int cc = 0; cc < C/8; ++cc) {
simdgroup_load(mp[cc], ss + 8*cc, T, 0, false);
}
for (int64_t i = 0; i < D8; ++i) {
simdgroup_half8x8 mp[C/8];
simdgroup_half8x8 mqkv;
simdgroup_load(mqkv, ps + i*8, T, 0, false);
for (int cc = 0; cc < C/8; ++cc) {
simdgroup_load(mp[cc], ss + 8*cc, T, 0, false);
}
for (int cc = 0; cc < C/8; ++cc) {
device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));