From 7b629c3b65f4d49f8122fc951e00de261a131e54 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 1 Mar 2024 17:46:33 +0200 Subject: [PATCH] iq3_s: minor improvement on Metal 49.4 t/s -> 50.3 t/s --- ggml-metal.metal | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 74a5e0b03..7051c50208 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -4773,8 +4773,10 @@ void kernel_mul_mv_iq3_s_f32_impl( float2 sum = {0}; for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); - const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); + const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values; + const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values; + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); for (int j = 0; j < 4; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);