mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-28 04:47:04 +01:00
iq1_s: slightly faster dot product
This commit is contained in:
parent
f604a17994
commit
5c977221d2
@ -4344,19 +4344,20 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|||||||
device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
|
device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
|
||||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||||
|
|
||||||
float yl[32];
|
float yl[16];
|
||||||
float sumf[N_DST]={0.f}, all_sum;
|
float sumf[N_DST]={0.f}, all_sum;
|
||||||
|
|
||||||
const int nb32 = nb * (QK_K / 32);
|
const int nb32 = nb * (QK_K / 32);
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int ix = tiisg;
|
const int ix = tiisg/2;
|
||||||
|
const int il = tiisg%2;
|
||||||
|
|
||||||
device const float * y4 = y + 32 * ix;
|
device const float * y4 = y + 32 * ix + 16 * il;
|
||||||
|
|
||||||
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
|
||||||
|
|
||||||
for (int i = 0; i < 32; ++i) {
|
for (int i = 0; i < 16; ++i) {
|
||||||
yl[i] = y4[i];
|
yl[i] = y4[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4364,33 +4365,28 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|||||||
const int ib = ib32 % (QK_K / 32);
|
const int ib = ib32 % (QK_K / 32);
|
||||||
|
|
||||||
device const block_iq1_s * xr = x + ibl;
|
device const block_iq1_s * xr = x + ibl;
|
||||||
device const uint8_t * qs = xr->qs + 4 * ib;
|
device const uint8_t * qs = xr->qs + 4 * ib + 2 * il;
|
||||||
device const uint8_t * sc = xr->scales + 2 * ib;
|
device const uint8_t * sc = xr->scales + 2 * ib + il;
|
||||||
device const half * dh = &xr->d;
|
device const half * dh = &xr->d;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row++) {
|
for (int row = 0; row < N_DST; row++) {
|
||||||
|
|
||||||
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
|
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
|
||||||
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
|
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
|
||||||
constant int8_t * grid3 = (constant int8_t *)(iq1s_grid + (qs[2] | ((sc[1] & 0x08) << 5)));
|
|
||||||
constant int8_t * grid4 = (constant int8_t *)(iq1s_grid + (qs[3] | ((sc[1] & 0x80) << 1)));
|
|
||||||
|
|
||||||
float4 sum = {0};
|
float2 sum = {0};
|
||||||
for (int j = 0; j < 8; ++j) {
|
for (int j = 0; j < 8; ++j) {
|
||||||
sum[0] += yl[j+ 0] * grid1[j];
|
sum[0] += yl[j+ 0] * grid1[j];
|
||||||
sum[1] += yl[j+ 8] * grid2[j];
|
sum[1] += yl[j+ 8] * grid2[j];
|
||||||
sum[2] += yl[j+16] * grid3[j];
|
|
||||||
sum[3] += yl[j+24] * grid4[j];
|
|
||||||
}
|
}
|
||||||
sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1) +
|
sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1));
|
||||||
sum[2] * (2*(sc[1] & 7) + 1) + sum[3] * (2*((sc[1] >> 4) & 7) + 1));
|
|
||||||
|
|
||||||
dh += nb*sizeof(block_iq1_s)/2;
|
dh += nb*sizeof(block_iq1_s)/2;
|
||||||
qs += nb*sizeof(block_iq1_s);
|
qs += nb*sizeof(block_iq1_s);
|
||||||
sc += nb*sizeof(block_iq1_s);
|
sc += nb*sizeof(block_iq1_s);
|
||||||
}
|
}
|
||||||
|
|
||||||
y4 += 32 * 32;
|
y4 += 16 * 32;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
// TODO
|
// TODO
|
||||||
|
Loading…
Reference in New Issue
Block a user