wip : good version 8x32

This commit is contained in:
Georgi Gerganov 2024-01-25 12:59:59 +02:00
parent eb12e3c391
commit f6416d4493
No known key found for this signature in database
GPG Key ID: BF970631944C16B7
2 changed files with 36 additions and 45 deletions

View File

@ -2253,9 +2253,9 @@ static bool ggml_metal_graph_compute(
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
const int64_t nsg = 4; // simdgroups per threadgroup (a.k.a. warps)
const int64_t nsg = 2; // simdgroups per threadgroup (a.k.a. warps)
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 8;
const int64_t ncpsg = 32;
//const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2);
const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2);

View File

@ -2072,9 +2072,9 @@ kernel void kernel_flash_attn_ext_f16(
}
}
if (tiisg < 1) {
if (tiisg < C) {
for (int64_t j = 0; j < Q; ++j) {
ss[j*T + tiisg] = 0.0h;
ss[j*T + 0 + tiisg] = 0.0h;
}
}
@ -2128,36 +2128,26 @@ kernel void kernel_flash_attn_ext_f16(
}
for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) {
//{
// bool skip = true;
// for (int64_t j = 0; j < Q; ++j) {
// skip = skip && (mp[j][iic] == -INFINITY);
// }
// if (skip) {
// continue;
// }
//}
{
simdgroup_half8x8 mk;
simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half, Q>(0.h);
device const half * pk = (device const half *) ((device const char *) k + (iic*nb11 + ik2*nb12 + ik3*nb13));
for (int cc = 0; cc < 4; ++cc) {
simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half, Q>(0.h);
for (int64_t i = 0; i < D8; ++i) {
simdgroup_load(mk, pk + i*8, nb11/2, 0, true);
device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
for (int64_t i = 0; i < D8; ++i) {
simdgroup_load(mk, pk + i*8, nb11/2, 0, true);
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
}
simdgroup_store(mqk, ss + 8*cc, T, 0, false);
}
simdgroup_store(mqk, ss, T, 0, false);
}
// not sure why this barrier is needed
simdgroup_barrier(mem_flags::mem_none);
for (int64_t j = 0; j < Q; ++j) {
const int64_t p = tiisg % C;
const int64_t p = tiisg;
const half s = ss[j*T + p]*scale + (mp[j][iic + p]);
@ -2168,37 +2158,38 @@ kernel void kernel_flash_attn_ext_f16(
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
S[j] = S[j]*ms + 0.25h*simd_sum(vs); // 4*8 = 32
S[j] = S[j]*ms + simd_sum(vs);
for (int64_t i = 0; i < L4; ++i) {
ls4[j][i] *= ms;
}
if (tiisg < C) {
ss[j*T + p] = vs;
}
ss[j*T + p] = vs;
}
{
simdgroup_half8x8 mv;
simdgroup_half8x8 mp;
simdgroup_half8x8 mqkv;
device const half * pv = (device const half *) ((device const char *) v + (iic*nb21 + iv2*nb22 + iv3*nb23));
// load mp
simdgroup_load(mp, ss, T, 0, false);
for (int64_t i = 0; i < D8; ++i) {
simdgroup_load (mv, pv + i*8, nb21/2, 0, false);
simdgroup_multiply(mqkv, mp, mv);
simdgroup_store (mqkv, ps + i*8, T, 0, false);
simdgroup_half8x8 mp[4];
simdgroup_half8x8 mqkv = make_filled_simdgroup_matrix<half, Q>(0.h);
for (int cc = 0; cc < 4; ++cc) {
simdgroup_load(mp[cc], ss + 8*cc, T, 0, false);
}
for (int cc = 0; cc < 4; ++cc) {
device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
simdgroup_load(mv, pv + i*8, nb21/2, 0, false);
simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv);
}
simdgroup_store(mqkv, ps + i*8, T, 0, false);
}
}
// not sure why this barrier is needed too
threadgroup_barrier(mem_flags::mem_none);
for (int64_t j = 0; j < Q; ++j) {
for (int64_t i = 0; i < L4; ++i) {
ls4[j][i] += ps4[j*T4 + N4*i + tiisg];
@ -2284,9 +2275,9 @@ kernel void kernel_flash_attn_ext_f16(
}
}
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 8>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 8>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 8>;
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>;
kernel void kernel_cpy_f16_f16(
device const half * src0,