From 5c977221d218053a01ca61841d3b3e7d550d28d2 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 13 Feb 2024 15:18:27 +0200 Subject: [PATCH] iq1_s: slightly faster dot product --- ggml-metal.metal | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 235a3c7cf..1e52f65ad 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -4344,19 +4344,20 @@ void kernel_mul_mv_iq1_s_f32_impl( device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - float yl[32]; + float yl[16]; float sumf[N_DST]={0.f}, all_sum; const int nb32 = nb * (QK_K / 32); #if QK_K == 256 - const int ix = tiisg; + const int ix = tiisg/2; + const int il = tiisg%2; - device const float * y4 = y + 32 * ix; + device const float * y4 = y + 32 * ix + 16 * il; - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + for (int ib32 = ix; ib32 < nb32; ib32 += 16) { - for (int i = 0; i < 32; ++i) { + for (int i = 0; i < 16; ++i) { yl[i] = y4[i]; } @@ -4364,33 +4365,28 @@ void kernel_mul_mv_iq1_s_f32_impl( const int ib = ib32 % (QK_K / 32); device const block_iq1_s * xr = x + ibl; - device const uint8_t * qs = xr->qs + 4 * ib; - device const uint8_t * sc = xr->scales + 2 * ib; + device const uint8_t * qs = xr->qs + 4 * ib + 2 * il; + device const uint8_t * sc = xr->scales + 2 * ib + il; device const half * dh = &xr->d; for (int row = 0; row < N_DST; row++) { constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5))); constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1))); - constant int8_t * grid3 = (constant int8_t *)(iq1s_grid + (qs[2] | ((sc[1] & 0x08) << 5))); - constant int8_t * grid4 = (constant int8_t *)(iq1s_grid + (qs[3] | ((sc[1] & 0x80) << 1))); - float4 sum = {0}; + float2 sum = {0}; for (int j = 0; j < 8; ++j) { sum[0] += yl[j+ 0] * grid1[j]; sum[1] += yl[j+ 8] * grid2[j]; - sum[2] += yl[j+16] * grid3[j]; - sum[3] += yl[j+24] * grid4[j]; } - sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1) + - sum[2] * (2*(sc[1] & 7) + 1) + sum[3] * (2*((sc[1] >> 4) & 7) + 1)); + sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1)); dh += nb*sizeof(block_iq1_s)/2; qs += nb*sizeof(block_iq1_s); sc += nb*sizeof(block_iq1_s); } - y4 += 32 * 32; + y4 += 16 * 32; } #else // TODO