mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 05:42:22 +01:00
finish i2_s/i8_s vec_dot x86 simd
This commit is contained in:
parent
95dced07e4
commit
569a03ed97
128
ggml-common.h
128
ggml-common.h
@ -1023,70 +1023,70 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
|
||||
GGML_TABLE_END()
|
||||
|
||||
GGML_TABLE_BEGIN(uint32_t, i2s_i8s, 256)
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00010000, 0x01010000, 0x00010000, 0xff010000,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
|
||||
0x00000100, 0x01000100, 0x00000100, 0xff000100,
|
||||
0x00010100, 0x01010100, 0x00010100, 0xff010100,
|
||||
0x00000100, 0x01000100, 0x00000100, 0xff000100,
|
||||
0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00010000, 0x01010000, 0x00010000, 0xff010000,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
|
||||
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
|
||||
0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00,
|
||||
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
|
||||
0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00,
|
||||
0x00000001, 0x01000001, 0x00000001, 0xff000001,
|
||||
0x00010001, 0x01010001, 0x00010001, 0xff010001,
|
||||
0x00000001, 0x01000001, 0x00000001, 0xff000001,
|
||||
0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001,
|
||||
0x00000101, 0x01000101, 0x00000101, 0xff000101,
|
||||
0x00010101, 0x01010101, 0x00010101, 0xff010101,
|
||||
0x00000101, 0x01000101, 0x00000101, 0xff000101,
|
||||
0x00ff0101, 0x01ff0101, 0x00ff0101, 0xffff0101,
|
||||
0x00000001, 0x01000001, 0x00000001, 0xff000001,
|
||||
0x00010001, 0x01010001, 0x00010001, 0xff010001,
|
||||
0x00000001, 0x01000001, 0x00000001, 0xff000001,
|
||||
0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001,
|
||||
0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01,
|
||||
0x0001ff01, 0x0101ff01, 0x0001ff01, 0xff01ff01,
|
||||
0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01,
|
||||
0x00ffff01, 0x01ffff01, 0x00ffff01, 0xffffff01,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00010000, 0x01010000, 0x00010000, 0xff010000,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
|
||||
0x00000100, 0x01000100, 0x00000100, 0xff000100,
|
||||
0x00010100, 0x01010100, 0x00010100, 0xff010100,
|
||||
0x00000100, 0x01000100, 0x00000100, 0xff000100,
|
||||
0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00010000, 0x01010000, 0x00010000, 0xff010000,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
|
||||
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
|
||||
0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00,
|
||||
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
|
||||
0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00,
|
||||
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
|
||||
0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff,
|
||||
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
|
||||
0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff,
|
||||
0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff,
|
||||
0x000101ff, 0x010101ff, 0x000101ff, 0xff0101ff,
|
||||
0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff,
|
||||
0x00ff01ff, 0x01ff01ff, 0x00ff01ff, 0xffff01ff,
|
||||
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
|
||||
0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff,
|
||||
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
|
||||
0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff,
|
||||
0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff,
|
||||
0x0001ffff, 0x0101ffff, 0x0001ffff, 0xff01ffff,
|
||||
0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff,
|
||||
0x00ffffff, 0x01ffffff, 0x00ffffff, 0xffffffff,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00010000, 0x01010000, 0x00010000, 0xff010000,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
|
||||
0x00000100, 0x01000100, 0x00000100, 0xff000100,
|
||||
0x00010100, 0x01010100, 0x00010100, 0xff010100,
|
||||
0x00000100, 0x01000100, 0x00000100, 0xff000100,
|
||||
0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00010000, 0x01010000, 0x00010000, 0xff010000,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
|
||||
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
|
||||
0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00,
|
||||
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
|
||||
0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00,
|
||||
0x00000001, 0x01000001, 0x00000001, 0xff000001,
|
||||
0x00010001, 0x01010001, 0x00010001, 0xff010001,
|
||||
0x00000001, 0x01000001, 0x00000001, 0xff000001,
|
||||
0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001,
|
||||
0x00000101, 0x01000101, 0x00000101, 0xff000101,
|
||||
0x00010101, 0x01010101, 0x00010101, 0xff010101,
|
||||
0x00000101, 0x01000101, 0x00000101, 0xff000101,
|
||||
0x00ff0101, 0x01ff0101, 0x00ff0101, 0xffff0101,
|
||||
0x00000001, 0x01000001, 0x00000001, 0xff000001,
|
||||
0x00010001, 0x01010001, 0x00010001, 0xff010001,
|
||||
0x00000001, 0x01000001, 0x00000001, 0xff000001,
|
||||
0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001,
|
||||
0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01,
|
||||
0x0001ff01, 0x0101ff01, 0x0001ff01, 0xff01ff01,
|
||||
0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01,
|
||||
0x00ffff01, 0x01ffff01, 0x00ffff01, 0xffffff01,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00010000, 0x01010000, 0x00010000, 0xff010000,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
|
||||
0x00000100, 0x01000100, 0x00000100, 0xff000100,
|
||||
0x00010100, 0x01010100, 0x00010100, 0xff010100,
|
||||
0x00000100, 0x01000100, 0x00000100, 0xff000100,
|
||||
0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00010000, 0x01010000, 0x00010000, 0xff010000,
|
||||
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
|
||||
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
|
||||
0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00,
|
||||
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
|
||||
0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00,
|
||||
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
|
||||
0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff,
|
||||
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
|
||||
0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff,
|
||||
0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff,
|
||||
0x000101ff, 0x010101ff, 0x000101ff, 0xff0101ff,
|
||||
0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff,
|
||||
0x00ff01ff, 0x01ff01ff, 0x00ff01ff, 0xffff01ff,
|
||||
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
|
||||
0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff,
|
||||
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
|
||||
0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff,
|
||||
0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff,
|
||||
0x0001ffff, 0x0101ffff, 0x0001ffff, 0xff01ffff,
|
||||
0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff,
|
||||
0x00ffffff, 0x01ffffff, 0x00ffffff, 0xffffffff,
|
||||
GGML_TABLE_END()
|
||||
|
||||
#define NGRID_IQ1S 2048
|
||||
|
@ -3799,60 +3799,61 @@ void ggml_vec_dot_i2_i8_s(int n, float * restrict s, size_t bs, const void * res
|
||||
UNUSED(by);
|
||||
UNUSED(nrc);
|
||||
|
||||
// TODO
|
||||
// #if defined(__AVX2__)
|
||||
// __m256i accu = _mm256_setzero_si256();
|
||||
#if defined(__AVX2__)
|
||||
__m256i accu = _mm256_setzero_si256();
|
||||
|
||||
// for (int i=0; i<n/32; i++) {
|
||||
// const int8_t* w0 = (const int8_t *)(i2s_i8s + x[i*8 + 0]);
|
||||
// const int8_t* w1 = (const int8_t *)(i2s_i8s + x[i*8 + 1]);
|
||||
// const int8_t* w2 = (const int8_t *)(i2s_i8s + x[i*8 + 2]);
|
||||
// const int8_t* w3 = (const int8_t *)(i2s_i8s + x[i*8 + 3]);
|
||||
// const int8_t* w4 = (const int8_t *)(i2s_i8s + x[i*8 + 4]);
|
||||
// const int8_t* w5 = (const int8_t *)(i2s_i8s + x[i*8 + 5]);
|
||||
// const int8_t* w6 = (const int8_t *)(i2s_i8s + x[i*8 + 6]);
|
||||
// const int8_t* w7 = (const int8_t *)(i2s_i8s + x[i*8 + 7]);
|
||||
// max group_size is 128 (2^8)
|
||||
// limited by 8640 to 2 (8640 % (2 * 32) == 0)
|
||||
int group_num = 2;
|
||||
|
||||
// __m256i xq8 = _mm256_set_epi8(
|
||||
// w0[0], w0[1], w0[2], w0[3],
|
||||
// w1[0], w1[1], w1[2], w1[3],
|
||||
// w2[0], w2[1], w2[2], w2[3],
|
||||
// w3[0], w3[1], w3[2], w3[3],
|
||||
// w4[0], w4[1], w4[2], w4[3],
|
||||
// w5[0], w5[1], w5[2], w5[3],
|
||||
// w6[0], w6[1], w6[2], w6[3],
|
||||
// w7[0], w7[1], w7[2], w7[3]
|
||||
// );
|
||||
for (int i=0; i < n / (group_num * 32); i++){
|
||||
__m256i laccu = _mm256_setzero_si256();
|
||||
__m256i haccu = _mm256_setzero_si256();
|
||||
|
||||
// __m256i yq8 = _mm256_loadu_si256((const __m256i*)(y + i*32));
|
||||
for (int j=0; j < group_num; j++) {
|
||||
__m256i xq8 = _mm256_set_epi32(
|
||||
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 7]],
|
||||
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 6]],
|
||||
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 5]],
|
||||
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 4]],
|
||||
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 3]],
|
||||
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 2]],
|
||||
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 1]],
|
||||
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 0]]
|
||||
);
|
||||
|
||||
// __m128i hxq8 = _mm256_castsi256_si128(xq8);
|
||||
// __m128i lxq8 = _mm256_extractf128_si256(xq8, 1);
|
||||
// __m128i hyq8 = _mm256_castsi256_si128(yq8);
|
||||
// __m128i lyq8 = _mm256_extractf128_si256(yq8, 1);
|
||||
__m256i yq8 = _mm256_loadu_si256((const __m256i*)(y + i * group_num * 32 + j * 32));
|
||||
|
||||
// __m256i hxq16 = _mm256_cvtepi8_epi16(hxq8);
|
||||
// __m256i lxq16 = _mm256_cvtepi8_epi16(lxq8);
|
||||
// __m256i hyq16 = _mm256_cvtepi8_epi16(hyq8);
|
||||
// __m256i lyq16 = _mm256_cvtepi8_epi16(lyq8);
|
||||
__m128i hxq8 = _mm256_castsi256_si128(xq8);
|
||||
__m128i lxq8 = _mm256_extractf128_si256(xq8, 1);
|
||||
__m128i hyq8 = _mm256_castsi256_si128(yq8);
|
||||
__m128i lyq8 = _mm256_extractf128_si256(yq8, 1);
|
||||
|
||||
// __m256i hzq16 = _mm256_sign_epi16(hyq16, hxq16);
|
||||
// __m256i lzq16 = _mm256_sign_epi16(lyq16, lxq16);
|
||||
__m256i hxq16 = _mm256_cvtepi8_epi16(hxq8);
|
||||
__m256i lxq16 = _mm256_cvtepi8_epi16(lxq8);
|
||||
__m256i hyq16 = _mm256_cvtepi8_epi16(hyq8);
|
||||
__m256i lyq16 = _mm256_cvtepi8_epi16(lyq8);
|
||||
|
||||
// __m256i hhzq32 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(hzq16));
|
||||
// __m256i hlzq32 = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(hzq16, 1));
|
||||
// __m256i llzq32 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(lzq16));
|
||||
// __m256i lhzq32 = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(lzq16, 1));
|
||||
__m256i hzq16 = _mm256_sign_epi16(hyq16, hxq16);
|
||||
__m256i lzq16 = _mm256_sign_epi16(lyq16, lxq16);
|
||||
|
||||
// accu = _mm256_add_epi32(accu, hhzq32);
|
||||
// accu = _mm256_add_epi32(accu, hlzq32);
|
||||
// accu = _mm256_add_epi32(accu, llzq32);
|
||||
// accu = _mm256_add_epi32(accu, lhzq32);
|
||||
// }
|
||||
haccu = _mm256_add_epi16(haccu, hzq16);
|
||||
laccu = _mm256_add_epi16(laccu, lzq16);
|
||||
}
|
||||
|
||||
// int sumi = hsum_i32_8(accu);
|
||||
// *s = (float)sumi;
|
||||
// #else
|
||||
__m256i hhzq32 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(haccu));
|
||||
__m256i hlzq32 = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(haccu, 1));
|
||||
__m256i llzq32 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(laccu));
|
||||
__m256i lhzq32 = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(laccu, 1));
|
||||
|
||||
accu = _mm256_add_epi32(accu, hhzq32);
|
||||
accu = _mm256_add_epi32(accu, hlzq32);
|
||||
accu = _mm256_add_epi32(accu, llzq32);
|
||||
accu = _mm256_add_epi32(accu, lhzq32);
|
||||
}
|
||||
int sumi = hsum_i32_8(accu);
|
||||
*s = (float)sumi;
|
||||
#else
|
||||
|
||||
int sumi = 0;
|
||||
|
||||
@ -3864,7 +3865,7 @@ void ggml_vec_dot_i2_i8_s(int n, float * restrict s, size_t bs, const void * res
|
||||
sumi += (int)y[i*4+3] * weight[3];
|
||||
}
|
||||
*s = (float)sumi;
|
||||
// #endif
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
|
@ -11702,7 +11702,6 @@ struct llm_build_context {
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur);
|
||||
cb(cur, "ffn_down", il);
|
||||
|
||||
}
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "l_out", il);
|
||||
@ -11723,7 +11722,6 @@ struct llm_build_context {
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user