mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
metal : minor fixup in FA kernel (#10143)
* metal : minor fixup in FA kernel ggml-ci * metal : use the unrolled loop variable * metal : remove unused var
This commit is contained in:
parent
1839f69130
commit
08828a6d7d
@ -2776,11 +2776,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||||||
const short iv3 = iq3 / rv3;
|
const short iv3 = iq3 / rv3;
|
||||||
|
|
||||||
// load the queries from shared memory into local memory
|
// load the queries from shared memory into local memory
|
||||||
float4 mq[D4];
|
float4 mq[D4/NW];
|
||||||
|
|
||||||
for (short ii = 0; ii < D4; ii += NW) {
|
for (short ii = 0; ii < D4; ii += NW) {
|
||||||
short i = ii + tiisg;
|
short i = ii + tiisg;
|
||||||
mq[i] = (float4) sq4[i];
|
mq[ii/NW] = (float4) sq4[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
// pointer to the mask
|
// pointer to the mask
|
||||||
@ -2812,7 +2812,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||||||
mk[2] = (float4) pk4[i + 2*(nb11/8)];
|
mk[2] = (float4) pk4[i + 2*(nb11/8)];
|
||||||
mk[3] = (float4) pk4[i + 3*(nb11/8)];
|
mk[3] = (float4) pk4[i + 3*(nb11/8)];
|
||||||
|
|
||||||
mqk += (float4) (mq[i] * mk);
|
mqk += (float4) (mq[ii/NW] * mk);
|
||||||
}
|
}
|
||||||
|
|
||||||
// reduce the results from the threads in the simdgroup
|
// reduce the results from the threads in the simdgroup
|
||||||
@ -2857,8 +2857,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||||||
// O = diag(ms)*O
|
// O = diag(ms)*O
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (short ii = 0; ii < D4; ii += NW) {
|
for (short ii = 0; ii < D4; ii += NW) {
|
||||||
const short i = ii + tiisg;
|
lo[ii/NW] *= ms;
|
||||||
lo[i/NW] *= ms;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2872,10 +2871,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||||||
for (short ii = 0; ii < D4; ii += NW) {
|
for (short ii = 0; ii < D4; ii += NW) {
|
||||||
const short i = ii + tiisg;
|
const short i = ii + tiisg;
|
||||||
|
|
||||||
lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
|
lo[ii/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
|
||||||
lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
|
lo[ii/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
|
||||||
lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
|
lo[ii/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
|
||||||
lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
|
lo[ii/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user