From a6a263b919a89cd5523286c1c09f630529967788 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 4 Mar 2024 20:10:36 +0200 Subject: [PATCH] iq3_s_mult_shuffle: works on ARM_NEON and Metal --- ggml-metal.metal | 28 ++++++++++++++++------------ ggml-quants.c | 9 +++++---- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 69a928c24..550bc682e 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2550,7 +2550,9 @@ typedef struct { #define IQ3S_MULTIPLIER 190842953 #else //#define IQ3S_MULTIPLIER 898886 -#define IQ3S_MULTIPLIER 842866 +//#define IQ3S_MULTIPLIER 842866 +#define IQ3S_MULTIPLIER 72968561ULL +constexpr constant static uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15}; #endif typedef struct { @@ -4698,9 +4700,9 @@ void kernel_mul_mv_iq3_s_f32_impl( { int nval = 8; int pos = (32*sgitg + tiisg)*nval; -#ifdef IQ3S_SLOW_MULT uint32_t aux32; thread int8_t * q = (thread int8_t *)&aux32; +#ifdef IQ3S_SLOW_MULT for (int i = 0; i < nval; ++i) { aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f; for (int k = 0; k < 4; ++k) q[k] = 2*((q[k]-1)/2) + 1; @@ -4708,7 +4710,9 @@ void kernel_mul_mv_iq3_s_f32_impl( } #else for (int i = 0; i < nval; ++i) { - values[pos + i] = ((IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f) | 0x01010101; + aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f; + for (int k = 0; k < 4; ++k) q[k] = iq3s_values[q[k]]; + values[pos + i] = aux32; } #endif threadgroup_barrier(mem_flags::mem_threadgroup); @@ -5667,7 +5671,7 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); uint32_t aux32[2]; thread const int8_t * grid = (thread const int8_t *)aux32; -#ifdef IQ3S_SLOW)MULT +#ifdef IQ3S_SLOW_MULT aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f; aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f; for (int i = 0; i < 4; ++i) { @@ -5681,17 +5685,17 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg[3][i] = dl * (2*((grid[i+4]-1)/2)+1) * select(1, -1, signs[1] & kmask_iq2xs[i+4]); } #else - aux32[0] = ((IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f) | 0x01010101; - aux32[1] = ((IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f) | 0x01010101; + aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f; + aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f; for (int i = 0; i < 4; ++i) { - reg[0][i] = dl * grid[i+0] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); - reg[1][i] = dl * grid[i+4] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); + reg[0][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); + reg[1][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); } - aux32[0] = ((IQ3S_MULTIPLIER * (qs[4*il+2] | ((qh << 6) & 256))) & 0x0f0f0f0f) | 0x01010101; - aux32[1] = ((IQ3S_MULTIPLIER * (qs[4*il+3] | ((qh << 5) & 256))) & 0x0f0f0f0f) | 0x01010101; + aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+2] | ((qh << 6) & 256))) & 0x0f0f0f0f; + aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+3] | ((qh << 5) & 256))) & 0x0f0f0f0f; for (int i = 0; i < 4; ++i) { - reg[2][i] = dl * grid[i+0] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); - reg[3][i] = dl * grid[i+4] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); + reg[2][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); + reg[3][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); } #endif } diff --git a/ggml-quants.c b/ggml-quants.c index 0535f5c25..e7fc0f85c 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -10005,6 +10005,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1); const uint8x16_t mask2 = vld1q_u8(k_mask2); + const uint8x16_t shuff = vld1q_u8(iq3s_values); const uint32x4_t idx_mult = vdupq_n_u32(IQ3S_MULTIPLIER); const int16x8_t idx_shift = vld1q_s16(k_shift); @@ -10042,10 +10043,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v q3s.val[2] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[2], m1), m0), 1), 1), m1); q3s.val[3] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[3], m1), m0), 1), 1), m1); #else - q3s.val[0] = vorrq_s8(vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2)), m1); - q3s.val[1] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2)), m1); - q3s.val[2] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2)), m1); - q3s.val[3] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_2))), idx_mask2)), m1); + q3s.val[0] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2))); + q3s.val[1] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2))); + q3s.val[2] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2))); + q3s.val[3] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_2))), idx_mask2))); #endif vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));