mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-24 10:29:21 +01:00
Another very minor speedup on metal
This commit is contained in:
parent
2cb47e0e16
commit
e3ff8c20c8
@ -133,19 +133,24 @@ kernel void kernel_soft_max(
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
|
|
||||||
// broadcast
|
//// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
|
||||||
if (tpitg[0] == 0) {
|
// the loop, and when that is done, buf[0] has the correct (synchronized) value
|
||||||
buf[0] = buf[0];
|
//if (tpitg[0] == 0) {
|
||||||
}
|
// buf[0] = buf[0];
|
||||||
|
//}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
const float max = buf[0];
|
const float max = buf[0];
|
||||||
|
|
||||||
// parallel sum
|
// parallel sum
|
||||||
buf[tpitg[0]] = 0.0f;
|
buf[tpitg[0]] = 0.0f;
|
||||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||||
buf[tpitg[0]] += exp(psrc0[i00] - max);
|
const float exp_psrc0 = exp(psrc0[i00] - max);
|
||||||
|
buf[tpitg[0]] += exp_psrc0;
|
||||||
|
// Remember the result of exp here. exp is expensive, so we really do not
|
||||||
|
// whish to compute it twice.
|
||||||
|
pdst[i00] = exp_psrc0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// reduce
|
// reduce
|
||||||
@ -157,17 +162,18 @@ kernel void kernel_soft_max(
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
|
|
||||||
// broadcast
|
// broadcast - not needed, see above
|
||||||
if (tpitg[0] == 0) {
|
//// broadcast
|
||||||
buf[0] = buf[0];
|
//if (tpitg[0] == 0) {
|
||||||
}
|
// buf[0] = buf[0];
|
||||||
|
//}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
const float sum = buf[0];
|
const float sum = buf[0];
|
||||||
|
|
||||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||||
pdst[i00] = exp(psrc0[i00] - max) / sum;
|
pdst[i00] /= sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -214,25 +220,27 @@ kernel void kernel_norm(
|
|||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
// broadcast
|
//// broadcast
|
||||||
if (tpitg == 0) {
|
//if (tpitg == 0) {
|
||||||
sum[0] /= ne00;
|
// sum[0] /= ne00;
|
||||||
}
|
//}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
const float mean = sum[0];
|
const float mean = sum[0];
|
||||||
|
|
||||||
// recenter
|
// recenter and VARIANCE
|
||||||
device float * y = dst + tgpig*ne00;
|
device float * y = dst + tgpig*ne00;
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
||||||
y[i00] = x[i00] - mean;
|
|
||||||
}
|
|
||||||
|
|
||||||
// VARIANCE
|
|
||||||
// parallel sum
|
|
||||||
sum[tpitg] = 0.0f;
|
sum[tpitg] = 0.0f;
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
|
y[i00] = x[i00] - mean;
|
||||||
sum[tpitg] += y[i00] * y[i00];
|
sum[tpitg] += y[i00] * y[i00];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//// VARIANCE
|
||||||
|
//// parallel sum
|
||||||
|
//sum[tpitg] = 0.0f;
|
||||||
|
//for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
|
// sum[tpitg] += y[i00] * y[i00];
|
||||||
|
//}
|
||||||
// reduce
|
// reduce
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
for (uint i = ntg/2; i > 0; i /= 2) {
|
for (uint i = ntg/2; i > 0; i /= 2) {
|
||||||
@ -241,11 +249,11 @@ kernel void kernel_norm(
|
|||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
// broadcast
|
//// broadcast
|
||||||
if (tpitg == 0) {
|
//if (tpitg == 0) {
|
||||||
sum[0] /= ne00;
|
// sum[0] /= ne00;
|
||||||
}
|
//}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
const float variance = sum[0];
|
const float variance = sum[0];
|
||||||
|
|
||||||
const float scale = 1.0f/sqrt(variance + eps);
|
const float scale = 1.0f/sqrt(variance + eps);
|
||||||
|
Loading…
Reference in New Issue
Block a user