mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 22:08:55 +01:00
ggml : fix q4_1 dot product types
This commit is contained in:
parent
c5d70f5c9e
commit
0f07cacb05
12
ggml.c
12
ggml.c
@ -2344,14 +2344,14 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
|||||||
|
|
||||||
#if defined(__ARM_FEATURE_DOTPROD)
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
// dot product into int32x4_t
|
// dot product into int32x4_t
|
||||||
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l);
|
uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l);
|
||||||
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l);
|
uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
|
||||||
|
|
||||||
p_0 = vdotq_s32(p_0, v0_0h, v1_0h);
|
p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
|
||||||
p_1 = vdotq_s32(p_1, v0_1h, v1_1h);
|
p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
|
||||||
|
|
||||||
sum11 += x0->d*y0->d*vaddvq_s32(p_0);
|
sum11 += x0->d*y0->d*vaddvq_u32(p_0);
|
||||||
sum11 += x1->d*y1->d*vaddvq_s32(p_1);
|
sum11 += x1->d*y1->d*vaddvq_u32(p_1);
|
||||||
#else
|
#else
|
||||||
const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
|
const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
|
||||||
const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
|
const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
|
||||||
|
Loading…
Reference in New Issue
Block a user