iq3_s_mult: alternative multiplier / bit twidling

This commit is contained in:
Iwan Kawrakow 2024-03-03 08:51:28 +02:00
parent fe3c20b251
commit 726aed307a
2 changed files with 66 additions and 21 deletions

View File

@ -2370,9 +2370,10 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
}
//#define IQ3S_MULTIPLIER 746226
//#define IQ3S_MULTIPLIER 677595
#define IQ3S_MULTIPLIER 190842953LL
//#define IQ3S_MULTIPLIER 190842953LL
//#define IQ3S_MULTIPLIER 5718026
#define IQ3S_MULTIPLIER 898886
template<typename dst_t>
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
@ -2390,11 +2391,17 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
const int8_t * grid = (const int8_t *)aux32;
const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
const uint8_t signs = x[i].signs[4*ib + il];
aux32[0] = ((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
aux32[1] = ((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
//aux32[0] = ((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
//aux32[1] = ((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
aux32[0] = (((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
aux32[1] = (((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101;
aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101;
//aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101;
//aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101;
//aux32[0] = ((__vsub4(aux32[0], 0x01010101) >> 1) << 1) | 0x01010101;
//aux32[1] = ((__vsub4(aux32[1], 0x01010101) >> 1) << 1) | 0x01010101;
aux32[0] = ((aux32[0] >> 1) << 1) | 0x01010101;
aux32[1] = ((aux32[1] >> 1) << 1) | 0x01010101;
uint32_t signs0 = __vcmpeq4(((signs & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
uint32_t signs1 = __vcmpeq4(((signs >> 4) * 0x01010101) & 0x08040201, 0x08040201);
aux32[0] = __vsub4(aux32[0] ^ signs0, signs0);
@ -5220,10 +5227,16 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
const int8_t * q8 = bq8_1[ib32].qs;
int sumi = 0;
for (int l = 0; l < 4; ++l) {
//aux32[0] = ((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
//aux32[1] = ((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
//aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101;
//aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101;
aux32[0] = ((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
aux32[1] = ((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101;
aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101;
//aux32[0] = ((__vsub4(aux32[0], 0x01010101) >> 1) << 1) | 0x01010101;
//aux32[1] = ((__vsub4(aux32[1], 0x01010101) >> 1) << 1) | 0x01010101;
aux32[0] = ((aux32[0] >> 1) << 1) | 0x01010101;
aux32[1] = ((aux32[1] >> 1) << 1) | 0x01010101;
uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
const int grid_l = __vsub4(aux32[0] ^ signs0, signs0);

View File

@ -4125,7 +4125,13 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y
//#define IQ3S_MULTIPLIER 746226
//#define IQ3S_MULTIPLIER 717154
//#define IQ3S_MULTIPLIER 677595
#define IQ3S_MULTIPLIER 190842953
// Best PPL
//#define IQ3S_MULTIPLIER 190842953
//
//#define IQ3S_MULTIPLIER 5718026
#define IQ3S_MULTIPLIER 898886
#define IQ3S_BITS 3
void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int k) {
@ -4146,20 +4152,34 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in
const float db1 = d * (1 + 2*(x[i].scales[ib32/2] & 0xf));
const float db2 = d * (1 + 2*(x[i].scales[ib32/2] >> 4));
for (int l = 0; l < 4; ++l) {
aux32[0] = ((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
aux32[1] = ((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
//aux32[0] = ((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
//aux32[1] = ((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
//for (int j = 0; j < 8; ++j) {
// y[j] = db1 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
//}
aux32[0] = (((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
aux32[1] = (((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
aux32[0] = ((aux32[0] >> 1) << 1) | 0x01010101;
aux32[1] = ((aux32[1] >> 1) << 1) | 0x01010101;
for (int j = 0; j < 8; ++j) {
y[j] = db1 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
y[j] = db1 * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
}
y += 8;
}
qs += 8;
signs += 4;
for (int l = 0; l < 4; ++l) {
aux32[0] = ((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
aux32[1] = ((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
//aux32[0] = ((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
//aux32[1] = ((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
//for (int j = 0; j < 8; ++j) {
// y[j] = db2 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
//}
aux32[0] = (((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
aux32[1] = (((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
aux32[0] = ((aux32[0] >> 1) << 1) | 0x01010101;
aux32[1] = ((aux32[1] >> 1) << 1) | 0x01010101;
for (int j = 0; j < 8; ++j) {
y[j] = db2 * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
y[j] = db2 * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
}
y += 8;
}
@ -10136,7 +10156,12 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
const __m256i m1 = _mm256_set1_epi8(1);
const __m256i m7 = _mm256_set1_epi32(0x07070707);
const __m256i m15 = _mm256_set1_epi32(0x0f0f0f0f);
const __m256i m0 = _mm256_setzero_si256();
//const __m256i m0 = _mm256_setzero_si256();
// aux32[0] = (((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
// aux32[1] = (((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
// aux32[0] = ((aux32[0] >> 1) << 1) | 0x01010101;
// aux32[1] = ((aux32[1] >> 1) << 1) | 0x01010101;
__m256 accumf = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
@ -10158,11 +10183,16 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
const __m256i idx_32_h = _mm256_or_si256(idx_h_h, _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l_16, 1)));
// v = MAX(((IQ3S_MULTIPLIER * idx) & 0x0f0f0f0f) - 1, 0)
const __m256i idx_l = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1), m0);
//const __m256i idx_l = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1), m0);
// v = (((v >> 1) & 0x07070707) << 1) | 0x01010101
//const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1);
//const __m256i idx_h = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1), m0);
//const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1);
const __m256i idx_l = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1);
const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1);
const __m256i idx_h = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1), m0);
const __m256i idx_h = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1);
const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1);
__m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
@ -11392,9 +11422,11 @@ static void iq3xs_init_grid512(void) {
const uint8_t * q4 = (const uint8_t *)&aux32;
for (int k = 0; k < grid_size; ++k) {
int8_t * pos = (int8_t *)(the_grid + k);
aux32 = ((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f;
//aux32 = ((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f;
aux32 = (((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f) | 0x01010101;
for (int i = 0; i < 4; ++i) {
pos[i] = 2*((q4[i]-1)/2) + 1;
//pos[i] = 2*((q4[i]-1)/2) + 1;
pos[i] = 2*(q4[i]/2) + 1;
}
}