iq3_s(multiplier): use SIMD also in dequantize

This commit is contained in:
Iwan Kawrakow 2024-03-01 12:02:39 +02:00
parent 9c752ff0d3
commit 1cc7cb2b46

View File

@ -2389,24 +2389,28 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * qs = x[i].qs + 8*ib;
int32_t aux32[2];
const uint8_t * grid = (const uint8_t *)aux32;
const int8_t * grid = (const int8_t *)aux32;
const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
const uint8_t signs = x[i].signs[4*ib + il];
aux32[0] = ((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
aux32[1] = ((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
aux32[0] = __vadd4(((__vadd4(aux32[0], 0x01010101) >> 1) & 0x07070707) << 1, 0x01010101);
aux32[1] = __vadd4(((__vadd4(aux32[1], 0x01010101) >> 1) & 0x07070707) << 1, 0x01010101);
uint32_t signs0 = __vcmpeq4(((signs & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
uint32_t signs1 = __vcmpeq4(((signs >> 4) * 0x01010101) & 0x08040201, 0x08040201);
aux32[0] = __vsub4(aux32[0] ^ signs0, signs0);
aux32[1] = __vsub4(aux32[1] ^ signs1, signs1);
for (int j = 0; j < 8; ++j) {
y[j] = d * grid[j];
}
#else
for (int j = 0; j < 8; ++j) {
//y[j] = d * (2*((grid[j]-1)/2) + 1) * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
//y[j] = d * iq3s_values[grid[j]] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
y[j] = d * (2*(((grid[j]+1)/2) & 7) + 1) * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
// const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
// const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
// const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f;
// const uint8_t signs = x[i].signs[4*ib + il];
// for (int j = 0; j < 4; ++j) {
// y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
// y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
// }
#endif
#else
assert(false);
#endif