mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-28 12:57:03 +01:00
metal : improve precision
This commit is contained in:
parent
ecc466a460
commit
3a428a1097
@ -2120,7 +2120,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
device const float * mp = (device const float *) (mask + (ir%ne31)*nb31);
|
device const float * mp = (device const float *) (mask + (ir%ne31)*nb31);
|
||||||
|
|
||||||
// prepare diagonal scale matrix
|
// prepare diagonal scale matrix
|
||||||
simdgroup_half8x8 mscale(scale);
|
simdgroup_float8x8 mscale(scale);
|
||||||
|
|
||||||
// loop over the KV cache
|
// loop over the KV cache
|
||||||
// each simdgroup handles blocks of Q rows and C columns
|
// each simdgroup handles blocks of Q rows and C columns
|
||||||
@ -2163,7 +2163,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
M[j] = simd_max(max(M[j], s));
|
M[j] = simd_max(max(M[j], s));
|
||||||
}
|
}
|
||||||
|
|
||||||
const half ms = exp(m - M[j]);
|
const half ms = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
||||||
|
|
||||||
S[j] = S[j]*ms;
|
S[j] = S[j]*ms;
|
||||||
|
|
||||||
@ -2175,7 +2175,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
for (int64_t p = tiisg; p < C; p += NW) {
|
for (int64_t p = tiisg; p < C; p += NW) {
|
||||||
const half s = ss[j*T + p];
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
const half vs = exp(s - M[j]);
|
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
||||||
|
|
||||||
S[j] = S[j] + simd_sum(vs);
|
S[j] = S[j] + simd_sum(vs);
|
||||||
|
|
||||||
@ -2255,8 +2255,8 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
|
|
||||||
M = max(M0, M1);
|
M = max(M0, M1);
|
||||||
|
|
||||||
const half ms0 = exp(M0 - M);
|
const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M);
|
||||||
const half ms1 = exp(M1 - M);
|
const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M);
|
||||||
|
|
||||||
S = S0*ms0 + S1*ms1;
|
S = S0*ms0 + S1*ms1;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user