mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 22:59:24 +01:00
metal : support Q > 8
This commit is contained in:
parent
134c81c78d
commit
1db22d7032
@ -104,7 +104,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
ctx_params.seed = 1234;
|
||||
ctx_params.n_ctx = n_kv_max;
|
||||
ctx_params.n_batch = 512;
|
||||
ctx_params.n_batch = 2048;
|
||||
ctx_params.mul_mat_q = mmq;
|
||||
|
||||
ctx_params.n_threads = params.n_threads;
|
||||
|
@ -2206,8 +2206,11 @@ static bool ggml_metal_graph_compute(
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||
|
||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! (multiple of 8)
|
||||
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! (multiple of 32)
|
||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||
|
||||
GGML_ASSERT(nqptg % 8 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||
|
@ -2040,6 +2040,7 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
|
||||
const int64_t D4 = D/4;
|
||||
const int64_t D8 = D/8;
|
||||
const int64_t Q8 = Q/8;
|
||||
const int64_t NW = N_SIMDWIDTH;
|
||||
const int64_t SH = (C + Q); // shared memory per simdgroup in (half)
|
||||
|
||||
@ -2051,7 +2052,7 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for diagonal matrix
|
||||
|
||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||
simdgroup_half8x8 lo[D8];
|
||||
simdgroup_half8x8 lo[Q8][D8];
|
||||
|
||||
// load heads from Q to shared memory
|
||||
for (int64_t j = sgitg; j < Q; j += nsg) {
|
||||
@ -2067,8 +2068,10 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
}
|
||||
|
||||
// zero out lo
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
|
||||
for (int64_t j = 0; j < Q8; ++j) {
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
lo[j][i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
|
||||
}
|
||||
}
|
||||
|
||||
// zero out shared memory SH
|
||||
@ -2108,10 +2111,12 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
const int64_t iv3 = iq3 / rv3;
|
||||
|
||||
// load the queries from shared memory into local memory
|
||||
simdgroup_half8x8 mq[D8];
|
||||
simdgroup_half8x8 mq[Q8][D8];
|
||||
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_load(mq[i], sq + i*8, T);
|
||||
for (int64_t j = 0; j < Q8; ++j) {
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T);
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
|
||||
@ -2128,7 +2133,10 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
// Q*K^T
|
||||
{
|
||||
for (int cc = 0; cc < C/8; ++cc) {
|
||||
simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half, Q>(0.h);
|
||||
simdgroup_half8x8 mqk[Q8];
|
||||
for (int64_t j = 0; j < Q8; ++j) {
|
||||
mqk[j] = make_filled_simdgroup_matrix<half, 8>(0.h);
|
||||
}
|
||||
|
||||
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
||||
@ -2136,15 +2144,19 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
simdgroup_half8x8 mk;
|
||||
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
||||
|
||||
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
||||
for (int64_t j = 0; j < Q8; ++j) {
|
||||
simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// mqk = mqk*scale + mask
|
||||
simdgroup_float8x8 mm;
|
||||
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(float), 0, false);
|
||||
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
||||
for (int64_t j = 0; j < Q8; ++j) {
|
||||
simdgroup_float8x8 mm;
|
||||
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(float)) + ic + 8*cc, nb31/sizeof(float), 0, false);
|
||||
simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
|
||||
|
||||
simdgroup_store(mqk, ss + 8*cc, T, 0, false);
|
||||
simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2166,7 +2178,7 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
|
||||
S[j] = S[j]*ms;
|
||||
|
||||
// create an 8x8 diagonal matrix for rescaling the output
|
||||
// create a QxQ diagonal matrix for rescaling the output
|
||||
if (tiisg == j) {
|
||||
ss[j*T + C + j] = ms;
|
||||
}
|
||||
@ -2189,28 +2201,30 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
}
|
||||
|
||||
// O = diag(ms)*O
|
||||
{
|
||||
for (int64_t j = 0; j < Q8; ++j) {
|
||||
simdgroup_half8x8 mm;
|
||||
simdgroup_load(mm, ss + C, T, 0, false);
|
||||
simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false);
|
||||
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_multiply(lo[i], mm, lo[i]);
|
||||
simdgroup_multiply(lo[j][i], mm, lo[j][i]);
|
||||
}
|
||||
}
|
||||
|
||||
// O = O + (Q*K^T)*V
|
||||
{
|
||||
for (int cc = 0; cc < C/8; ++cc) {
|
||||
simdgroup_half8x8 mp;
|
||||
simdgroup_load(mp, ss + 8*cc, T, 0, false);
|
||||
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
simdgroup_half8x8 mk;
|
||||
simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
|
||||
|
||||
simdgroup_half8x8 mv;
|
||||
simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false);
|
||||
for (int64_t j = 0; j < Q8; ++j) {
|
||||
simdgroup_half8x8 mv;
|
||||
simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false);
|
||||
|
||||
simdgroup_multiply_accumulate(lo[i], mp, mv, lo[i]);
|
||||
simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2234,8 +2248,10 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
|
||||
// each simdgroup stores its output to shared memory, reusing sq
|
||||
if (sgitg == sg) {
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
||||
for (int64_t j = 0; j < Q8; ++j) {
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2267,19 +2283,19 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
}
|
||||
|
||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||
{
|
||||
for (int64_t j = 0; j < Q8; ++j) {
|
||||
simdgroup_half8x8 t;
|
||||
simdgroup_half8x8 ms0;
|
||||
simdgroup_half8x8 ms1;
|
||||
|
||||
simdgroup_load(ms0, ss + C, T, 0, false);
|
||||
simdgroup_load(ms1, ss + C + sg*SH, T, 0, false);
|
||||
simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false);
|
||||
simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false);
|
||||
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_load (t, sq + i*8, T, 0, false);
|
||||
simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false);
|
||||
simdgroup_multiply(t, ms1, t);
|
||||
|
||||
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
|
||||
simdgroup_multiply_accumulate(lo[j][i], ms0, lo[j][i], t);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2287,8 +2303,10 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
|
||||
// store result to shared memory (reuse sq)
|
||||
if (sgitg == 0) {
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
||||
for (int64_t j = 0; j < Q8; ++j) {
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user