mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-01 00:39:00 +01:00
iq3_s: somewhat faster ARM_NEON dot product
Still dog slow - 10.7 t/s up from 9.9 t/s.
This commit is contained in:
parent
39e3a429c8
commit
1e94989156
@ -10089,18 +10089,33 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
||||
|
||||
#if defined(__ARM_NEON)
|
||||
|
||||
typedef union {
|
||||
uint16x8_t vec_index;
|
||||
uint16_t index[8];
|
||||
} vec_index_t;
|
||||
|
||||
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
|
||||
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
|
||||
};
|
||||
|
||||
static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
|
||||
|
||||
static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
|
||||
|
||||
const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1);
|
||||
const uint8x16_t mask2 = vld1q_u8(k_mask2);
|
||||
const int16x8_t hshift = vld1q_s16(k_shift);
|
||||
const uint16x8_t m256 = vdupq_n_u16(256);
|
||||
|
||||
uint8x16x2_t vs;
|
||||
ggml_int8x16x4_t q3s;
|
||||
ggml_int8x16x4_t q8b;
|
||||
vec_index_t idx;
|
||||
|
||||
#if QK_K == 256
|
||||
uint32_t scales32[2];
|
||||
const uint8_t * scales8 = (const uint8_t *)scales32;
|
||||
#endif
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
@ -10109,18 +10124,29 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
||||
const uint8_t * restrict qh = x[i].qh;
|
||||
const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
|
||||
#if QK_K == 256
|
||||
memcpy(scales32, x[i].scales, 4);
|
||||
scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;
|
||||
scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;
|
||||
#endif
|
||||
|
||||
int sumi1 = 0, sumi2 = 0;
|
||||
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
||||
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
|
||||
const uint32x4_t aux32x4_0 = {iq3xs_grid[qs[ 0] | ((qh[ib32+0] << 8) & 256)], iq3xs_grid[qs[ 1] | ((qh[ib32+0] << 7) & 256)],
|
||||
iq3xs_grid[qs[ 2] | ((qh[ib32+0] << 6) & 256)], iq3xs_grid[qs[ 3] | ((qh[ib32+0] << 5) & 256)]};
|
||||
const uint32x4_t aux32x4_1 = {iq3xs_grid[qs[ 4] | ((qh[ib32+0] << 4) & 256)], iq3xs_grid[qs[ 5] | ((qh[ib32+0] << 3) & 256)],
|
||||
iq3xs_grid[qs[ 6] | ((qh[ib32+0] << 2) & 256)], iq3xs_grid[qs[ 7] | ((qh[ib32+0] << 1) & 256)]};
|
||||
const uint32x4_t aux32x4_2 = {iq3xs_grid[qs[ 8] | ((qh[ib32+1] << 8) & 256)], iq3xs_grid[qs[ 9] | ((qh[ib32+1] << 7) & 256)],
|
||||
iq3xs_grid[qs[10] | ((qh[ib32+1] << 6) & 256)], iq3xs_grid[qs[11] | ((qh[ib32+1] << 5) & 256)]};
|
||||
const uint32x4_t aux32x4_3 = {iq3xs_grid[qs[12] | ((qh[ib32+1] << 4) & 256)], iq3xs_grid[qs[13] | ((qh[ib32+1] << 3) & 256)],
|
||||
iq3xs_grid[qs[14] | ((qh[ib32+1] << 2) & 256)], iq3xs_grid[qs[15] | ((qh[ib32+1] << 1) & 256)]};
|
||||
qs += 16;
|
||||
|
||||
const uint8x16_t idx_l = vld1q_u8(qs); qs += 16;
|
||||
idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256));
|
||||
const uint32x4_t aux32x4_0 = {iq3xs_grid[idx.index[0]], iq3xs_grid[idx.index[1]],
|
||||
iq3xs_grid[idx.index[2]], iq3xs_grid[idx.index[3]]};
|
||||
const uint32x4_t aux32x4_1 = {iq3xs_grid[idx.index[4]], iq3xs_grid[idx.index[5]],
|
||||
iq3xs_grid[idx.index[6]], iq3xs_grid[idx.index[7]]};
|
||||
idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256));
|
||||
const uint32x4_t aux32x4_2 = {iq3xs_grid[idx.index[0]], iq3xs_grid[idx.index[1]],
|
||||
iq3xs_grid[idx.index[2]], iq3xs_grid[idx.index[3]]};
|
||||
const uint32x4_t aux32x4_3 = {iq3xs_grid[idx.index[4]], iq3xs_grid[idx.index[5]],
|
||||
iq3xs_grid[idx.index[6]], iq3xs_grid[idx.index[7]]};
|
||||
|
||||
|
||||
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
|
||||
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
||||
@ -10144,8 +10170,13 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
||||
|
||||
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
|
||||
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
|
||||
#if QK_K == 256
|
||||
sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0];
|
||||
sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4];
|
||||
#else
|
||||
sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf));
|
||||
sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4));
|
||||
#endif
|
||||
}
|
||||
sumf += d*(sumi1 + sumi2);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user