mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-26 03:12:23 +01:00
ggml : add ARM_NEON ggml_vec_dot_q4_1()
This commit is contained in:
parent
61cbfff5c9
commit
3b44d30d9b
39
ggml.c
39
ggml.c
@ -2008,6 +2008,45 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
|||||||
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
||||||
|
|
||||||
sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
|
sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
|
||||||
|
#elif defined(__ARM_NEON)
|
||||||
|
float sum00 = 0.0f;
|
||||||
|
float sum01 = 0.0f;
|
||||||
|
float sum10 = 0.0f;
|
||||||
|
float sum11 = 0.0f;
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const block_q4_1 * restrict x0 = &x[i + 0];
|
||||||
|
const block_q4_1 * restrict y0 = &y[i + 0];
|
||||||
|
|
||||||
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
||||||
|
|
||||||
|
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
||||||
|
const uint8x16_t v1_0 = vld1q_u8(y0->qs);
|
||||||
|
|
||||||
|
// and with 0xf
|
||||||
|
const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
|
||||||
|
const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
|
||||||
|
|
||||||
|
const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
|
||||||
|
const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
|
||||||
|
|
||||||
|
// dot product into uint16x8_t
|
||||||
|
const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
|
||||||
|
const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
|
||||||
|
|
||||||
|
const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
|
||||||
|
const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
|
||||||
|
|
||||||
|
const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
|
||||||
|
const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
|
||||||
|
|
||||||
|
sum00 += x0->m*y0->m;
|
||||||
|
sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
|
||||||
|
sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
|
||||||
|
sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
|
||||||
|
}
|
||||||
|
|
||||||
|
sumf = QK*sum00 + sum01 + sum10 + sum11;
|
||||||
#else
|
#else
|
||||||
// scalar
|
// scalar
|
||||||
for (int i = 0; i < nb; i++) {
|
for (int i = 0; i < nb; i++) {
|
||||||
|
Loading…
Reference in New Issue
Block a user