Attempt 2

This commit is contained in:
Iwan Kawrakow 2024-03-12 18:40:13 +02:00
parent 9188523f70
commit 9f805264dc

View File

@ -4713,7 +4713,7 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA; const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1); const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32; uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)]; grid32[0] = iq1s_grid[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f; grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
grid32[0] &= 0x0f0f0f0f; grid32[0] &= 0x0f0f0f0f;
for (int j = 0; j < 8; ++j) { for (int j = 0; j < 8; ++j) {
@ -7616,10 +7616,12 @@ vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
const block_q8_1 *__restrict__ bq8_1, const int &iqs, const block_q8_1 *__restrict__ bq8_1, const int &iqs,
const uint32_t *iq1s_grid) { const uint32_t *iq1s_grid) {
#if QK_K == 256 #if QK_K == 256
const int ib32 = iqs;
const block_iq1_s * bq1 = (const block_iq1_s *) vbq; const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
const int * q8 = (const int *)bq8_1[ib32].qs; const int * q8 = (const int *)bq8_1[ib32].qs;
int sumi = 0;
for (int l = 0; l < 4; ++l) { for (int l = 0; l < 4; ++l) {
const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8))); const int * grid = (const int *)(iq1s_grid + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
int grid0 = grid[0] & 0x0f0f0f0f; int grid0 = grid[0] & 0x0f0f0f0f;
int grid1 = (grid[0] >> 4) & 0x0f0f0f0f; int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
sumi = dpct::dp4a(q8[2*l+1], grid1, dpct::dp4a(q8[2*l+0], grid0, sumi)); sumi = dpct::dp4a(q8[2*l+1], grid1, dpct::dp4a(q8[2*l+0], grid0, sumi));