mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-15 06:40:45 +01:00
iq1_s: ARM_NEON dot product. Works, but not very fast
This commit is contained in:
parent
2ffb05acc8
commit
773014926f
@ -9303,7 +9303,64 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
#if defined __AVX2__
|
||||
#if defined __ARM_NEON
|
||||
|
||||
const uint8x16_t m8 = vdupq_n_u8(0x08);
|
||||
const uint8x16_t m7 = vdupq_n_u8(0x07);
|
||||
const uint8x16_t m1 = vdupq_n_u8(0x01);
|
||||
const int32x4_t vzero = vdupq_n_s32(0);
|
||||
|
||||
uint16_t gindex[8];
|
||||
uint16x8x2_t vindex;
|
||||
int8x16x4_t q1b;
|
||||
int8x16x4_t q8b;
|
||||
uint16x8x4_t scales;
|
||||
int32x4x2_t sumi;
|
||||
int32x4x2_t dotq;
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
const int8_t * q8 = y[i].qs;
|
||||
const uint8_t * qs = x[i].qs;
|
||||
const uint8_t * sc = x[i].scales;
|
||||
|
||||
sumi.val[0] = sumi.val[1] = vzero;
|
||||
|
||||
for (int i128 = 0; i128 < QK_K/128; ++i128) {
|
||||
const uint8x16_t ql = vld1q_u8(qs); qs += 16;
|
||||
const uint8x8_t tm1 = vld1_u8 (sc); sc += 8;
|
||||
const uint8x8_t tm2 = vshr_n_u8(tm1, 4);
|
||||
const uint8x16_t qh = vcombine_u8(vzip1_u8(tm1, tm2), vzip2_u8(tm1, tm2));
|
||||
const uint8x16_t hbit = vandq_u8(qh, m8);
|
||||
vindex.val[0] = vorrq_u16(vmovl_u8(vget_low_u8 (ql)), vshlq_n_u16(vmovl_u8(vget_low_u8 (hbit)), 5));
|
||||
vindex.val[1] = vorrq_u16(vmovl_u8(vget_high_u8(ql)), vshlq_n_u16(vmovl_u8(vget_high_u8(hbit)), 5));
|
||||
const uint8x16_t scales8 = vorrq_u8(vshlq_n_u8(vandq_u8(qh, m7), 1), m1);
|
||||
scales.val[0] = vmovl_u8(vget_low_u8 (scales8));
|
||||
scales.val[1] = vmovl_u8(vget_high_u8 (scales8));
|
||||
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
vst1q_u16(gindex+0, vindex.val[l]);
|
||||
q1b.val[0] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[0])), vld1_s8((const void *)(iq1s_grid+gindex[1])));
|
||||
q1b.val[1] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[2])), vld1_s8((const void *)(iq1s_grid+gindex[3])));
|
||||
q1b.val[2] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[4])), vld1_s8((const void *)(iq1s_grid+gindex[5])));
|
||||
q1b.val[3] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[6])), vld1_s8((const void *)(iq1s_grid+gindex[7])));
|
||||
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
|
||||
|
||||
dotq.val[0] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(vzero, q1b.val[1], q8b.val[1]));
|
||||
dotq.val[1] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(vzero, q1b.val[3], q8b.val[3]));
|
||||
|
||||
sumi.val[0] = vmlaq_s32(sumi.val[0], dotq.val[0], vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales.val[l]))));
|
||||
sumi.val[1] = vmlaq_s32(sumi.val[1], dotq.val[1], vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales.val[l]))));
|
||||
}
|
||||
}
|
||||
|
||||
sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * vaddvq_s32(vaddq_s32(sumi.val[0], sumi.val[1]));
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
|
||||
#elif defined __AVX2__
|
||||
|
||||
const __m128i m8 = _mm_set1_epi8(0x08);
|
||||
const __m128i m7 = _mm_set1_epi8(0x07);
|
||||
|
Loading…
Reference in New Issue
Block a user