mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 15:18:26 +01:00
metal : reduce branches
This commit is contained in:
parent
528da7515e
commit
52ae085750
@ -2056,40 +2056,26 @@ kernel void kernel_flash_attn_ext_f16(
|
||||
continue;
|
||||
}
|
||||
|
||||
half4 s4 = 0.0f;
|
||||
device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13));
|
||||
device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
device const half4 * pk4 = (device const half4 *) ((device char *) k + ( ic*nb11 + ik2*nb12 + ik3*nb13));
|
||||
half4 s4 = 0.0h;
|
||||
|
||||
for (int64_t d = 0; d < D4; ++d) {
|
||||
s4 += pk4[d] * pq4[d];
|
||||
}
|
||||
|
||||
half s = s4.x + s4.y + s4.z + s4.w;
|
||||
|
||||
s = s*scale + mv;
|
||||
half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv;
|
||||
|
||||
const half Mold = M;
|
||||
|
||||
half ms = 1.0f;
|
||||
half vs = 1.0f;
|
||||
M = max(M, s);
|
||||
|
||||
if (s > M) {
|
||||
M = s;
|
||||
ms = exp(Mold - M);
|
||||
const half ms = exp(Mold - M);
|
||||
const half vs = exp(s - M);
|
||||
|
||||
// V = V*exp(Mold - M)
|
||||
for (int64_t d = 0; d < D4; ++d) {
|
||||
V16[d] *= ms;
|
||||
}
|
||||
} else {
|
||||
vs = exp(s - M);
|
||||
}
|
||||
|
||||
device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
// V += v*exp(s - M)
|
||||
for (int64_t d = 0; d < D4; ++d) {
|
||||
V16[d] += pv4[d] * vs;
|
||||
V16[d] = V16[d]*ms + pv4[d]*vs;
|
||||
}
|
||||
|
||||
S = S*ms + vs;
|
||||
|
Loading…
Reference in New Issue
Block a user