mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-03 17:51:09 +01:00
iq3_s: another small ARM_NEON improvement
10.7 -> 11.0 t/s. Using vmulq_s8 is faster than the xor - sub trick that works best on AVX2.
This commit is contained in:
parent
1e94989156
commit
9c5b594cde
@ -10106,6 +10106,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||||||
const uint8x16_t mask2 = vld1q_u8(k_mask2);
|
const uint8x16_t mask2 = vld1q_u8(k_mask2);
|
||||||
const int16x8_t hshift = vld1q_s16(k_shift);
|
const int16x8_t hshift = vld1q_s16(k_shift);
|
||||||
const uint16x8_t m256 = vdupq_n_u16(256);
|
const uint16x8_t m256 = vdupq_n_u16(256);
|
||||||
|
const uint16x8_t m1 = vdupq_n_u8(1);
|
||||||
|
|
||||||
uint8x16x2_t vs;
|
uint8x16x2_t vs;
|
||||||
ggml_int8x16x4_t q3s;
|
ggml_int8x16x4_t q3s;
|
||||||
@ -10151,22 +10152,22 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||||||
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
|
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
|
||||||
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
||||||
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
||||||
vs.val[0] = vceqq_u8(vs.val[0], mask2);
|
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
|
||||||
vs.val[1] = vceqq_u8(vs.val[1], mask2);
|
vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);
|
||||||
|
|
||||||
q3s.val[0] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_0))), vreinterpretq_s8_u8(vs.val[0]));
|
q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0));
|
||||||
q3s.val[1] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_1))), vreinterpretq_s8_u8(vs.val[1]));
|
q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1));
|
||||||
|
|
||||||
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
|
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
|
||||||
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
||||||
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
||||||
vs.val[0] = vceqq_u8(vs.val[0], mask2);
|
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
|
||||||
vs.val[1] = vceqq_u8(vs.val[1], mask2);
|
vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);
|
||||||
|
|
||||||
signs += 4;
|
signs += 4;
|
||||||
|
|
||||||
q3s.val[2] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_2))), vreinterpretq_s8_u8(vs.val[0]));
|
q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2));
|
||||||
q3s.val[3] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_3))), vreinterpretq_s8_u8(vs.val[1]));
|
q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3));
|
||||||
|
|
||||||
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
|
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
|
||||||
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
|
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
|
||||||
|
Loading…
Reference in New Issue
Block a user