mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 15:18:26 +01:00
metal : add parallel reduce version (disabled)
This commit is contained in:
parent
f9ca5dcbe8
commit
6fea843b24
@ -2254,7 +2254,7 @@ static bool ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||||
|
|
||||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||||
const int64_t nsg = ne01 < 4 ? 12 : 2; // simdgroups per threadgroup (a.k.a. warps)
|
const int64_t nsg = ne01 < 4 ? 12 : 4; // simdgroups per threadgroup (a.k.a. warps)
|
||||||
|
|
||||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||||
const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values)
|
const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values)
|
||||||
|
@ -2230,7 +2230,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// reduce the warps
|
// reduce the warps
|
||||||
// TODO: try parallel reduce
|
#if 1
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
half S = { 0.0h };
|
half S = { 0.0h };
|
||||||
half M = { -INFINITY };
|
half M = { -INFINITY };
|
||||||
@ -2261,6 +2261,46 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
// parallel reduce
|
||||||
|
// NOTE: this is significantly slower than the serial version above, likely due to the small number of warps
|
||||||
|
{
|
||||||
|
half S = { 0.0h };
|
||||||
|
half M = { -INFINITY };
|
||||||
|
|
||||||
|
for (int64_t sg = nsg/2; sg > 0; sg /= 2) {
|
||||||
|
if (sgitg >= sg) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
|
const half S0 = ss[j*T + 0];
|
||||||
|
const half S1 = ss[j*T + sg*(D + 1*C) + 0];
|
||||||
|
|
||||||
|
const half M0 = ss[j*T + 1];
|
||||||
|
const half M1 = ss[j*T + sg*(D + 1*C) + 1];
|
||||||
|
|
||||||
|
M = max(M0, M1);
|
||||||
|
|
||||||
|
const half ms0 = exp(M0 - M);
|
||||||
|
const half ms1 = exp(M1 - M);
|
||||||
|
|
||||||
|
S = S0*ms0 + S1*ms1;
|
||||||
|
|
||||||
|
if (tiisg == 0) {
|
||||||
|
ss[j*T + 0] = S;
|
||||||
|
ss[j*T + 1] = M;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < L4; ++i) {
|
||||||
|
ps4[j*T4 + N4*i + tiisg] = ps4[j*T4 + N4*i + tiisg]*ms0 + ps4[j*T4 + sg*(D + 1*C)/4 + N4*i + tiisg]*ms1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user