metal : another fix for the fa kernel

This commit is contained in:
Georgi Gerganov 2024-08-26 14:55:28 +03:00
parent 7a3df798fc
commit a95225cdfd
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -2144,6 +2144,7 @@ kernel void kernel_flash_attn_ext_f16(
const short tx = tiisg%4; const short tx = tiisg%4;
const short ty = tiisg/4; const short ty = tiisg/4;
if (iq1 + ty < ne01) {
// mqk = mqk*scale // mqk = mqk*scale
ss[8*cc + ty*TF + 2*tx + 0] *= scale; ss[8*cc + ty*TF + 2*tx + 0] *= scale;
ss[8*cc + ty*TF + 2*tx + 1] *= scale; ss[8*cc + ty*TF + 2*tx + 1] *= scale;
@ -2160,6 +2161,7 @@ kernel void kernel_flash_attn_ext_f16(
} }
} }
} }
}
// used to detect blocks full of -INF // used to detect blocks full of -INF
float smax = -INFINITY; float smax = -INFINITY;