metal : more precise Q*K in FA vec kernel (#10247)

This commit is contained in:
Georgi Gerganov 2024-11-11 08:39:13 +02:00 committed by GitHub
parent b141e5f6ef
commit b0cefea58a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2942,6 +2942,7 @@ kernel void kernel_flash_attn_ext(
half smax = -INFINITY; half smax = -INFINITY;
// load the mask in shared memory // load the mask in shared memory
#pragma unroll(Q)
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31); device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
@ -2968,7 +2969,7 @@ kernel void kernel_flash_attn_ext(
// we can read directly from global memory // we can read directly from global memory
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
#pragma unroll #pragma unroll(D8)
for (short i = 0; i < D8; ++i) { for (short i = 0; i < D8; ++i) {
k8x8_t mk; k8x8_t mk;
simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10 simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
@ -2989,7 +2990,7 @@ kernel void kernel_flash_attn_ext(
simdgroup_barrier(mem_flags::mem_threadgroup); simdgroup_barrier(mem_flags::mem_threadgroup);
#pragma unroll #pragma unroll(4)
for (short k = 0; k < 4; ++k) { for (short k = 0; k < 4; ++k) {
k8x8_t mk; k8x8_t mk;
@ -3067,7 +3068,7 @@ kernel void kernel_flash_attn_ext(
s8x8_t mm; s8x8_t mm;
simdgroup_load(mm, ss + 2*C, TS, 0, false); simdgroup_load(mm, ss + 2*C, TS, 0, false);
#pragma unroll #pragma unroll(D8)
for (short i = 0; i < D8; ++i) { for (short i = 0; i < D8; ++i) {
simdgroup_multiply(lo[i], mm, lo[i]); simdgroup_multiply(lo[i], mm, lo[i]);
} }
@ -3082,7 +3083,8 @@ kernel void kernel_flash_attn_ext(
if (is_same<vd4x4_t, v4x4_t>::value) { if (is_same<vd4x4_t, v4x4_t>::value) {
// we can read directly from global memory // we can read directly from global memory
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
#pragma unroll
#pragma unroll(D8)
for (short i = 0; i < D8; ++i) { for (short i = 0; i < D8; ++i) {
v8x8_t mv; v8x8_t mv;
simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20 simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
@ -3103,7 +3105,7 @@ kernel void kernel_flash_attn_ext(
simdgroup_barrier(mem_flags::mem_threadgroup); simdgroup_barrier(mem_flags::mem_threadgroup);
#pragma unroll #pragma unroll(4)
for (short k = 0; k < 4; ++k) { for (short k = 0; k < 4; ++k) {
v8x8_t mv; v8x8_t mv;
@ -3196,6 +3198,7 @@ kernel void kernel_flash_attn_ext(
simdgroup_load(ms0, ss + 2*C, TS, 0, false); simdgroup_load(ms0, ss + 2*C, TS, 0, false);
simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false); simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
#pragma unroll(D8)
for (short i = 0; i < D8; ++i) { for (short i = 0; i < D8; ++i) {
o8x8_t t; o8x8_t t;
@ -3413,6 +3416,7 @@ kernel void kernel_flash_attn_ext_vec(
// load the queries from shared memory into local memory // load the queries from shared memory into local memory
q4x4_t mq[D16/NL]; q4x4_t mq[D16/NL];
#pragma unroll(D16/NL)
for (short ii = 0; ii < D16; ii += NL) { for (short ii = 0; ii < D16; ii += NL) {
mq[ii/NL] = sq4x4[ii + tx]; mq[ii/NL] = sq4x4[ii + tx];
} }
@ -3454,17 +3458,23 @@ kernel void kernel_flash_attn_ext_vec(
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
#pragma unroll #pragma unroll(D16/NL)
for (short ii = 0; ii < D16; ii += NL) { for (short ii = 0; ii < D16; ii += NL) {
const short i = ii + tx; const short i = ii + tx;
k4x4_t mk; k4x4_t mk;
deq_k(pk + i/nl_k, i%nl_k, mk); deq_k(pk + i/nl_k, i%nl_k, mk);
mqka[0] += dot(mq[ii/NL][0], mk[0]); // note: this is less precise than the version below
mqka[1] += dot(mq[ii/NL][1], mk[1]); //mqka[0] += dot(mq[ii/NL][0], mk[0]);
mqka[2] += dot(mq[ii/NL][2], mk[2]); //mqka[1] += dot(mq[ii/NL][1], mk[1]);
mqka[3] += dot(mq[ii/NL][3], mk[3]); //mqka[2] += dot(mq[ii/NL][2], mk[2]);
//mqka[3] += dot(mq[ii/NL][3], mk[3]);
mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]);
mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]);
mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]);
mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]);
} }
qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3]; qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
@ -3513,7 +3523,7 @@ kernel void kernel_flash_attn_ext_vec(
ss[tiisg] = vs; ss[tiisg] = vs;
// O = diag(ms)*O // O = diag(ms)*O
#pragma unroll #pragma unroll(D16/NL)
for (short ii = 0; ii < D16; ii += NL) { for (short ii = 0; ii < D16; ii += NL) {
lo[ii/NL] *= ms; lo[ii/NL] *= ms;
} }
@ -3523,13 +3533,12 @@ kernel void kernel_flash_attn_ext_vec(
// O = O + (Q*K^T)*V // O = O + (Q*K^T)*V
{ {
#pragma unroll
for (short cc = 0; cc < C/4; ++cc) { for (short cc = 0; cc < C/4; ++cc) {
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
const s4x4_t ms(ss[4*cc + ty]); const s4x4_t ms(ss[4*cc + ty]);
#pragma unroll #pragma unroll(D16/NL)
for (short ii = 0; ii < D16; ii += NL) { for (short ii = 0; ii < D16; ii += NL) {
const short i = ii + tx; const short i = ii + tx;