mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-19 08:20:10 +01:00
metal : fix kernel_norm
ggml-ci
This commit is contained in:
parent
fec2fb19e4
commit
5e1c4089d8
@ -995,8 +995,12 @@ void ggml_metal_graph_compute(
|
|||||||
else if (src0t == GGML_TYPE_Q6_K) {
|
else if (src0t == GGML_TYPE_Q6_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
} else {
|
} else {
|
||||||
int64_t ny = (ne11 + 3)/4;
|
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
|
|
||||||
|
// TODO: this breaks for Q4_0 - understand why and fix it
|
||||||
|
//int64_t ny = (ne11 + 3)/4;
|
||||||
|
//[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
@ -220,27 +220,26 @@ kernel void kernel_norm(
|
|||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
//// broadcast
|
// broadcast
|
||||||
//if (tpitg == 0) {
|
if (tpitg == 0) {
|
||||||
// sum[0] /= ne00;
|
sum[0] /= ne00;
|
||||||
//}
|
}
|
||||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
const float mean = sum[0];
|
const float mean = sum[0];
|
||||||
|
|
||||||
// recenter and VARIANCE
|
// recenter
|
||||||
device float * y = dst + tgpig*ne00;
|
device float * y = dst + tgpig*ne00;
|
||||||
sum[tpitg] = 0.0f;
|
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
y[i00] = x[i00] - mean;
|
y[i00] = x[i00] - mean;
|
||||||
|
}
|
||||||
|
|
||||||
|
// VARIANCE
|
||||||
|
// parallel sum
|
||||||
|
sum[tpitg] = 0.0f;
|
||||||
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
sum[tpitg] += y[i00] * y[i00];
|
sum[tpitg] += y[i00] * y[i00];
|
||||||
}
|
}
|
||||||
|
|
||||||
//// VARIANCE
|
|
||||||
//// parallel sum
|
|
||||||
//sum[tpitg] = 0.0f;
|
|
||||||
//for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
||||||
// sum[tpitg] += y[i00] * y[i00];
|
|
||||||
//}
|
|
||||||
// reduce
|
// reduce
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
for (uint i = ntg/2; i > 0; i /= 2) {
|
for (uint i = ntg/2; i > 0; i /= 2) {
|
||||||
@ -249,11 +248,11 @@ kernel void kernel_norm(
|
|||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
//// broadcast
|
// broadcast
|
||||||
//if (tpitg == 0) {
|
if (tpitg == 0) {
|
||||||
// sum[0] /= ne00;
|
sum[0] /= ne00;
|
||||||
//}
|
}
|
||||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
const float variance = sum[0];
|
const float variance = sum[0];
|
||||||
|
|
||||||
const float scale = 1.0f/sqrt(variance + eps);
|
const float scale = 1.0f/sqrt(variance + eps);
|
||||||
@ -262,7 +261,6 @@ kernel void kernel_norm(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
kernel void kernel_rms_norm(
|
kernel void kernel_rms_norm(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -630,7 +628,6 @@ kernel void kernel_mul_mat_f16_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_alibi_f32(
|
kernel void kernel_alibi_f32(
|
||||||
|
Loading…
Reference in New Issue
Block a user