mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 06:10:29 +01:00
ggml : optimize non-SIMD Q4_0 vector dot product (#703)
This commit is contained in:
parent
6c248707f5
commit
6232f2d7fd
12
ggml.c
12
ggml.c
@ -2160,18 +2160,20 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|||||||
const uint8_t * restrict p0 = x[i].qs;
|
const uint8_t * restrict p0 = x[i].qs;
|
||||||
const uint8_t * restrict p1 = y[i].qs;
|
const uint8_t * restrict p1 = y[i].qs;
|
||||||
|
|
||||||
|
int sumi = 0;
|
||||||
for (int j = 0; j < QK/2; j++) {
|
for (int j = 0; j < QK/2; j++) {
|
||||||
const uint8_t v0 = p0[j];
|
const uint8_t v0 = p0[j];
|
||||||
const uint8_t v1 = p1[j];
|
const uint8_t v1 = p1[j];
|
||||||
|
|
||||||
const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
|
const int8_t i0 = (int8_t) (v0 & 0xf) - 8;
|
||||||
const float f1 = d0*((int8_t) (v0 >> 4) - 8);
|
const int8_t i1 = (int8_t) (v0 >> 4) - 8;
|
||||||
|
|
||||||
const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
|
const int8_t i2 = (int8_t) (v1 & 0xf) - 8;
|
||||||
const float f3 = d1*((int8_t) (v1 >> 4) - 8);
|
const int8_t i3 = (int8_t) (v1 >> 4) - 8;
|
||||||
|
|
||||||
sumf += f0*f2 + f1*f3;
|
sumi += i0*i2 + i1*i3;
|
||||||
}
|
}
|
||||||
|
sumf += d0 * d1 * sumi;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user