metal : fix kernel_norm

ggml-ci
This commit is contained in:
Georgi Gerganov 2023-09-07 14:11:21 +03:00
parent fec2fb19e4
commit 5e1c4089d8
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 24 additions and 23 deletions

View File

@ -995,8 +995,12 @@ void ggml_metal_graph_compute(
else if (src0t == GGML_TYPE_Q6_K) { else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else { } else {
int64_t ny = (ne11 + 3)/4; [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
// TODO: this breaks for Q4_0 - understand why and fix it
//int64_t ny = (ne11 + 3)/4;
//[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
} }
} break; } break;

View File

@ -220,27 +220,26 @@ kernel void kernel_norm(
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
} }
//// broadcast // broadcast
//if (tpitg == 0) { if (tpitg == 0) {
// sum[0] /= ne00; sum[0] /= ne00;
//} }
//threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
const float mean = sum[0]; const float mean = sum[0];
// recenter and VARIANCE // recenter
device float * y = dst + tgpig*ne00; device float * y = dst + tgpig*ne00;
sum[tpitg] = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) { for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
y[i00] = x[i00] - mean; y[i00] = x[i00] - mean;
}
// VARIANCE
// parallel sum
sum[tpitg] = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
sum[tpitg] += y[i00] * y[i00]; sum[tpitg] += y[i00] * y[i00];
} }
//// VARIANCE
//// parallel sum
//sum[tpitg] = 0.0f;
//for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
// sum[tpitg] += y[i00] * y[i00];
//}
// reduce // reduce
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = ntg/2; i > 0; i /= 2) { for (uint i = ntg/2; i > 0; i /= 2) {
@ -249,11 +248,11 @@ kernel void kernel_norm(
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
} }
//// broadcast // broadcast
//if (tpitg == 0) { if (tpitg == 0) {
// sum[0] /= ne00; sum[0] /= ne00;
//} }
//threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
const float variance = sum[0]; const float variance = sum[0];
const float scale = 1.0f/sqrt(variance + eps); const float scale = 1.0f/sqrt(variance + eps);
@ -262,7 +261,6 @@ kernel void kernel_norm(
} }
} }
kernel void kernel_rms_norm( kernel void kernel_rms_norm(
device const void * src0, device const void * src0,
device float * dst, device float * dst,
@ -630,7 +628,6 @@ kernel void kernel_mul_mat_f16_f32(
} }
} }
} }
} }
kernel void kernel_alibi_f32( kernel void kernel_alibi_f32(