metal : support Q > 8

This commit is contained in:
Georgi Gerganov 2024-01-28 23:08:31 +02:00
parent 134c81c78d
commit 1db22d7032
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 55 additions and 34 deletions

View File

@ -104,7 +104,7 @@ int main(int argc, char ** argv) {
ctx_params.seed = 1234;
ctx_params.n_ctx = n_kv_max;
ctx_params.n_batch = 512;
ctx_params.n_batch = 2048;
ctx_params.mul_mat_q = mmq;
ctx_params.n_threads = params.n_threads;

View File

@ -2206,8 +2206,11 @@ static bool ggml_metal_graph_compute(
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! (multiple of 8)
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! (multiple of 32)
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
// simdgroups per threadgroup (a.k.a. warps)
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)

View File

@ -2040,6 +2040,7 @@ kernel void kernel_flash_attn_ext_f16(
const int64_t D4 = D/4;
const int64_t D8 = D/8;
const int64_t Q8 = Q/8;
const int64_t NW = N_SIMDWIDTH;
const int64_t SH = (C + Q); // shared memory per simdgroup in (half)
@ -2051,7 +2052,7 @@ kernel void kernel_flash_attn_ext_f16(
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
simdgroup_half8x8 lo[D8];
simdgroup_half8x8 lo[Q8][D8];
// load heads from Q to shared memory
for (int64_t j = sgitg; j < Q; j += nsg) {
@ -2067,8 +2068,10 @@ kernel void kernel_flash_attn_ext_f16(
}
// zero out lo
for (int64_t i = 0; i < D8; ++i) {
lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
for (int64_t j = 0; j < Q8; ++j) {
for (int64_t i = 0; i < D8; ++i) {
lo[j][i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
}
}
// zero out shared memory SH
@ -2108,10 +2111,12 @@ kernel void kernel_flash_attn_ext_f16(
const int64_t iv3 = iq3 / rv3;
// load the queries from shared memory into local memory
simdgroup_half8x8 mq[D8];
simdgroup_half8x8 mq[Q8][D8];
for (int64_t i = 0; i < D8; ++i) {
simdgroup_load(mq[i], sq + i*8, T);
for (int64_t j = 0; j < Q8; ++j) {
for (int64_t i = 0; i < D8; ++i) {
simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T);
}
}
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
@ -2128,7 +2133,10 @@ kernel void kernel_flash_attn_ext_f16(
// Q*K^T
{
for (int cc = 0; cc < C/8; ++cc) {
simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half, Q>(0.h);
simdgroup_half8x8 mqk[Q8];
for (int64_t j = 0; j < Q8; ++j) {
mqk[j] = make_filled_simdgroup_matrix<half, 8>(0.h);
}
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
@ -2136,15 +2144,19 @@ kernel void kernel_flash_attn_ext_f16(
simdgroup_half8x8 mk;
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
for (int64_t j = 0; j < Q8; ++j) {
simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]);
}
}
// mqk = mqk*scale + mask
simdgroup_float8x8 mm;
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(float), 0, false);
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
for (int64_t j = 0; j < Q8; ++j) {
simdgroup_float8x8 mm;
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(float)) + ic + 8*cc, nb31/sizeof(float), 0, false);
simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
simdgroup_store(mqk, ss + 8*cc, T, 0, false);
simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false);
}
}
}
@ -2166,7 +2178,7 @@ kernel void kernel_flash_attn_ext_f16(
S[j] = S[j]*ms;
// create an 8x8 diagonal matrix for rescaling the output
// create a QxQ diagonal matrix for rescaling the output
if (tiisg == j) {
ss[j*T + C + j] = ms;
}
@ -2189,28 +2201,30 @@ kernel void kernel_flash_attn_ext_f16(
}
// O = diag(ms)*O
{
for (int64_t j = 0; j < Q8; ++j) {
simdgroup_half8x8 mm;
simdgroup_load(mm, ss + C, T, 0, false);
simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false);
for (int64_t i = 0; i < D8; ++i) {
simdgroup_multiply(lo[i], mm, lo[i]);
simdgroup_multiply(lo[j][i], mm, lo[j][i]);
}
}
// O = O + (Q*K^T)*V
{
for (int cc = 0; cc < C/8; ++cc) {
simdgroup_half8x8 mp;
simdgroup_load(mp, ss + 8*cc, T, 0, false);
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
for (int64_t i = 0; i < D8; ++i) {
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
simdgroup_half8x8 mk;
simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
simdgroup_half8x8 mv;
simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false);
for (int64_t j = 0; j < Q8; ++j) {
simdgroup_half8x8 mv;
simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false);
simdgroup_multiply_accumulate(lo[i], mp, mv, lo[i]);
simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]);
}
}
}
}
@ -2234,8 +2248,10 @@ kernel void kernel_flash_attn_ext_f16(
// each simdgroup stores its output to shared memory, reusing sq
if (sgitg == sg) {
for (int64_t i = 0; i < D8; ++i) {
simdgroup_store(lo[i], sq + i*8, T, 0, false);
for (int64_t j = 0; j < Q8; ++j) {
for (int64_t i = 0; i < D8; ++i) {
simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
}
}
}
@ -2267,19 +2283,19 @@ kernel void kernel_flash_attn_ext_f16(
}
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
{
for (int64_t j = 0; j < Q8; ++j) {
simdgroup_half8x8 t;
simdgroup_half8x8 ms0;
simdgroup_half8x8 ms1;
simdgroup_load(ms0, ss + C, T, 0, false);
simdgroup_load(ms1, ss + C + sg*SH, T, 0, false);
simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false);
simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false);
for (int64_t i = 0; i < D8; ++i) {
simdgroup_load (t, sq + i*8, T, 0, false);
simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false);
simdgroup_multiply(t, ms1, t);
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
simdgroup_multiply_accumulate(lo[j][i], ms0, lo[j][i], t);
}
}
}
@ -2287,8 +2303,10 @@ kernel void kernel_flash_attn_ext_f16(
// store result to shared memory (reuse sq)
if (sgitg == 0) {
for (int64_t i = 0; i < D8; ++i) {
simdgroup_store(lo[i], sq + i*8, T, 0, false);
for (int64_t j = 0; j < Q8; ++j) {
for (int64_t i = 0; i < D8; ++i) {
simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
}
}
}