mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 15:18:26 +01:00
wip : good version 8x32
This commit is contained in:
parent
eb12e3c391
commit
f6416d4493
@ -2253,9 +2253,9 @@ static bool ggml_metal_graph_compute(
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||
|
||||
const int64_t nsg = 4; // simdgroups per threadgroup (a.k.a. warps)
|
||||
const int64_t nsg = 2; // simdgroups per threadgroup (a.k.a. warps)
|
||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 8;
|
||||
const int64_t ncpsg = 32;
|
||||
|
||||
//const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2);
|
||||
const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2);
|
||||
|
@ -2072,9 +2072,9 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
}
|
||||
}
|
||||
|
||||
if (tiisg < 1) {
|
||||
if (tiisg < C) {
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
ss[j*T + tiisg] = 0.0h;
|
||||
ss[j*T + 0 + tiisg] = 0.0h;
|
||||
}
|
||||
}
|
||||
|
||||
@ -2128,36 +2128,26 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
}
|
||||
|
||||
for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) {
|
||||
//{
|
||||
// bool skip = true;
|
||||
// for (int64_t j = 0; j < Q; ++j) {
|
||||
// skip = skip && (mp[j][iic] == -INFINITY);
|
||||
// }
|
||||
// if (skip) {
|
||||
// continue;
|
||||
// }
|
||||
//}
|
||||
|
||||
{
|
||||
simdgroup_half8x8 mk;
|
||||
simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half, Q>(0.h);
|
||||
|
||||
device const half * pk = (device const half *) ((device const char *) k + (iic*nb11 + ik2*nb12 + ik3*nb13));
|
||||
for (int cc = 0; cc < 4; ++cc) {
|
||||
simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half, Q>(0.h);
|
||||
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_load(mk, pk + i*8, nb11/2, 0, true);
|
||||
device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
||||
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_load(mk, pk + i*8, nb11/2, 0, true);
|
||||
|
||||
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
||||
}
|
||||
|
||||
simdgroup_store(mqk, ss + 8*cc, T, 0, false);
|
||||
}
|
||||
|
||||
simdgroup_store(mqk, ss, T, 0, false);
|
||||
}
|
||||
|
||||
// not sure why this barrier is needed
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const int64_t p = tiisg % C;
|
||||
const int64_t p = tiisg;
|
||||
|
||||
const half s = ss[j*T + p]*scale + (mp[j][iic + p]);
|
||||
|
||||
@ -2168,37 +2158,38 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
||||
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
||||
|
||||
S[j] = S[j]*ms + 0.25h*simd_sum(vs); // 4*8 = 32
|
||||
S[j] = S[j]*ms + simd_sum(vs);
|
||||
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
ls4[j][i] *= ms;
|
||||
}
|
||||
|
||||
if (tiisg < C) {
|
||||
ss[j*T + p] = vs;
|
||||
}
|
||||
ss[j*T + p] = vs;
|
||||
}
|
||||
|
||||
{
|
||||
simdgroup_half8x8 mv;
|
||||
simdgroup_half8x8 mp;
|
||||
simdgroup_half8x8 mqkv;
|
||||
|
||||
device const half * pv = (device const half *) ((device const char *) v + (iic*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
// load mp
|
||||
simdgroup_load(mp, ss, T, 0, false);
|
||||
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_load (mv, pv + i*8, nb21/2, 0, false);
|
||||
simdgroup_multiply(mqkv, mp, mv);
|
||||
simdgroup_store (mqkv, ps + i*8, T, 0, false);
|
||||
simdgroup_half8x8 mp[4];
|
||||
simdgroup_half8x8 mqkv = make_filled_simdgroup_matrix<half, Q>(0.h);
|
||||
|
||||
for (int cc = 0; cc < 4; ++cc) {
|
||||
simdgroup_load(mp[cc], ss + 8*cc, T, 0, false);
|
||||
}
|
||||
|
||||
for (int cc = 0; cc < 4; ++cc) {
|
||||
device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
simdgroup_load(mv, pv + i*8, nb21/2, 0, false);
|
||||
|
||||
simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv);
|
||||
}
|
||||
|
||||
simdgroup_store(mqkv, ps + i*8, T, 0, false);
|
||||
}
|
||||
}
|
||||
|
||||
// not sure why this barrier is needed too
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
ls4[j][i] += ps4[j*T4 + N4*i + tiisg];
|
||||
@ -2284,9 +2275,9 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 8>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 8>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 8>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>;
|
||||
|
||||
kernel void kernel_cpy_f16_f16(
|
||||
device const half * src0,
|
||||
|
Loading…
Reference in New Issue
Block a user