mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-01 00:39:00 +01:00
metal : switch to execution barriers + fix one of the barriers
This commit is contained in:
parent
109e7aa8ac
commit
e1241d9b46
@ -385,8 +385,11 @@ kernel void kernel_soft_max(
|
|||||||
pdst[i00] = exp_psrc0;
|
pdst[i00] = exp_psrc0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This barrier fixes a failing test
|
||||||
|
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
float sum = simd_sum(lsum);
|
float sum = simd_sum(lsum);
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
if (ntg > N_SIMDWIDTH) {
|
if (ntg > N_SIMDWIDTH) {
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
@ -470,9 +473,13 @@ kernel void kernel_soft_max_4(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
// This barrier fixes a failing test
|
||||||
|
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
float sum = simd_sum(lsum);
|
float sum = simd_sum(lsum);
|
||||||
|
|
||||||
if (ntg > N_SIMDWIDTH) {
|
if (ntg > N_SIMDWIDTH) {
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
buf[tiisg] = 0.0f;
|
buf[tiisg] = 0.0f;
|
||||||
|
Loading…
Reference in New Issue
Block a user