diff --git a/ggml-metal.m b/ggml-metal.m index 521ca180f..5135e1cbb 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -995,8 +995,12 @@ void ggml_metal_graph_compute( else if (src0t == GGML_TYPE_Q6_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { - int64_t ny = (ne11 + 3)/4; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + + // TODO: this breaks for Q4_0 - understand why and fix it + //int64_t ny = (ne11 + 3)/4; + //[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 119fcbeb6..a107e7d97 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -111,7 +111,7 @@ kernel void kernel_soft_max( uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; + const int64_t i03 = tgpig[2]; const int64_t i02 = tgpig[1]; const int64_t i01 = tgpig[0]; @@ -220,27 +220,26 @@ kernel void kernel_norm( } threadgroup_barrier(mem_flags::mem_threadgroup); } - //// broadcast - //if (tpitg == 0) { - // sum[0] /= ne00; - //} - //threadgroup_barrier(mem_flags::mem_threadgroup); + // broadcast + if (tpitg == 0) { + sum[0] /= ne00; + } + threadgroup_barrier(mem_flags::mem_threadgroup); const float mean = sum[0]; - // recenter and VARIANCE + // recenter device float * y = dst + tgpig*ne00; - sum[tpitg] = 0.0f; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { y[i00] = x[i00] - mean; + } + + // VARIANCE + // parallel sum + sum[tpitg] = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { sum[tpitg] += y[i00] * y[i00]; } - //// VARIANCE - //// parallel sum - //sum[tpitg] = 0.0f; - //for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - // sum[tpitg] += y[i00] * y[i00]; - //} // reduce threadgroup_barrier(mem_flags::mem_threadgroup); for (uint i = ntg/2; i > 0; i /= 2) { @@ -249,11 +248,11 @@ kernel void kernel_norm( } threadgroup_barrier(mem_flags::mem_threadgroup); } - //// broadcast - //if (tpitg == 0) { - // sum[0] /= ne00; - //} - //threadgroup_barrier(mem_flags::mem_threadgroup); + // broadcast + if (tpitg == 0) { + sum[0] /= ne00; + } + threadgroup_barrier(mem_flags::mem_threadgroup); const float variance = sum[0]; const float scale = 1.0f/sqrt(variance + eps); @@ -262,7 +261,6 @@ kernel void kernel_norm( } } - kernel void kernel_rms_norm( device const void * src0, device float * dst, @@ -630,7 +628,6 @@ kernel void kernel_mul_mat_f16_f32( } } } - } kernel void kernel_alibi_f32(