mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-15 06:40:45 +01:00
metal : scale and mask in matrix form
This commit is contained in:
parent
d917746ddb
commit
432ad04ffa
@ -2127,6 +2127,9 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// prepare diagonal scale matrix
|
||||||
|
simdgroup_half8x8 mscale(scale);
|
||||||
|
|
||||||
for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) {
|
for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) {
|
||||||
// skip -INF blocks
|
// skip -INF blocks
|
||||||
// TODO: double-check this
|
// TODO: double-check this
|
||||||
@ -2153,11 +2156,16 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
|
||||||
for (int64_t i = 0; i < D8; ++i) {
|
for (int64_t i = 0; i < D8; ++i) {
|
||||||
simdgroup_load(mk, pk + i*8, nb11/2, 0, true);
|
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true);
|
||||||
|
|
||||||
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mqk = mqk*scale + mask
|
||||||
|
simdgroup_float8x8 mm;
|
||||||
|
simdgroup_load(mm, mp[0] + iic + 8*cc, nb31/sizeof(float), 0, false);
|
||||||
|
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
||||||
|
|
||||||
simdgroup_store(mqk, ss + 8*cc, T, 0, false);
|
simdgroup_store(mqk, ss + 8*cc, T, 0, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2166,7 +2174,8 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
const int64_t p = tiisg;
|
const int64_t p = tiisg;
|
||||||
|
|
||||||
const half s = ss[j*T + p]*scale + (mp[j][iic + p]);
|
//const half s = ss[j*T + p]*scale + (mp[j][iic + p]);
|
||||||
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
half m = M[j];
|
half m = M[j];
|
||||||
|
|
||||||
@ -2203,7 +2212,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
for (int cc = 0; cc < C/8; ++cc) {
|
for (int cc = 0; cc < C/8; ++cc) {
|
||||||
device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
simdgroup_load(mv, pv + i*8, nb21/2, 0, false);
|
simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false);
|
||||||
|
|
||||||
simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv);
|
simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user