diff --git a/ggml-quants.c b/ggml-quants.c index a3a8ab9c8..4a3cc2722 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -9583,7 +9583,7 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void const uint8_t * qs = x[i].qs; const uint16_t * qh = x[i].qh; - int sumi1 = 0, sumi2 = 0; + int sumi1 = 0, sumi2 = 0, sumi3 = 0; for (int ib = 0; ib < QK_K/32; ib += 2) { @@ -9602,12 +9602,16 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]); const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]); - sumi1 += vaddvq_s32(p1) * (2*(qh[ib+0] >> 12) + 1); - sumi2 += vaddvq_s32(p2) * (2*(qh[ib+1] >> 12) + 1); + const int ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; + const int ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; + sumi1 += vaddvq_s32(p1) * ls1; + sumi2 += vaddvq_s32(p2) * ls2; + sumi3 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * ls1 * (qh[ib+0] & 0x8000 ? -1 : 1) + + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * ls2 * (qh[ib+1] & 0x8000 ? -1 : 1); } - sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2); + sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2 + IQ1S_DELTA * sumi3); } *s = sumf;