metal : scale and mask in matrix form

This commit is contained in:
Georgi Gerganov 2024-01-25 15:47:52 +02:00
parent d917746ddb
commit 432ad04ffa
No known key found for this signature in database
GPG Key ID: BF970631944C16B7

View File

@ -2127,6 +2127,9 @@ kernel void kernel_flash_attn_ext_f16(
} }
} }
// prepare diagonal scale matrix
simdgroup_half8x8 mscale(scale);
for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) { for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) {
// skip -INF blocks // skip -INF blocks
// TODO: double-check this // TODO: double-check this
@ -2153,11 +2156,16 @@ kernel void kernel_flash_attn_ext_f16(
device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
for (int64_t i = 0; i < D8; ++i) { for (int64_t i = 0; i < D8; ++i) {
simdgroup_load(mk, pk + i*8, nb11/2, 0, true); simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true);
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
} }
// mqk = mqk*scale + mask
simdgroup_float8x8 mm;
simdgroup_load(mm, mp[0] + iic + 8*cc, nb31/sizeof(float), 0, false);
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
simdgroup_store(mqk, ss + 8*cc, T, 0, false); simdgroup_store(mqk, ss + 8*cc, T, 0, false);
} }
} }
@ -2166,7 +2174,8 @@ kernel void kernel_flash_attn_ext_f16(
for (int64_t j = 0; j < Q; ++j) { for (int64_t j = 0; j < Q; ++j) {
const int64_t p = tiisg; const int64_t p = tiisg;
const half s = ss[j*T + p]*scale + (mp[j][iic + p]); //const half s = ss[j*T + p]*scale + (mp[j][iic + p]);
const half s = ss[j*T + p];
half m = M[j]; half m = M[j];
@ -2203,7 +2212,7 @@ kernel void kernel_flash_attn_ext_f16(
for (int cc = 0; cc < C/8; ++cc) { for (int cc = 0; cc < C/8; ++cc) {
device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
simdgroup_load(mv, pv + i*8, nb21/2, 0, false); simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false);
simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv); simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv);
} }