mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 07:34:18 +01:00
iq3s_mult: ARM and Metal
This commit is contained in:
parent
b6402fa757
commit
5b9c8785fa
@ -2546,7 +2546,12 @@ typedef struct {
|
|||||||
uint8_t signs[QK_K/8];
|
uint8_t signs[QK_K/8];
|
||||||
uint8_t scales[IQ3S_N_SCALE];
|
uint8_t scales[IQ3S_N_SCALE];
|
||||||
} block_iq3_s;
|
} block_iq3_s;
|
||||||
|
#ifdef IQ3S_SLOW_MULT
|
||||||
#define IQ3S_MULTIPLIER 190842953
|
#define IQ3S_MULTIPLIER 190842953
|
||||||
|
#else
|
||||||
|
//#define IQ3S_MULTIPLIER 898886
|
||||||
|
#define IQ3S_MULTIPLIER 842866
|
||||||
|
#endif
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d;
|
half d;
|
||||||
@ -4691,15 +4696,21 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|||||||
|
|
||||||
threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
|
threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
|
||||||
{
|
{
|
||||||
uint32_t aux32;
|
|
||||||
thread int8_t * q = (thread int8_t *)&aux32;
|
|
||||||
int nval = 8;
|
int nval = 8;
|
||||||
int pos = (32*sgitg + tiisg)*nval;
|
int pos = (32*sgitg + tiisg)*nval;
|
||||||
|
#ifdef IQ3S_SLOW_MULT
|
||||||
|
uint32_t aux32;
|
||||||
|
thread int8_t * q = (thread int8_t *)&aux32;
|
||||||
for (int i = 0; i < nval; ++i) {
|
for (int i = 0; i < nval; ++i) {
|
||||||
aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f;
|
aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f;
|
||||||
for (int k = 0; k < 4; ++k) q[k] = 2*((q[k]-1)/2) + 1;
|
for (int k = 0; k < 4; ++k) q[k] = 2*((q[k]-1)/2) + 1;
|
||||||
values[pos + i] = aux32;
|
values[pos + i] = aux32;
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
for (int i = 0; i < nval; ++i) {
|
||||||
|
values[pos + i] = ((IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f) | 0x01010101;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4733,17 +4744,16 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|||||||
|
|
||||||
float2 sum = {0};
|
float2 sum = {0};
|
||||||
for (int l = 0; l < 4; ++l) {
|
for (int l = 0; l < 4; ++l) {
|
||||||
//aux32[0] = (IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))) & 0x0f0f0f0f;
|
// This is slower than pre-computing the grid in shared memory and loading from there
|
||||||
//aux32[1] = (IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))) & 0x0f0f0f0f;
|
//aux32[0] = ((IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101;
|
||||||
//threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
|
//aux32[1] = ((IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101;
|
||||||
//threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
|
//for (int j = 0; j < 4; ++j) {
|
||||||
threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + qs[2*l+0] +
|
// sum[0] += yl[8*l + j + 0] * grid[j+0] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
||||||
select(0, 256, qh[0] & kmask_iq2xs[2*l+0]));
|
// sum[1] += yl[8*l + j + 4] * grid[j+4] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
||||||
threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + qs[2*l+1] +
|
//}
|
||||||
select(0, 256, qh[0] & kmask_iq2xs[2*l+1]));
|
threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
|
||||||
|
threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
|
||||||
for (int j = 0; j < 4; ++j) {
|
for (int j = 0; j < 4; ++j) {
|
||||||
//sum[0] += yl[8*l + j + 0] * (2*((grid[j+0] - 1)/2) + 1) * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
|
||||||
//sum[1] += yl[8*l + j + 4] * (2*((grid[j+4] - 1)/2) + 1) * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
|
||||||
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
||||||
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
||||||
}
|
}
|
||||||
@ -5657,6 +5667,7 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 &
|
|||||||
const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
|
const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
|
||||||
uint32_t aux32[2];
|
uint32_t aux32[2];
|
||||||
thread const int8_t * grid = (thread const int8_t *)aux32;
|
thread const int8_t * grid = (thread const int8_t *)aux32;
|
||||||
|
#ifdef IQ3S_SLOW)MULT
|
||||||
aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f;
|
aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f;
|
||||||
aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f;
|
aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f;
|
||||||
for (int i = 0; i < 4; ++i) {
|
for (int i = 0; i < 4; ++i) {
|
||||||
@ -5669,6 +5680,20 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 &
|
|||||||
reg[2][i] = dl * (2*((grid[i+0]-1)/2)+1) * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
|
reg[2][i] = dl * (2*((grid[i+0]-1)/2)+1) * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
|
||||||
reg[3][i] = dl * (2*((grid[i+4]-1)/2)+1) * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
|
reg[3][i] = dl * (2*((grid[i+4]-1)/2)+1) * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
aux32[0] = ((IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f) | 0x01010101;
|
||||||
|
aux32[1] = ((IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f) | 0x01010101;
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
reg[0][i] = dl * grid[i+0] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
|
||||||
|
reg[1][i] = dl * grid[i+4] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
|
||||||
|
}
|
||||||
|
aux32[0] = ((IQ3S_MULTIPLIER * (qs[4*il+2] | ((qh << 6) & 256))) & 0x0f0f0f0f) | 0x01010101;
|
||||||
|
aux32[1] = ((IQ3S_MULTIPLIER * (qs[4*il+3] | ((qh << 5) & 256))) & 0x0f0f0f0f) | 0x01010101;
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
reg[2][i] = dl * grid[i+0] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
|
||||||
|
reg[3][i] = dl * grid[i+4] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
|
@ -10070,6 +10070,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||||||
vmovl_u8(vget_low_u8(idx_l)));
|
vmovl_u8(vget_low_u8(idx_l)));
|
||||||
const uint16x8_t idx_2 = vorrq_u16(vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), idx_shift), idx_mask1),
|
const uint16x8_t idx_2 = vorrq_u16(vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), idx_shift), idx_mask1),
|
||||||
vmovl_u8(vget_high_u8(idx_l)));
|
vmovl_u8(vget_high_u8(idx_l)));
|
||||||
|
#ifdef IQ3S_SLOW_MULT
|
||||||
q3s.val[0] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2));
|
q3s.val[0] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2));
|
||||||
q3s.val[1] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2));
|
q3s.val[1] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2));
|
||||||
q3s.val[2] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2));
|
q3s.val[2] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2));
|
||||||
@ -10078,6 +10079,12 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||||||
q3s.val[1] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[1], m1), m0), 1), 1), m1);
|
q3s.val[1] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[1], m1), m0), 1), 1), m1);
|
||||||
q3s.val[2] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[2], m1), m0), 1), 1), m1);
|
q3s.val[2] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[2], m1), m0), 1), 1), m1);
|
||||||
q3s.val[3] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[3], m1), m0), 1), 1), m1);
|
q3s.val[3] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[3], m1), m0), 1), 1), m1);
|
||||||
|
#else
|
||||||
|
q3s.val[0] = vorrq_s8(vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2)), m1);
|
||||||
|
q3s.val[1] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2)), m1);
|
||||||
|
q3s.val[2] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2)), m1);
|
||||||
|
q3s.val[3] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_2))), idx_mask2)), m1);
|
||||||
|
#endif
|
||||||
|
|
||||||
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
|
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
|
||||||
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
||||||
@ -10094,8 +10101,6 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||||||
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), vreinterpretq_u8_s8(m1));
|
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), vreinterpretq_u8_s8(m1));
|
||||||
vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), vreinterpretq_u8_s8(m1));
|
vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), vreinterpretq_u8_s8(m1));
|
||||||
|
|
||||||
signs += 4;
|
|
||||||
|
|
||||||
q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), q3s.val[2]);
|
q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), q3s.val[2]);
|
||||||
q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), q3s.val[3]);
|
q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), q3s.val[3]);
|
||||||
|
|
||||||
@ -10103,6 +10108,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||||||
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
|
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
|
||||||
sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf));
|
sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf));
|
||||||
sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4));
|
sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4));
|
||||||
|
|
||||||
|
signs += 4;
|
||||||
}
|
}
|
||||||
sumf += d*(sumi1 + sumi2);
|
sumf += d*(sumi1 + sumi2);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user