mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-01 00:39:00 +01:00
wip : good version 8x32
This commit is contained in:
parent
eb12e3c391
commit
f6416d4493
@ -2253,9 +2253,9 @@ static bool ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
||||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||||
|
|
||||||
const int64_t nsg = 4; // simdgroups per threadgroup (a.k.a. warps)
|
const int64_t nsg = 2; // simdgroups per threadgroup (a.k.a. warps)
|
||||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||||
const int64_t ncpsg = 8;
|
const int64_t ncpsg = 32;
|
||||||
|
|
||||||
//const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2);
|
//const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2);
|
||||||
const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2);
|
const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2);
|
||||||
|
@ -2072,9 +2072,9 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tiisg < 1) {
|
if (tiisg < C) {
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
ss[j*T + tiisg] = 0.0h;
|
ss[j*T + 0 + tiisg] = 0.0h;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2128,36 +2128,26 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) {
|
for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) {
|
||||||
//{
|
|
||||||
// bool skip = true;
|
|
||||||
// for (int64_t j = 0; j < Q; ++j) {
|
|
||||||
// skip = skip && (mp[j][iic] == -INFINITY);
|
|
||||||
// }
|
|
||||||
// if (skip) {
|
|
||||||
// continue;
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
simdgroup_half8x8 mk;
|
simdgroup_half8x8 mk;
|
||||||
simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half, Q>(0.h);
|
|
||||||
|
|
||||||
device const half * pk = (device const half *) ((device const char *) k + (iic*nb11 + ik2*nb12 + ik3*nb13));
|
for (int cc = 0; cc < 4; ++cc) {
|
||||||
|
simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half, Q>(0.h);
|
||||||
|
|
||||||
for (int64_t i = 0; i < D8; ++i) {
|
device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
simdgroup_load(mk, pk + i*8, nb11/2, 0, true);
|
|
||||||
|
|
||||||
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
for (int64_t i = 0; i < D8; ++i) {
|
||||||
|
simdgroup_load(mk, pk + i*8, nb11/2, 0, true);
|
||||||
|
|
||||||
|
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
||||||
|
}
|
||||||
|
|
||||||
|
simdgroup_store(mqk, ss + 8*cc, T, 0, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
simdgroup_store(mqk, ss, T, 0, false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// not sure why this barrier is needed
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
const int64_t p = tiisg % C;
|
const int64_t p = tiisg;
|
||||||
|
|
||||||
const half s = ss[j*T + p]*scale + (mp[j][iic + p]);
|
const half s = ss[j*T + p]*scale + (mp[j][iic + p]);
|
||||||
|
|
||||||
@ -2168,37 +2158,38 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
||||||
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
||||||
|
|
||||||
S[j] = S[j]*ms + 0.25h*simd_sum(vs); // 4*8 = 32
|
S[j] = S[j]*ms + simd_sum(vs);
|
||||||
|
|
||||||
for (int64_t i = 0; i < L4; ++i) {
|
for (int64_t i = 0; i < L4; ++i) {
|
||||||
ls4[j][i] *= ms;
|
ls4[j][i] *= ms;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tiisg < C) {
|
ss[j*T + p] = vs;
|
||||||
ss[j*T + p] = vs;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
simdgroup_half8x8 mv;
|
simdgroup_half8x8 mv;
|
||||||
simdgroup_half8x8 mp;
|
|
||||||
simdgroup_half8x8 mqkv;
|
|
||||||
|
|
||||||
device const half * pv = (device const half *) ((device const char *) v + (iic*nb21 + iv2*nb22 + iv3*nb23));
|
|
||||||
|
|
||||||
// load mp
|
|
||||||
simdgroup_load(mp, ss, T, 0, false);
|
|
||||||
|
|
||||||
for (int64_t i = 0; i < D8; ++i) {
|
for (int64_t i = 0; i < D8; ++i) {
|
||||||
simdgroup_load (mv, pv + i*8, nb21/2, 0, false);
|
simdgroup_half8x8 mp[4];
|
||||||
simdgroup_multiply(mqkv, mp, mv);
|
simdgroup_half8x8 mqkv = make_filled_simdgroup_matrix<half, Q>(0.h);
|
||||||
simdgroup_store (mqkv, ps + i*8, T, 0, false);
|
|
||||||
|
for (int cc = 0; cc < 4; ++cc) {
|
||||||
|
simdgroup_load(mp[cc], ss + 8*cc, T, 0, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int cc = 0; cc < 4; ++cc) {
|
||||||
|
device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
|
simdgroup_load(mv, pv + i*8, nb21/2, 0, false);
|
||||||
|
|
||||||
|
simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv);
|
||||||
|
}
|
||||||
|
|
||||||
|
simdgroup_store(mqkv, ps + i*8, T, 0, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// not sure why this barrier is needed too
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
for (int64_t i = 0; i < L4; ++i) {
|
for (int64_t i = 0; i < L4; ++i) {
|
||||||
ls4[j][i] += ps4[j*T4 + N4*i + tiisg];
|
ls4[j][i] += ps4[j*T4 + N4*i + tiisg];
|
||||||
@ -2284,9 +2275,9 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 8>;
|
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 8>;
|
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>;
|
||||||
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 8>;
|
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>;
|
||||||
|
|
||||||
kernel void kernel_cpy_f16_f16(
|
kernel void kernel_cpy_f16_f16(
|
||||||
device const half * src0,
|
device const half * src0,
|
||||||
|
Loading…
Reference in New Issue
Block a user