From f238461236f4e0e18cac1a554af23c7deadc9b01 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Jan 2024 14:02:30 +0200 Subject: [PATCH] ggml : fix 32-bit ARM compat for IQ2_XS (whisper/1758) * ggml : fix 32-bit ARM compat * ggml : fix fix * ggml : fix fix fix --- ggml-quants.c | 39 +++++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/ggml-quants.c b/ggml-quants.c index a24b4b244..601d155d7 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -272,10 +272,13 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 // vaddvq_s16 // vpaddq_s16 +// vpaddq_s32 // vaddvq_s32 // vaddvq_f32 // vmaxvq_f32 // vcvtnq_s32_f32 +// vzip1_u8 +// vzip2_u8 inline static int32_t vaddvq_s16(int16x8_t v) { return @@ -291,6 +294,12 @@ inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { return vcombine_s16(a0, b0); } +inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) { + int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a)); + int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b)); + return vcombine_s32(a0, b0); +} + inline static int32_t vaddvq_s32(int32x4_t v) { return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); } @@ -316,6 +325,28 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { return res; } +inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) { + uint8x8_t res; + + res[0] = a[0]; res[1] = b[0]; + res[2] = a[1]; res[3] = b[1]; + res[4] = a[2]; res[5] = b[2]; + res[6] = a[3]; res[7] = b[3]; + + return res; +} + +inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) { + uint8x8_t res; + + res[0] = a[4]; res[1] = b[4]; + res[2] = a[5]; res[3] = b[5]; + res[4] = a[6]; res[5] = b[6]; + res[6] = a[7]; res[7] = b[7]; + + return res; +} + // vld1q_s16_x2 // vld1q_u8_x2 // vld1q_u8_x4 @@ -7554,9 +7585,9 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - int8x16x4_t q2u; - int8x16x4_t q2s; - int8x16x4_t q8b; + ggml_int8x16x4_t q2u; + ggml_int8x16x4_t q2s; + ggml_int8x16x4_t q8b; int32x4x4_t scales32; @@ -7578,7 +7609,7 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2))); int32x4_t sumi = vdupq_n_s32(0); for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { - q8b = vld1q_s8_x4(q8); q8 += 64; + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511)))); q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511)))); q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));