mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 06:39:25 +01:00
metal : clean up
This commit is contained in:
parent
f6416d4493
commit
2bf91c5306
@ -2253,9 +2253,9 @@ static bool ggml_metal_graph_compute(
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||
|
||||
const int64_t nsg = 2; // simdgroups per threadgroup (a.k.a. warps)
|
||||
const int64_t nsg = 4; // simdgroups per threadgroup (a.k.a. warps)
|
||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 32;
|
||||
const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values)
|
||||
|
||||
//const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2);
|
||||
const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2);
|
||||
|
@ -2055,10 +2055,8 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(D + 1*C) + 1*D);
|
||||
threadgroup half * ss = (threadgroup half *) (shared + sgitg*(D + 1*C) + 2*D);
|
||||
|
||||
half4 ls4[Q][L4];
|
||||
|
||||
// load heads from Q to shared memory
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
// load heads from Q to shared memory
|
||||
for (int64_t j = sgitg; j < Q; j += nsg) {
|
||||
if (iq1 + j < ne01) {
|
||||
pq4[j*T4 + N4*i + tiisg] = ((device const half4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + tiisg];
|
||||
@ -2067,8 +2065,9 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
}
|
||||
}
|
||||
|
||||
// zero out shared memory
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
ls4[j][i] = 0.0h;
|
||||
ps4[j*T4 + N4*i + tiisg] = 0.0h;
|
||||
}
|
||||
}
|
||||
|
||||
@ -2113,6 +2112,7 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
simdgroup_load(mq[i], pq + i*8, T);
|
||||
}
|
||||
|
||||
// TODO: this can be improved
|
||||
device const float * mp[Q];
|
||||
|
||||
{
|
||||
@ -2128,10 +2128,26 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
}
|
||||
|
||||
for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) {
|
||||
// skip -INF blocks
|
||||
// TODO: double-check this
|
||||
{
|
||||
float smc = -INFINITY;
|
||||
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const float mc = mp[j] ? mp[j][iic + tiisg] : -INFINITY;
|
||||
smc = simd_max(max(smc, mc));
|
||||
}
|
||||
|
||||
if (smc == -INFINITY) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Q*K^T
|
||||
{
|
||||
simdgroup_half8x8 mk;
|
||||
|
||||
for (int cc = 0; cc < 4; ++cc) {
|
||||
for (int cc = 0; cc < C/8; ++cc) {
|
||||
simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half, Q>(0.h);
|
||||
|
||||
device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
@ -2146,6 +2162,7 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
}
|
||||
}
|
||||
|
||||
// online softmax
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const int64_t p = tiisg;
|
||||
|
||||
@ -2161,24 +2178,27 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
S[j] = S[j]*ms + simd_sum(vs);
|
||||
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
ls4[j][i] *= ms;
|
||||
ps4[j*T4 + N4*i + tiisg] *= ms;
|
||||
}
|
||||
|
||||
ss[j*T + p] = vs;
|
||||
}
|
||||
|
||||
// (Q*K^T)*V
|
||||
{
|
||||
simdgroup_half8x8 mv;
|
||||
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_half8x8 mp[4];
|
||||
simdgroup_half8x8 mqkv = make_filled_simdgroup_matrix<half, Q>(0.h);
|
||||
simdgroup_half8x8 mp[C/8];
|
||||
simdgroup_half8x8 mqkv;
|
||||
|
||||
for (int cc = 0; cc < 4; ++cc) {
|
||||
simdgroup_load(mqkv, ps + i*8, T, 0, false);
|
||||
|
||||
for (int cc = 0; cc < C/8; ++cc) {
|
||||
simdgroup_load(mp[cc], ss + 8*cc, T, 0, false);
|
||||
}
|
||||
|
||||
for (int cc = 0; cc < 4; ++cc) {
|
||||
for (int cc = 0; cc < C/8; ++cc) {
|
||||
device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
simdgroup_load(mv, pv + i*8, nb21/2, 0, false);
|
||||
@ -2189,12 +2209,6 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
simdgroup_store(mqkv, ps + i*8, T, 0, false);
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
ls4[j][i] += ps4[j*T4 + N4*i + tiisg];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
@ -2208,23 +2222,12 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// reduce the warps
|
||||
// TODO: try parallel reduce
|
||||
if (sgitg == 0) {
|
||||
half S = { 0.0h };
|
||||
half M = { -INFINITY };
|
||||
|
||||
half S = { 0.0h };
|
||||
half M = { -INFINITY };
|
||||
|
||||
for (int64_t sg = 1; sg < nsg; ++sg) {
|
||||
if (sgitg == sg) {
|
||||
// store heads to shared memory - reuse pq4
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
pq4[j*T4 + N4*i + tiisg] = ls4[j][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sgitg == 0) {
|
||||
for (int64_t sg = 1; sg < nsg; ++sg) {
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const half S0 = ss[j*T + 0];
|
||||
const half S1 = ss[j*T + sg*(D + 1*C) + 0];
|
||||
@ -2245,21 +2248,10 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
ls4[j][i] = ls4[j][i]*ms0 + pq4[j*T4 + N4*i + tiisg]*ms1;
|
||||
ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
if (sgitg == 0) {
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const half S = ss[j*T + 0];
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
ls4[j][i] = ls4[j][i]/S;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@ -2268,8 +2260,10 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
|
||||
if (sgitg == 0) {
|
||||
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||
const half S = ss[j*T + 0];
|
||||
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + N4*i + tiisg] = (float4) ls4[j][i];
|
||||
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + N4*i + tiisg] = (float4) ps4[j*T4 + N4*i + tiisg]/S;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user