mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-14 22:38:58 +01:00
iq3_s_mult_shuffle: use lookup table on Metal
~4% faster TG and ~2% faster PP that way.
This commit is contained in:
parent
93034df760
commit
31cecc8734
130
ggml-metal.metal
130
ggml-metal.metal
@ -2546,8 +2546,10 @@ typedef struct {
|
|||||||
uint8_t signs[QK_K/8];
|
uint8_t signs[QK_K/8];
|
||||||
uint8_t scales[IQ3S_N_SCALE];
|
uint8_t scales[IQ3S_N_SCALE];
|
||||||
} block_iq3_s;
|
} block_iq3_s;
|
||||||
#define IQ3S_MULTIPLIER 518559
|
|
||||||
constexpr constant static uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15};
|
// When a shuffle is involved in the codebook, on Metal it is faster to use a lookup table
|
||||||
|
//#define IQ3S_MULTIPLIER 518559
|
||||||
|
//constexpr constant static uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15};
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d;
|
half d;
|
||||||
@ -4085,6 +4087,73 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
|
|||||||
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
constexpr constant static uint32_t iq3s_grid[512] = {
|
||||||
|
0x01010101, 0x0105070f, 0x010f030d, 0x0105090b, 0x010f0509, 0x01050109, 0x010f0707, 0x01050307,
|
||||||
|
0x010f0905, 0x01050505, 0x010f0105, 0x01050703, 0x010d0303, 0x01050b03, 0x010d0501, 0x01050101,
|
||||||
|
0x010d0701, 0x0105030f, 0x010d0b0d, 0x0105050b, 0x010d0109, 0x01050709, 0x010d0307, 0x01030b07,
|
||||||
|
0x010b0505, 0x01030105, 0x010b0705, 0x01030303, 0x010b0b03, 0x01030503, 0x010b0101, 0x01030701,
|
||||||
|
0x010b0301, 0x01030b0f, 0x010b050d, 0x0103010b, 0x01090709, 0x01030309, 0x01090b07, 0x01030507,
|
||||||
|
0x01090105, 0x01030705, 0x01090305, 0x01030b03, 0x01090503, 0x01030103, 0x01090701, 0x01030301,
|
||||||
|
0x01090b01, 0x0103050f, 0x0109010d, 0x0103070b, 0x01090309, 0x01030b09, 0x01090507, 0x01030107,
|
||||||
|
0x01090705, 0x01030305, 0x01070d05, 0x01010503, 0x01070103, 0x01010703, 0x01070301, 0x01010d01,
|
||||||
|
0x01070501, 0x0101010f, 0x0107070d, 0x0101030b, 0x01070d09, 0x01010509, 0x01070107, 0x01010907,
|
||||||
|
0x01070305, 0x01010d05, 0x01070505, 0x01010103, 0x01070903, 0x01010303, 0x01070d01, 0x01010501,
|
||||||
|
0x01070101, 0x0101090f, 0x0105030d, 0x01010d0b, 0x01050509, 0x01010109, 0x01050907, 0x01010307,
|
||||||
|
0x01050d05, 0x01010505, 0x01050105, 0x01010903, 0x01050303, 0x010f0d03, 0x01050501, 0x010f0101,
|
||||||
|
0x01050901, 0x010f030f, 0x03050d0d, 0x030f050b, 0x03050109, 0x030f0909, 0x03050307, 0x030d0d07,
|
||||||
|
0x03050505, 0x030d0105, 0x03050905, 0x030d0303, 0x03050f03, 0x030d0503, 0x03050101, 0x030d0901,
|
||||||
|
0x03050301, 0x030d0f0f, 0x0305050d, 0x030b010b, 0x03030909, 0x030b0309, 0x03030f07, 0x030b0507,
|
||||||
|
0x03030105, 0x030b0905, 0x03030305, 0x030b0f03, 0x03030703, 0x030b0103, 0x03030901, 0x03090301,
|
||||||
|
0x03030f01, 0x0309070f, 0x0303010d, 0x0309090b, 0x03030309, 0x03090f09, 0x03030707, 0x03090107,
|
||||||
|
0x03030905, 0x03090505, 0x03030f05, 0x03090703, 0x03030103, 0x03090903, 0x03030501, 0x03090f01,
|
||||||
|
0x03030701, 0x0309030f, 0x0303090d, 0x0309050b, 0x03030f09, 0x03070709, 0x03010307, 0x03070907,
|
||||||
|
0x03010505, 0x03070105, 0x03010705, 0x03070303, 0x03010903, 0x03070503, 0x03010101, 0x03070701,
|
||||||
|
0x03010301, 0x0307090f, 0x0301050d, 0x0307010b, 0x03010709, 0x03070309, 0x03010b07, 0x03070507,
|
||||||
|
0x03010105, 0x03070705, 0x03010305, 0x03070b03, 0x03010503, 0x03050103, 0x03010701, 0x03050301,
|
||||||
|
0x03010b01, 0x0305050f, 0x0301010d, 0x0305070b, 0x03010309, 0x03050b09, 0x03010507, 0x03050107,
|
||||||
|
0x030f0705, 0x03050305, 0x030f0b05, 0x03050503, 0x030f0103, 0x03050703, 0x030f0301, 0x03050b01,
|
||||||
|
0x030f0501, 0x0305010f, 0x030f070d, 0x0505030b, 0x050d0b09, 0x05050509, 0x050d0107, 0x05050707,
|
||||||
|
0x050d0305, 0x05050b05, 0x050d0505, 0x05050103, 0x050d0703, 0x05050303, 0x050b0b01, 0x05030501,
|
||||||
|
0x050b0101, 0x0503070f, 0x050b030d, 0x05030d0b, 0x050b0509, 0x05030109, 0x050b0707, 0x05030307,
|
||||||
|
0x050b0d05, 0x05030505, 0x05090105, 0x05030903, 0x05090303, 0x05030d03, 0x05090501, 0x05030101,
|
||||||
|
0x05090901, 0x0503030f, 0x05090d0d, 0x0503050b, 0x05090109, 0x05030909, 0x05090307, 0x05030d07,
|
||||||
|
0x05090505, 0x05030105, 0x05090905, 0x05030303, 0x05090d03, 0x05030503, 0x05090101, 0x05030901,
|
||||||
|
0x05090301, 0x05010d0f, 0x0507050d, 0x0501010b, 0x05070909, 0x05010309, 0x05070d07, 0x05010507,
|
||||||
|
0x05070105, 0x05010905, 0x05070305, 0x05010d03, 0x05070503, 0x05010103, 0x05070901, 0x05010301,
|
||||||
|
0x05070f01, 0x0501050f, 0x0507010d, 0x0501090b, 0x05070309, 0x05010f09, 0x05070507, 0x05010107,
|
||||||
|
0x05050905, 0x05010305, 0x05050f05, 0x05010503, 0x05050103, 0x05010903, 0x05050301, 0x05010f01,
|
||||||
|
0x05050501, 0x0501010f, 0x0505090d, 0x050f030b, 0x05050f09, 0x050f0709, 0x05050107, 0x050f0907,
|
||||||
|
0x05050305, 0x050f0f05, 0x05050705, 0x050f0103, 0x05050903, 0x050f0503, 0x05050f01, 0x050d0701,
|
||||||
|
0x05050101, 0x050d090f, 0x0505050d, 0x050d0f0b, 0x07050709, 0x070d0109, 0x07050907, 0x070d0507,
|
||||||
|
0x07050f05, 0x070d0705, 0x07030305, 0x070b0903, 0x07030503, 0x070b0f03, 0x07030701, 0x070b0301,
|
||||||
|
0x07030901, 0x070b050f, 0x0703010d, 0x070b070b, 0x07030309, 0x07090909, 0x07030507, 0x07090107,
|
||||||
|
0x07030705, 0x07090305, 0x07030b05, 0x07090503, 0x07030103, 0x07090703, 0x07030301, 0x07090b01,
|
||||||
|
0x07030501, 0x0709010f, 0x0703070d, 0x0709030b, 0x07030b09, 0x07090509, 0x07030107, 0x07090707,
|
||||||
|
0x07030305, 0x07090b05, 0x07030505, 0x07090103, 0x07010703, 0x07070303, 0x07010b01, 0x07070501,
|
||||||
|
0x07010101, 0x0707070f, 0x0701030d, 0x07070b0b, 0x07010509, 0x07070109, 0x07010707, 0x07070307,
|
||||||
|
0x07010b05, 0x07070505, 0x07010105, 0x07070703, 0x07010303, 0x07070b03, 0x07010501, 0x07070101,
|
||||||
|
0x07010701, 0x0707030f, 0x07010b0d, 0x0705050b, 0x09010109, 0x09050709, 0x09010307, 0x09050b07,
|
||||||
|
0x09010505, 0x09050105, 0x09010705, 0x09050303, 0x09010d03, 0x09050503, 0x09010101, 0x09050701,
|
||||||
|
0x090f0301, 0x09050d0f, 0x090f050d, 0x0905010b, 0x090f0909, 0x09050309, 0x090f0d07, 0x09050507,
|
||||||
|
0x090f0105, 0x09050905, 0x090d0305, 0x09050d03, 0x090d0503, 0x09050103, 0x090d0901, 0x09050301,
|
||||||
|
0x090d0d01, 0x0905050f, 0x090d010d, 0x0905090b, 0x090d0309, 0x09030d09, 0x090b0507, 0x09030107,
|
||||||
|
0x090b0905, 0x09030305, 0x090b0d05, 0x09030503, 0x090b0103, 0x09030903, 0x090b0301, 0x09030d01,
|
||||||
|
0x090b0501, 0x0903010f, 0x0909090d, 0x0903030b, 0x09090d09, 0x09030509, 0x09090107, 0x09030907,
|
||||||
|
0x09090305, 0x09030f05, 0x09090505, 0x09030103, 0x09090903, 0x09030303, 0x09090f01, 0x09030501,
|
||||||
|
0x09090101, 0x0903090f, 0x0909030d, 0x09030f0b, 0x09090509, 0x0b030109, 0x0b090907, 0x0b030307,
|
||||||
|
0x0b070f05, 0x0b010505, 0x0b070105, 0x0b010903, 0x0b070303, 0x0b010f03, 0x0b070701, 0x0b010101,
|
||||||
|
0x0b070901, 0x0b01030f, 0x0b070f0d, 0x0b01070b, 0x0b070109, 0x0b010909, 0x0b070507, 0x0b010f07,
|
||||||
|
0x0b070705, 0x0b010105, 0x0b070905, 0x0b010503, 0x0b070f03, 0x0b010703, 0x0b070301, 0x0b010901,
|
||||||
|
0x0b050501, 0x0b010f0f, 0x0b05070d, 0x0b01030b, 0x0b050909, 0x0d010509, 0x0d050f07, 0x0d010707,
|
||||||
|
0x0d050305, 0x0d010905, 0x0d050505, 0x0d0f0103, 0x0d050703, 0x0d0f0303, 0x0d050901, 0x0d0f0501,
|
||||||
|
0x0d050101, 0x0d0f070f, 0x0d05030d, 0x0d0f0b0b, 0x0d050509, 0x0d0f0109, 0x0d050707, 0x0d0d0307,
|
||||||
|
0x0d050b05, 0x0d0d0505, 0x0d050105, 0x0d0d0703, 0x0d050303, 0x0d0d0b03, 0x0d050501, 0x0d0d0101,
|
||||||
|
0x0d050701, 0x0d0b030f, 0x0d030b0d, 0x0d0b050b, 0x0d030109, 0x0d0b0709, 0x0f030307, 0x0f0b0b07,
|
||||||
|
0x0f030505, 0x0f0b0105, 0x0f030705, 0x0f0b0303, 0x0f030b03, 0x0f090503, 0x0f030101, 0x0f090701,
|
||||||
|
0x0f030301, 0x0f090b0f, 0x0f03050d, 0x0f09010b, 0x0f030709, 0x0f090309, 0x0f030b07, 0x0f090507,
|
||||||
|
0x0f030105, 0x0f090705, 0x0f030305, 0x0f090b03, 0x0f030503, 0x0f090103, 0x0f030701, 0x0f090301,
|
||||||
|
};
|
||||||
|
|
||||||
#define NGRID_IQ1S 512
|
#define NGRID_IQ1S 512
|
||||||
constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
|
constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
|
||||||
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
||||||
@ -4694,20 +4763,23 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|||||||
{
|
{
|
||||||
int nval = 8;
|
int nval = 8;
|
||||||
int pos = (32*sgitg + tiisg)*nval;
|
int pos = (32*sgitg + tiisg)*nval;
|
||||||
uint32_t aux32;
|
|
||||||
thread int8_t * q = (thread int8_t *)&aux32;
|
|
||||||
for (int i = 0; i < nval; ++i) {
|
for (int i = 0; i < nval; ++i) {
|
||||||
aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f;
|
values[pos + i] = iq3s_grid[pos + i];
|
||||||
for (int k = 0; k < 4; ++k) q[k] = iq3s_values[q[k]];
|
|
||||||
values[pos + i] = aux32;
|
|
||||||
}
|
}
|
||||||
|
//uint32_t aux32;
|
||||||
|
//thread int8_t * q = (thread int8_t *)&aux32;
|
||||||
|
//for (int i = 0; i < nval; ++i) {
|
||||||
|
// aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f;
|
||||||
|
// for (int k = 0; k < 4; ++k) q[k] = iq3s_values[q[k]];
|
||||||
|
// values[pos + i] = aux32;
|
||||||
|
//}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int ix = tiisg;
|
const int ix = tiisg;
|
||||||
|
|
||||||
uint32_t aux32[2];
|
//uint32_t aux32[2];
|
||||||
thread const int8_t * grid = (thread const int8_t *)aux32;
|
//thread const int8_t * grid = (thread const int8_t *)aux32;
|
||||||
|
|
||||||
device const float * y4 = y + 32 * ix;
|
device const float * y4 = y + 32 * ix;
|
||||||
|
|
||||||
@ -4735,11 +4807,11 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|||||||
float2 sum = {0};
|
float2 sum = {0};
|
||||||
for (int l = 0; l < 4; ++l) {
|
for (int l = 0; l < 4; ++l) {
|
||||||
// This is slower than pre-computing the grid in shared memory and loading from there
|
// This is slower than pre-computing the grid in shared memory and loading from there
|
||||||
//aux32[0] = ((IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101;
|
//aux32[0] = (IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))) & 0x0f0f0f0f;
|
||||||
//aux32[1] = ((IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101;
|
//aux32[1] = (IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))) & 0x0f0f0f0f;
|
||||||
//for (int j = 0; j < 4; ++j) {
|
//for (int j = 0; j < 4; ++j) {
|
||||||
// sum[0] += yl[8*l + j + 0] * grid[j+0] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
// sum[0] += yl[8*l + j + 0] * iq3s_values[grid[j+0]] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
||||||
// sum[1] += yl[8*l + j + 4] * grid[j+4] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
// sum[1] += yl[8*l + j + 4] * iq3s_values[grid[j+4]] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
||||||
//}
|
//}
|
||||||
threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
|
threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
|
||||||
threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
|
threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
|
||||||
@ -5655,20 +5727,32 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 &
|
|||||||
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
|
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
|
||||||
const uint8_t qh = xb->qh[ib32] >> 4*il;
|
const uint8_t qh = xb->qh[ib32] >> 4*il;
|
||||||
const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
|
const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
|
||||||
uint32_t aux32[2];
|
constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
|
||||||
thread const int8_t * grid = (thread const int8_t *)aux32;
|
constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
|
||||||
aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f;
|
|
||||||
aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f;
|
|
||||||
for (int i = 0; i < 4; ++i) {
|
for (int i = 0; i < 4; ++i) {
|
||||||
reg[0][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
|
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
|
||||||
reg[1][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
|
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
|
||||||
}
|
}
|
||||||
aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+2] | ((qh << 6) & 256))) & 0x0f0f0f0f;
|
grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
|
||||||
aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+3] | ((qh << 5) & 256))) & 0x0f0f0f0f;
|
grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
|
||||||
for (int i = 0; i < 4; ++i) {
|
for (int i = 0; i < 4; ++i) {
|
||||||
reg[2][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
|
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
|
||||||
reg[3][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
|
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
|
||||||
}
|
}
|
||||||
|
//uint32_t aux32[2];
|
||||||
|
//thread const int8_t * grid = (thread const int8_t *)aux32;
|
||||||
|
//aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f;
|
||||||
|
//aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f;
|
||||||
|
//for (int i = 0; i < 4; ++i) {
|
||||||
|
// reg[0][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
|
||||||
|
// reg[1][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
|
||||||
|
//}
|
||||||
|
//aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+2] | ((qh << 6) & 256))) & 0x0f0f0f0f;
|
||||||
|
//aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+3] | ((qh << 5) & 256))) & 0x0f0f0f0f;
|
||||||
|
//for (int i = 0; i < 4; ++i) {
|
||||||
|
// reg[2][i] = dl * iq3s_values[grid[i+0]] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
|
||||||
|
// reg[3][i] = dl * iq3s_values[grid[i+4]] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
|
||||||
|
//}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
|
Loading…
Reference in New Issue
Block a user