finish i2_s/i8_s vec_dot x86 simd

This commit is contained in:
Eddie-Wang 2024-06-15 14:01:26 +00:00
parent 95dced07e4
commit 569a03ed97
3 changed files with 111 additions and 112 deletions

View File

@ -3799,60 +3799,61 @@ void ggml_vec_dot_i2_i8_s(int n, float * restrict s, size_t bs, const void * res
UNUSED(by);
UNUSED(nrc);
// TODO
// #if defined(__AVX2__)
// __m256i accu = _mm256_setzero_si256();
#if defined(__AVX2__)
__m256i accu = _mm256_setzero_si256();
// for (int i=0; i<n/32; i++) {
// const int8_t* w0 = (const int8_t *)(i2s_i8s + x[i*8 + 0]);
// const int8_t* w1 = (const int8_t *)(i2s_i8s + x[i*8 + 1]);
// const int8_t* w2 = (const int8_t *)(i2s_i8s + x[i*8 + 2]);
// const int8_t* w3 = (const int8_t *)(i2s_i8s + x[i*8 + 3]);
// const int8_t* w4 = (const int8_t *)(i2s_i8s + x[i*8 + 4]);
// const int8_t* w5 = (const int8_t *)(i2s_i8s + x[i*8 + 5]);
// const int8_t* w6 = (const int8_t *)(i2s_i8s + x[i*8 + 6]);
// const int8_t* w7 = (const int8_t *)(i2s_i8s + x[i*8 + 7]);
// max group_size is 128 (2^8)
// limited by 8640 to 2 (8640 % (2 * 32) == 0)
int group_num = 2;
// __m256i xq8 = _mm256_set_epi8(
// w0[0], w0[1], w0[2], w0[3],
// w1[0], w1[1], w1[2], w1[3],
// w2[0], w2[1], w2[2], w2[3],
// w3[0], w3[1], w3[2], w3[3],
// w4[0], w4[1], w4[2], w4[3],
// w5[0], w5[1], w5[2], w5[3],
// w6[0], w6[1], w6[2], w6[3],
// w7[0], w7[1], w7[2], w7[3]
// );
for (int i=0; i < n / (group_num * 32); i++){
__m256i laccu = _mm256_setzero_si256();
__m256i haccu = _mm256_setzero_si256();
// __m256i yq8 = _mm256_loadu_si256((const __m256i*)(y + i*32));
for (int j=0; j < group_num; j++) {
__m256i xq8 = _mm256_set_epi32(
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 7]],
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 6]],
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 5]],
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 4]],
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 3]],
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 2]],
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 1]],
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 0]]
);
// __m128i hxq8 = _mm256_castsi256_si128(xq8);
// __m128i lxq8 = _mm256_extractf128_si256(xq8, 1);
// __m128i hyq8 = _mm256_castsi256_si128(yq8);
// __m128i lyq8 = _mm256_extractf128_si256(yq8, 1);
__m256i yq8 = _mm256_loadu_si256((const __m256i*)(y + i * group_num * 32 + j * 32));
// __m256i hxq16 = _mm256_cvtepi8_epi16(hxq8);
// __m256i lxq16 = _mm256_cvtepi8_epi16(lxq8);
// __m256i hyq16 = _mm256_cvtepi8_epi16(hyq8);
// __m256i lyq16 = _mm256_cvtepi8_epi16(lyq8);
__m128i hxq8 = _mm256_castsi256_si128(xq8);
__m128i lxq8 = _mm256_extractf128_si256(xq8, 1);
__m128i hyq8 = _mm256_castsi256_si128(yq8);
__m128i lyq8 = _mm256_extractf128_si256(yq8, 1);
// __m256i hzq16 = _mm256_sign_epi16(hyq16, hxq16);
// __m256i lzq16 = _mm256_sign_epi16(lyq16, lxq16);
__m256i hxq16 = _mm256_cvtepi8_epi16(hxq8);
__m256i lxq16 = _mm256_cvtepi8_epi16(lxq8);
__m256i hyq16 = _mm256_cvtepi8_epi16(hyq8);
__m256i lyq16 = _mm256_cvtepi8_epi16(lyq8);
// __m256i hhzq32 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(hzq16));
// __m256i hlzq32 = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(hzq16, 1));
// __m256i llzq32 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(lzq16));
// __m256i lhzq32 = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(lzq16, 1));
__m256i hzq16 = _mm256_sign_epi16(hyq16, hxq16);
__m256i lzq16 = _mm256_sign_epi16(lyq16, lxq16);
// accu = _mm256_add_epi32(accu, hhzq32);
// accu = _mm256_add_epi32(accu, hlzq32);
// accu = _mm256_add_epi32(accu, llzq32);
// accu = _mm256_add_epi32(accu, lhzq32);
// }
haccu = _mm256_add_epi16(haccu, hzq16);
laccu = _mm256_add_epi16(laccu, lzq16);
}
// int sumi = hsum_i32_8(accu);
// *s = (float)sumi;
// #else
__m256i hhzq32 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(haccu));
__m256i hlzq32 = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(haccu, 1));
__m256i llzq32 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(laccu));
__m256i lhzq32 = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(laccu, 1));
accu = _mm256_add_epi32(accu, hhzq32);
accu = _mm256_add_epi32(accu, hlzq32);
accu = _mm256_add_epi32(accu, llzq32);
accu = _mm256_add_epi32(accu, lhzq32);
}
int sumi = hsum_i32_8(accu);
*s = (float)sumi;
#else
int sumi = 0;
@ -3864,7 +3865,7 @@ void ggml_vec_dot_i2_i8_s(int n, float * restrict s, size_t bs, const void * res
sumi += (int)y[i*4+3] * weight[3];
}
*s = (float)sumi;
// #endif
#endif
}
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {

View File

@ -11702,7 +11702,6 @@ struct llm_build_context {
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur);
cb(cur, "ffn_down", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "l_out", il);
@ -11723,7 +11722,6 @@ struct llm_build_context {
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}