mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-01 00:39:00 +01:00
iq2_xs: working, but dog slow, ARM_NEON dot product
This commit is contained in:
parent
55e2cae83f
commit
ff49d876c6
@ -7550,13 +7550,10 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
#if defined(z__ARM_NEON)
|
||||
#if defined(__ARM_NEON)
|
||||
|
||||
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
||||
|
||||
uint32_t aux32[4];
|
||||
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||
|
||||
int8x16x4_t q2u;
|
||||
int8x16x4_t q2s;
|
||||
int8x16x4_t q8b;
|
||||
@ -7565,27 +7562,33 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||
const uint16_t * restrict q2 = x[i].qs;
|
||||
const uint8_t * restrict sc = x[i].scales;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
float sumf1 = 0, sumf2 = 0;
|
||||
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
||||
q8b = vld1q_s8_x4(q8); q8 += 64;
|
||||
memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
|
||||
q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1])));
|
||||
q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3])));
|
||||
q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9])));
|
||||
q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11])));
|
||||
q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
|
||||
q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
|
||||
q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127))));
|
||||
q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127))));
|
||||
q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
|
||||
q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
|
||||
q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));
|
||||
q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511))));
|
||||
q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9))));
|
||||
q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9))));
|
||||
q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9))));
|
||||
q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9))));
|
||||
q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
|
||||
q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
|
||||
q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
|
||||
q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
|
||||
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]);
|
||||
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]);
|
||||
sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28));
|
||||
sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28));
|
||||
const int32x4_t p1 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]);
|
||||
const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]);
|
||||
const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]);
|
||||
const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]);
|
||||
sumf1 += vaddvq_s32(p1) * (0.5f + (sc[0] & 0xf));
|
||||
sumf2 += vaddvq_s32(p2) * (0.5f + (sc[0] >> 4));
|
||||
sumf1 += vaddvq_s32(p3) * (0.5f + (sc[1] & 0xf));
|
||||
sumf2 += vaddvq_s32(p4) * (0.5f + (sc[1] >> 4));
|
||||
q2 += 8;
|
||||
sc += 2;
|
||||
}
|
||||
sumf += d*(sumf1 + sumf2);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user