mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-02-06 16:40:34 +01:00
iq3_s_mult: ARM_NEON works - 13 t/s
This commit is contained in:
parent
0fe9cd488f
commit
bf90920fb2
@ -10048,9 +10048,18 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||||||
|
|
||||||
static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
|
static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
|
||||||
|
|
||||||
|
static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
|
||||||
|
|
||||||
const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1);
|
const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1);
|
||||||
const uint8x16_t mask2 = vld1q_u8(k_mask2);
|
const uint8x16_t mask2 = vld1q_u8(k_mask2);
|
||||||
|
|
||||||
|
const uint32x4_t idx_mult = vdupq_n_u32(IQ3S_MULTIPLIER);
|
||||||
|
const int16x8_t idx_shift = vld1q_s16(k_shift);
|
||||||
|
const uint16x8_t idx_mask1 = vdupq_n_u16(256);
|
||||||
|
const uint32x4_t idx_mask2 = vdupq_n_u32(0x0f0f0f0f);
|
||||||
|
const int8x16_t m1 = vdupq_n_s8(1);
|
||||||
|
const int8x16_t m0 = vdupq_n_s8(0);
|
||||||
|
|
||||||
uint8x16x2_t vs;
|
uint8x16x2_t vs;
|
||||||
ggml_int8x16x4_t q3s;
|
ggml_int8x16x4_t q3s;
|
||||||
ggml_int8x16x4_t q8b;
|
ggml_int8x16x4_t q8b;
|
||||||
@ -10065,35 +10074,39 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||||||
int sumi1 = 0, sumi2 = 0;
|
int sumi1 = 0, sumi2 = 0;
|
||||||
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
||||||
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
|
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
|
||||||
const uint32x4_t aux32x4_0 = {iq3xs_grid[qs[ 0] | ((qh[ib32+0] << 8) & 256)], iq3xs_grid[qs[ 1] | ((qh[ib32+0] << 7) & 256)],
|
const uint8x16_t idx_l = vld1q_u8(qs); qs += 16;
|
||||||
iq3xs_grid[qs[ 2] | ((qh[ib32+0] << 6) & 256)], iq3xs_grid[qs[ 3] | ((qh[ib32+0] << 5) & 256)]};
|
const uint16x8_t idx_1 = vorrq_u16(vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), idx_shift), idx_mask1),
|
||||||
const uint32x4_t aux32x4_1 = {iq3xs_grid[qs[ 4] | ((qh[ib32+0] << 4) & 256)], iq3xs_grid[qs[ 5] | ((qh[ib32+0] << 3) & 256)],
|
vmovl_u8(vget_low_u8(idx_l)));
|
||||||
iq3xs_grid[qs[ 6] | ((qh[ib32+0] << 2) & 256)], iq3xs_grid[qs[ 7] | ((qh[ib32+0] << 1) & 256)]};
|
const uint16x8_t idx_2 = vorrq_u16(vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), idx_shift), idx_mask1),
|
||||||
const uint32x4_t aux32x4_2 = {iq3xs_grid[qs[ 8] | ((qh[ib32+1] << 8) & 256)], iq3xs_grid[qs[ 9] | ((qh[ib32+1] << 7) & 256)],
|
vmovl_u8(vget_high_u8(idx_l)));
|
||||||
iq3xs_grid[qs[10] | ((qh[ib32+1] << 6) & 256)], iq3xs_grid[qs[11] | ((qh[ib32+1] << 5) & 256)]};
|
q3s.val[0] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2));
|
||||||
const uint32x4_t aux32x4_3 = {iq3xs_grid[qs[12] | ((qh[ib32+1] << 4) & 256)], iq3xs_grid[qs[13] | ((qh[ib32+1] << 3) & 256)],
|
q3s.val[1] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2));
|
||||||
iq3xs_grid[qs[14] | ((qh[ib32+1] << 2) & 256)], iq3xs_grid[qs[15] | ((qh[ib32+1] << 1) & 256)]};
|
q3s.val[2] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2));
|
||||||
qs += 16;
|
q3s.val[3] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_2))), idx_mask2));
|
||||||
|
q3s.val[0] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[0], m1), m0), 1), 1), m1);
|
||||||
|
q3s.val[1] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[1], m1), m0), 1), 1), m1);
|
||||||
|
q3s.val[2] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[2], m1), m0), 1), 1), m1);
|
||||||
|
q3s.val[3] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[3], m1), m0), 1), 1), m1);
|
||||||
|
|
||||||
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
|
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
|
||||||
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
||||||
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
||||||
vs.val[0] = vceqq_u8(vs.val[0], mask2);
|
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), vreinterpretq_u8_s8(m1));
|
||||||
vs.val[1] = vceqq_u8(vs.val[1], mask2);
|
vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), vreinterpretq_u8_s8(m1));
|
||||||
|
|
||||||
q3s.val[0] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_0))), vreinterpretq_s8_u8(vs.val[0]));
|
q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), q3s.val[0]);
|
||||||
q3s.val[1] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_1))), vreinterpretq_s8_u8(vs.val[1]));
|
q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), q3s.val[1]);
|
||||||
|
|
||||||
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
|
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
|
||||||
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
||||||
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
||||||
vs.val[0] = vceqq_u8(vs.val[0], mask2);
|
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), vreinterpretq_u8_s8(m1));
|
||||||
vs.val[1] = vceqq_u8(vs.val[1], mask2);
|
vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), vreinterpretq_u8_s8(m1));
|
||||||
|
|
||||||
signs += 4;
|
signs += 4;
|
||||||
|
|
||||||
q3s.val[2] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_2))), vreinterpretq_s8_u8(vs.val[0]));
|
q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), q3s.val[2]);
|
||||||
q3s.val[3] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_3))), vreinterpretq_s8_u8(vs.val[1]));
|
q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), q3s.val[3]);
|
||||||
|
|
||||||
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
|
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
|
||||||
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
|
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
|
||||||
@ -10102,7 +10115,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||||||
}
|
}
|
||||||
sumf += d*(sumi1 + sumi2);
|
sumf += d*(sumi1 + sumi2);
|
||||||
}
|
}
|
||||||
*s = 0.25f * sumf;
|
*s = sumf;
|
||||||
|
|
||||||
#elif defined(__AVX2__)
|
#elif defined(__AVX2__)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user