mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 15:18:26 +01:00
metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers - significantly reduced SRAM usage - more efficient skipping of -INF blocks - avoid simdgroup barrier in hot loop - add comments
This commit is contained in:
parent
b3dd7d975f
commit
77f6976a87
12
ggml-metal.m
12
ggml-metal.m
@ -2213,14 +2213,14 @@ static bool ggml_metal_graph_compute(
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||
|
||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||
const int64_t nsg = ne01 < 4 ? 12 : 4; // simdgroups per threadgroup (a.k.a. warps)
|
||||
|
||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values)
|
||||
const int64_t ncpsg = 32; // cache values per simdgroup
|
||||
|
||||
//const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2);
|
||||
const size_t smem = nqptg*(ne00 + nsg*(ne00 + 1*ncpsg))*(sizeof(float)/2);
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/32, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4;
|
||||
|
||||
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
||||
|
||||
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||
|
272
ggml-metal.metal
272
ggml-metal.metal
@ -1995,6 +1995,7 @@ typedef void (flash_attn_ext_f16_t)(
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
||||
template<int64_t D, int64_t Q, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
|
||||
kernel void kernel_flash_attn_ext_f16(
|
||||
device const char * q,
|
||||
@ -2038,39 +2039,45 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
const int64_t iq1 = tgpig[0]*Q;
|
||||
|
||||
const int64_t D4 = D/4;
|
||||
const int64_t N4 = N_SIMDWIDTH;
|
||||
const int64_t L4 = (D4 + N4 - 1)/N4;
|
||||
const int64_t D8 = D/8;
|
||||
const int64_t NW = N_SIMDWIDTH;
|
||||
const int64_t L4 = (D4 + NW - 1)/NW;
|
||||
const int64_t SH = (C + Q); // shared memory per simdgroup in (half)
|
||||
|
||||
const int64_t T = D + nsg*(D + 1*C); // shared memory size per query in half
|
||||
const int64_t T4 = T/4; // shared memory size per query in half4
|
||||
const int64_t T = D + nsg*SH; // shared memory size per query in (half)
|
||||
const int64_t T4 = T/4; // shared memory size per query in (half4)
|
||||
|
||||
threadgroup half * pq = (threadgroup half *) (shared + 0*D);
|
||||
threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*D);
|
||||
threadgroup half * ps = (threadgroup half *) (shared + sgitg*(D + 1*C) + 1*D);
|
||||
threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(D + 1*C) + 1*D);
|
||||
threadgroup half * ss = (threadgroup half *) (shared + sgitg*(D + 1*C) + 2*D);
|
||||
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // scratch buffer for attention
|
||||
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix
|
||||
|
||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||
simdgroup_half8x8 lo[D8];
|
||||
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
// load heads from Q to shared memory
|
||||
for (int64_t j = sgitg; j < Q; j += nsg) {
|
||||
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
||||
if (iq1 + j < ne01) {
|
||||
pq4[j*T4 + N4*i + tiisg] = (half4) q4[N4*i + tiisg];
|
||||
sq4[j*T4 + NW*i + tiisg] = (half4) q4[NW*i + tiisg];
|
||||
} else {
|
||||
pq4[j*T4 + N4*i + tiisg] = 0.0h;
|
||||
sq4[j*T4 + NW*i + tiisg] = 0.0h;
|
||||
}
|
||||
}
|
||||
|
||||
// zero out shared memory
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
ps4[j*T4 + N4*i + tiisg] = 0.0h;
|
||||
}
|
||||
}
|
||||
|
||||
// zero out lo
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
|
||||
}
|
||||
|
||||
// zero out shared memory SH
|
||||
if (tiisg < C) {
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
ss[j*T + 0 + tiisg] = 0.0h;
|
||||
ss[j*T + tiisg] = 0.0h;
|
||||
if (tiisg < Q) {
|
||||
ss[j*T + C + tiisg] = 0.0h;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2103,46 +2110,24 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
const int64_t iv2 = iq2 / rv2;
|
||||
const int64_t iv3 = iq3 / rv3;
|
||||
|
||||
// load the queries from shared memory into local memory
|
||||
simdgroup_half8x8 mq[D8];
|
||||
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_load(mq[i], pq + i*8, T);
|
||||
simdgroup_load(mq[i], sq + i*8, T);
|
||||
}
|
||||
|
||||
// TODO: this can be improved
|
||||
device const float * mp[Q];
|
||||
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
|
||||
|
||||
{
|
||||
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
|
||||
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
if (iq1 + j < ne01) {
|
||||
mp[j] = (device const float *) (mask + ((ir + j)%ne31)*nb31);
|
||||
} else {
|
||||
mp[j] = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
// pointer to the mask
|
||||
device const float * mp = (device const float *) (mask + (ir%ne31)*nb31);
|
||||
|
||||
// prepare diagonal scale matrix
|
||||
simdgroup_half8x8 mscale(scale);
|
||||
|
||||
for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) {
|
||||
// skip -INF blocks
|
||||
// TODO: double-check this
|
||||
{
|
||||
float smc = -INFINITY;
|
||||
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const float mc = mp[j] ? mp[j][iic + tiisg] : -INFINITY;
|
||||
smc = simd_max(max(smc, mc));
|
||||
}
|
||||
|
||||
if (smc == -INFINITY) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// loop over the KV cache
|
||||
// each simdgroup handles blocks of Q rows and C columns
|
||||
for (int64_t ic = C*sgitg; ic < ne11; ic += C*nsg) {
|
||||
// Q*K^T
|
||||
{
|
||||
simdgroup_half8x8 mk;
|
||||
@ -2150,7 +2135,7 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
for (int cc = 0; cc < C/8; ++cc) {
|
||||
simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half, Q>(0.h);
|
||||
|
||||
device const half * pk = (device const half *) ((device const char *) k + ((iic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true);
|
||||
@ -2160,65 +2145,77 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
|
||||
// mqk = mqk*scale + mask
|
||||
simdgroup_float8x8 mm;
|
||||
simdgroup_load(mm, mp[0] + iic + 8*cc, nb31/sizeof(float), 0, false);
|
||||
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(float), 0, false);
|
||||
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
||||
|
||||
simdgroup_store(mqk, ss + 8*cc, T, 0, false);
|
||||
}
|
||||
}
|
||||
|
||||
// used to detect blocks full of -INF
|
||||
half smax = -INFINITY;
|
||||
|
||||
// online softmax
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const int64_t p = tiisg;
|
||||
|
||||
//const half s = ss[j*T + p]*scale + (mp[j][iic + p]);
|
||||
const half s = ss[j*T + p];
|
||||
|
||||
half m = M[j];
|
||||
|
||||
smax = simd_max(max(smax, s));
|
||||
M[j] = simd_max(max(M[j], s));
|
||||
|
||||
const half m = M[j];
|
||||
|
||||
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
||||
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
||||
|
||||
S[j] = S[j]*ms + simd_sum(vs);
|
||||
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
ps4[j*T4 + N4*i + tiisg] *= ms;
|
||||
// create an 8x8 diagonal matrix for rescaling the output
|
||||
if (p == j) {
|
||||
ss[j*T + C + j] = ms;
|
||||
}
|
||||
|
||||
// the P matrix from the paper (Q rows, C columns)
|
||||
ss[j*T + p] = vs;
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// skip -INF blocks
|
||||
if (smax == -INFINITY) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// (Q*K^T)*V
|
||||
// O = diag(ms)*O
|
||||
{
|
||||
simdgroup_half8x8 mm;
|
||||
|
||||
simdgroup_load(mm, ss + C, T, 0, false);
|
||||
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_multiply(lo[i], mm, lo[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// O = O + (Q*K^T)*V
|
||||
{
|
||||
simdgroup_half8x8 mv;
|
||||
|
||||
simdgroup_half8x8 mp[C/8];
|
||||
for (int cc = 0; cc < C/8; ++cc) {
|
||||
simdgroup_load(mp[cc], ss + 8*cc, T, 0, false);
|
||||
}
|
||||
simdgroup_half8x8 mp;
|
||||
simdgroup_load(mp, ss + 8*cc, T, 0, false);
|
||||
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_half8x8 mqkv;
|
||||
|
||||
simdgroup_load(mqkv, ps + i*8, T, 0, false);
|
||||
|
||||
for (int cc = 0; cc < C/8; ++cc) {
|
||||
device const half * pv = (device const half *) ((device const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false);
|
||||
|
||||
simdgroup_multiply_accumulate(mqkv, mp[cc], mv, mqkv);
|
||||
simdgroup_multiply_accumulate(lo[i], mp, mv, lo[i]);
|
||||
}
|
||||
|
||||
simdgroup_store(mqkv, ps + i*8, T, 0, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
if (tiisg == 0) {
|
||||
ss[j*T + 0] = S[j];
|
||||
@ -2227,91 +2224,82 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// reduce the warps sequentially
|
||||
for (int64_t sg = 1; sg < nsg; ++sg) {
|
||||
half S = { 0.0h };
|
||||
half M = { -INFINITY };
|
||||
|
||||
// reduce the warps
|
||||
#if 1
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// each simdgroup stores its output to shared memory, reusing sq4
|
||||
if (sgitg == sg) {
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// the first simdgroup accumulates the results from the other simdgroups
|
||||
if (sgitg == 0) {
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const half S0 = ss[j*T + 0];
|
||||
const half S1 = ss[j*T + sg*SH + 0];
|
||||
|
||||
const half M0 = ss[j*T + 1];
|
||||
const half M1 = ss[j*T + sg*SH + 1];
|
||||
|
||||
M = max(M0, M1);
|
||||
|
||||
const half ms0 = exp(M0 - M);
|
||||
const half ms1 = exp(M1 - M);
|
||||
|
||||
S = S0*ms0 + S1*ms1;
|
||||
|
||||
if (tiisg == 0) {
|
||||
ss[j*T + 0] = S;
|
||||
ss[j*T + 1] = M;
|
||||
|
||||
ss[j*T + C + j ] = ms0;
|
||||
ss[j*T + C + j + sg*SH] = ms1;
|
||||
}
|
||||
}
|
||||
|
||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||
{
|
||||
simdgroup_half8x8 t;
|
||||
simdgroup_half8x8 ms0;
|
||||
simdgroup_half8x8 ms1;
|
||||
|
||||
simdgroup_load(ms0, ss + C, T, 0, false);
|
||||
simdgroup_load(ms1, ss + C + sg*SH, T, 0, false);
|
||||
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_load (t, sq + i*8, T, 0, false);
|
||||
simdgroup_multiply(t, ms1, t);
|
||||
|
||||
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// store result to shared memory (reuse sq4)
|
||||
if (sgitg == 0) {
|
||||
half S = { 0.0h };
|
||||
half M = { -INFINITY };
|
||||
|
||||
for (int64_t sg = 1; sg < nsg; ++sg) {
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const half S0 = ss[j*T + 0];
|
||||
const half S1 = ss[j*T + sg*(D + 1*C) + 0];
|
||||
|
||||
const half M0 = ss[j*T + 1];
|
||||
const half M1 = ss[j*T + sg*(D + 1*C) + 1];
|
||||
|
||||
M = max(M0, M1);
|
||||
|
||||
const half ms0 = exp(M0 - M);
|
||||
const half ms1 = exp(M1 - M);
|
||||
|
||||
S = S0*ms0 + S1*ms1;
|
||||
|
||||
if (tiisg == 0) {
|
||||
ss[j*T + 0] = S;
|
||||
ss[j*T + 1] = M;
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1;
|
||||
}
|
||||
}
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
||||
}
|
||||
}
|
||||
#else
|
||||
// parallel reduce
|
||||
// NOTE: this is significantly slower than the serial version above, likely due to the small number of warps
|
||||
{
|
||||
half S = { 0.0h };
|
||||
half M = { -INFINITY };
|
||||
|
||||
for (int64_t sg = nsg/2; sg > 0; sg /= 2) {
|
||||
if (sgitg >= sg) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const half S0 = ss[j*T + 0];
|
||||
const half S1 = ss[j*T + sg*(D + 1*C) + 0];
|
||||
|
||||
const half M0 = ss[j*T + 1];
|
||||
const half M1 = ss[j*T + sg*(D + 1*C) + 1];
|
||||
|
||||
M = max(M0, M1);
|
||||
|
||||
const half ms0 = exp(M0 - M);
|
||||
const half ms1 = exp(M1 - M);
|
||||
|
||||
S = S0*ms0 + S1*ms1;
|
||||
|
||||
if (tiisg == 0) {
|
||||
ss[j*T + 0] = S;
|
||||
ss[j*T + 1] = M;
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1;
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
device float4 * dst4 = (device float4 *) dst;
|
||||
|
||||
// final rescale with 1/S and store to global memory
|
||||
if (sgitg == 0) {
|
||||
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||
const half S = ss[j*T + 0];
|
||||
|
||||
for (int64_t i = 0; i < L4; ++i) {
|
||||
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + N4*i + tiisg] = (float4) ps4[j*T4 + N4*i + tiisg]/S;
|
||||
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + NW*i + tiisg] = (float4) sq4[j*T4 + NW*i + tiisg]/S;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user