mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 15:18:26 +01:00
metal : optimize softmax
This commit is contained in:
parent
56e45a239e
commit
cda5a60a41
@ -2285,8 +2285,9 @@ static bool ggml_metal_graph_compute(
|
|||||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||||
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||||
|
|
||||||
GGML_ASSERT(nqptg % 8 == 0);
|
GGML_ASSERT(nqptg <= 32);
|
||||||
GGML_ASSERT(ncpsg % 32 == 0);
|
GGML_ASSERT(nqptg % 8 == 0);
|
||||||
|
GGML_ASSERT(ncpsg % 32 == 0);
|
||||||
|
|
||||||
// simdgroups per threadgroup (a.k.a. warps)
|
// simdgroups per threadgroup (a.k.a. warps)
|
||||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||||
|
@ -2188,6 +2188,8 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
|
|
||||||
// online softmax
|
// online softmax
|
||||||
if (C == 32) {
|
if (C == 32) {
|
||||||
|
half ms[Q];
|
||||||
|
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
const int64_t p = tiisg;
|
const int64_t p = tiisg;
|
||||||
|
|
||||||
@ -2197,20 +2199,22 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
smax = simd_max(max(smax, s));
|
smax = simd_max(max(smax, s));
|
||||||
M[j] = simd_max(max(M[j], s));
|
M[j] = simd_max(max(M[j], s));
|
||||||
|
|
||||||
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
||||||
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
||||||
|
|
||||||
S[j] = S[j]*ms + simd_sum(vs);
|
S[j] = S[j]*ms[j] + simd_sum(vs);
|
||||||
|
|
||||||
// create a QxQ diagonal matrix for rescaling the output
|
|
||||||
if (p == j) {
|
|
||||||
ss[j*T + C + j] = ms;
|
|
||||||
}
|
|
||||||
|
|
||||||
// the P matrix from the paper (Q rows, C columns)
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
ss[j*T + p] = vs;
|
ss[j*T + p] = vs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create a QxQ diagonal matrix for rescaling the output
|
||||||
|
if (tiisg < Q) {
|
||||||
|
ss[tiisg*T + C + tiisg] = ms[tiisg];
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
|
half ms[Q];
|
||||||
|
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
const half m = M[j];
|
const half m = M[j];
|
||||||
|
|
||||||
@ -2224,12 +2228,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
smax = simd_max(smax);
|
smax = simd_max(smax);
|
||||||
M[j] = simd_max(M[j]);
|
M[j] = simd_max(M[j]);
|
||||||
|
|
||||||
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
||||||
|
|
||||||
// create a QxQ diagonal matrix for rescaling the output
|
|
||||||
if (tiisg == j) {
|
|
||||||
ss[j*T + C + j] = ms;
|
|
||||||
}
|
|
||||||
|
|
||||||
// local sum
|
// local sum
|
||||||
half ls = 0.0h;
|
half ls = 0.0h;
|
||||||
@ -2245,7 +2244,12 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
ss[j*T + p] = vs;
|
ss[j*T + p] = vs;
|
||||||
}
|
}
|
||||||
|
|
||||||
S[j] = S[j]*ms + simd_sum(ls);
|
S[j] = S[j]*ms[j] + simd_sum(ls);
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a QxQ diagonal matrix for rescaling the output
|
||||||
|
if (tiisg < Q) {
|
||||||
|
ss[tiisg*T + C + tiisg] = ms[tiisg];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user