mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
metal : fix F32 accumulation in FA vec kernel (#10232)
This commit is contained in:
parent
f018acba22
commit
bb38cdd8ba
@ -3450,7 +3450,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
{
|
||||
// each simdgroup processes 1 query and 4 keys
|
||||
for (short cc = 0; cc < C/4; ++cc) {
|
||||
qk_t mqk = 0.0;
|
||||
qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
|
||||
|
||||
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
||||
|
||||
@ -3461,13 +3461,14 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
k4x4_t mk;
|
||||
deq_k(pk + i/nl_k, i%nl_k, mk);
|
||||
|
||||
mqk +=
|
||||
dot(mq[ii/NL][0], mk[0]) +
|
||||
dot(mq[ii/NL][1], mk[1]) +
|
||||
dot(mq[ii/NL][2], mk[2]) +
|
||||
dot(mq[ii/NL][3], mk[3]);
|
||||
mqka[0] += dot(mq[ii/NL][0], mk[0]);
|
||||
mqka[1] += dot(mq[ii/NL][1], mk[1]);
|
||||
mqka[2] += dot(mq[ii/NL][2], mk[2]);
|
||||
mqka[3] += dot(mq[ii/NL][3], mk[3]);
|
||||
}
|
||||
|
||||
qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
|
||||
|
||||
// simdgroup reduce
|
||||
// [ 0 .. 7] -> [ 0]
|
||||
// [ 8 .. 15] -> [ 8]
|
||||
|
Loading…
Reference in New Issue
Block a user