mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 07:34:18 +01:00
Add q3_s and q1_s
This commit is contained in:
parent
21b0867433
commit
ad251954eb
588
ggml-sycl.cpp
588
ggml-sycl.cpp
@ -3492,6 +3492,31 @@ typedef struct dpct_type_block_iq3_xxs {
|
||||
} block_iq3_xxs;
|
||||
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
|
||||
|
||||
#define QR3_XS 8
|
||||
#define QI3_XS (QK_K / (4*QR3_XS))
|
||||
#if QK_K == 64
|
||||
#define IQ3S_N_SCALE 2
|
||||
#else
|
||||
#define IQ3S_N_SCALE QK_K/64
|
||||
#endif
|
||||
typedef struct {
|
||||
sycl::half d;
|
||||
uint8_t qs[QK_K/4];
|
||||
uint8_t qh[QK_K/32];
|
||||
uint8_t signs[QK_K/8];
|
||||
uint8_t scales[IQ3S_N_SCALE];
|
||||
} block_iq3_s;
|
||||
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
|
||||
|
||||
#define QR1_S 8
|
||||
#define QI1_S (QK_K / (4*QR1_S))
|
||||
typedef struct {
|
||||
sycl::half d;
|
||||
uint8_t qs[QK_K/8];
|
||||
uint8_t scales[QK_K/16];
|
||||
} block_iq1_s;
|
||||
static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
|
||||
|
||||
#define WARP_SIZE 32
|
||||
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
||||
|
||||
@ -5025,7 +5050,207 @@ static dpct::global_memory<const uint32_t, 1> iq3xxs_grid(
|
||||
0x3e1c0404, 0x3e1c0c2c, 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c,
|
||||
0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
||||
});
|
||||
static dpct::global_memory<const uint32_t, 1> iq3s_grid(
|
||||
sycl::range<1>(256),
|
||||
{
|
||||
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
|
||||
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
|
||||
0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
|
||||
0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
|
||||
0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
|
||||
0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
|
||||
0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
|
||||
0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
|
||||
0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
|
||||
0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
|
||||
0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
|
||||
0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
|
||||
0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
|
||||
0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
|
||||
0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
|
||||
0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
|
||||
0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
|
||||
0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
|
||||
0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
|
||||
0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
|
||||
0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
|
||||
0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
|
||||
0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
|
||||
0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
|
||||
0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
|
||||
0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
|
||||
0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
|
||||
0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
|
||||
0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
|
||||
0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
|
||||
0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
|
||||
0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
|
||||
0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
|
||||
0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
|
||||
0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
|
||||
0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
|
||||
0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
|
||||
0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
|
||||
0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
|
||||
0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
|
||||
0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
|
||||
0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
|
||||
0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
|
||||
0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
|
||||
0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
|
||||
0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
|
||||
0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
|
||||
0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
|
||||
0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
|
||||
0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
|
||||
0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
|
||||
0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
|
||||
0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
|
||||
0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
|
||||
0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
|
||||
0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
|
||||
0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
|
||||
0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
|
||||
0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
|
||||
0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
|
||||
0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
|
||||
0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
|
||||
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
|
||||
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
|
||||
};
|
||||
|
||||
static dpct::global_memory<const uint64_t, 1> iq1s_grid(
|
||||
sycl::range<1>(256),
|
||||
{
|
||||
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
||||
0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
|
||||
0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
|
||||
0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
|
||||
0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
|
||||
0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
|
||||
0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
|
||||
0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
|
||||
0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
|
||||
0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
|
||||
0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
|
||||
0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
|
||||
0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
|
||||
0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
|
||||
0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
|
||||
0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
|
||||
0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
|
||||
0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
|
||||
0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
|
||||
0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
|
||||
0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
|
||||
0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
|
||||
0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
|
||||
0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
|
||||
0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
|
||||
0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
|
||||
0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
|
||||
0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
|
||||
0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
|
||||
0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
|
||||
0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
|
||||
0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
|
||||
0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
|
||||
0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
|
||||
0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
|
||||
0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
|
||||
0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
|
||||
0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
|
||||
0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
|
||||
0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
|
||||
0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
|
||||
0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
|
||||
0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
|
||||
0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
|
||||
0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
|
||||
0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
|
||||
0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
|
||||
0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
|
||||
0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
|
||||
0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
|
||||
0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
|
||||
0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
|
||||
0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
|
||||
0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
|
||||
0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
|
||||
0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
|
||||
0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
|
||||
0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
|
||||
0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
|
||||
0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
|
||||
0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
|
||||
0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
|
||||
0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
|
||||
0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
|
||||
0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
|
||||
0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
|
||||
0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
|
||||
0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
|
||||
0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
|
||||
0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
|
||||
0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
|
||||
0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
|
||||
0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
|
||||
0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
|
||||
0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
|
||||
0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
|
||||
0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
|
||||
0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
|
||||
0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
|
||||
0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
|
||||
0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
|
||||
0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
|
||||
0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
|
||||
0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
|
||||
0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
|
||||
0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
|
||||
0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
|
||||
0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
|
||||
0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
|
||||
0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
|
||||
0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
|
||||
0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
|
||||
0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
|
||||
0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
|
||||
0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
|
||||
0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
|
||||
0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
|
||||
0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
|
||||
0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
|
||||
0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
|
||||
0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
|
||||
0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
|
||||
0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
|
||||
0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
|
||||
0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
|
||||
0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
|
||||
0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
|
||||
0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
|
||||
0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
|
||||
0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
|
||||
0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
|
||||
0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
|
||||
0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
|
||||
0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
|
||||
0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
|
||||
0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
|
||||
0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
|
||||
0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
|
||||
0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
|
||||
0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
|
||||
0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
|
||||
0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
|
||||
0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
|
||||
0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
|
||||
0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
|
||||
0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
|
||||
0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
|
||||
0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
|
||||
};
|
||||
static dpct::global_memory<const uint8_t, 1> ksigns_iq2xs(
|
||||
sycl::range<1>(128),
|
||||
{
|
||||
@ -5179,6 +5404,62 @@ static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __res
|
||||
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||
const sycl::nd_item<3> &item_ct1,
|
||||
const uint32_t *iq3s_grid,
|
||||
const uint8_t *ksigns_iq2xs,
|
||||
const uint8_t *kmask_iq2xs) {
|
||||
|
||||
const int i = item_ct1.get_group(2);
|
||||
const block_iq3_s * x = (const block_iq3_s *) vx;
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
#if QK_K == 256
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
const uint8_t * qs = x[i].qs + 8*ib;
|
||||
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + q3[2*il+0]);
|
||||
const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + q3[2*il+1]);
|
||||
const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
|
||||
const uint8_t signs = x[i].signs[4*ib + il];
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
||||
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
||||
}
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||
const sycl::nd_item<3> &item_ct1,
|
||||
const uint32_t *iq3s_grid,
|
||||
const uint8_t *ksigns_iq2xs,
|
||||
const uint8_t *kmask_iq2xs) {
|
||||
|
||||
const int i = item_ct1.get_group(2);
|
||||
const block_iq1_s * x = (const block_iq1_s *) vx;
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
#if QK_K == 256
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
const int i8 = 4*ib+il;
|
||||
uint8_t h = x[i].scales[i8/2] >> 4*(i8%2);
|
||||
const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5)));
|
||||
const float d = (float)x[i].d * (2*(h & 7) + 1);
|
||||
for (int j = 0; j < 8; ++j) y[j] = d * grid[j];
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
/*
|
||||
DPCT1110:4: The total declared local variable size in device function
|
||||
dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register
|
||||
@ -8025,6 +8306,91 @@ vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
|
||||
#endif
|
||||
}
|
||||
|
||||
static __dpct_inline__ float
|
||||
vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
|
||||
const block_q8_1 *__restrict__ bq8_1, const int &iqs,
|
||||
const uint32_t *iq3s_grid, const uint64_t *ksigns64) {
|
||||
#if DPCT_COMPATIBILITY_TEMP >= \
|
||||
MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||
#if QK_K == 256
|
||||
const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
|
||||
|
||||
const int ib32 = iqs;
|
||||
const uint8_t * qs = bq2->qs + 8*ib32;
|
||||
const int8_t * q8 = bq8_1[ib32].qs;
|
||||
int sumi = 0;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
const uint32_t * grid1 = iq3s_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
|
||||
const uint32_t * grid2 = iq3s_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
|
||||
uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
|
||||
((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>());
|
||||
uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
|
||||
((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>());
|
||||
const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
|
||||
grid1[0] ^ signs0, signs0, std::minus<>());
|
||||
const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
|
||||
grid2[0] ^ signs1, signs1, std::minus<>());
|
||||
sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
|
||||
sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
|
||||
q8 += 8;
|
||||
aux32 >>= 7;
|
||||
}
|
||||
const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * (0.5f + aux32) * bq8_1[ib32].ds[0];
|
||||
return d * sumi;
|
||||
#else
|
||||
assert(false);
|
||||
return 0.f;
|
||||
#endif
|
||||
#else
|
||||
assert(false);
|
||||
return 0.f;
|
||||
#endif
|
||||
}
|
||||
|
||||
static __dpct_inline__ float
|
||||
vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
|
||||
const block_q8_1 *__restrict__ bq8_1, const int &iqs,
|
||||
const uint64_t *iq1s_grid, const uint64_t *ksigns64) {
|
||||
#if QK_K == 256
|
||||
const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
|
||||
|
||||
const int ib32 = iqs;
|
||||
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
|
||||
const uint8_t h1 = bq1->scales[2*ib32+0];
|
||||
const uint8_t h2 = bq1->scales[2*ib32+1];
|
||||
#if DPCT_COMPATIBILITY_TEMP >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||
const int * q8 = (const int *)bq8_1[ib32].qs;
|
||||
const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
|
||||
const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
|
||||
const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
|
||||
const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
sumi1 = dpct::dp4a(q8[j+0], grid1[j], sumi1);
|
||||
sumi2 = dpct::dp4a(q8[j+2], grid2[j], sumi2);
|
||||
sumi3 = dpct::dp4a(q8[j+4], grid3[j], sumi3);
|
||||
sumi4 = dpct::dp4a(q8[j+6], grid4[j], sumi4);
|
||||
}
|
||||
#else
|
||||
const int8_t * q8 = bq8_1[ib32].qs;
|
||||
const int8_t * grid1 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
|
||||
const int8_t * grid2 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
|
||||
const int8_t * grid3 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
|
||||
const int8_t * grid4 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
sumi1 += q8[j+ 0] * grid1[j];
|
||||
sumi2 += q8[j+ 8] * grid2[j];
|
||||
sumi3 += q8[j+16] * grid3[j];
|
||||
sumi4 += q8[j+24] * grid4[j];
|
||||
}
|
||||
#endif
|
||||
const float d = (float)bq1->d * bq8_1[ib32].ds[0];
|
||||
return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) +
|
||||
sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1));
|
||||
#else
|
||||
assert(false);
|
||||
return 0.f;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x,
|
||||
int mmq_y, int nwarps, load_tiles_sycl_t load_tiles, int vdr,
|
||||
@ -8790,6 +9156,98 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void * __restrict__ vx, const void
|
||||
}
|
||||
}
|
||||
|
||||
template <int qk, int qi, typename block_q_t, int vdr>
|
||||
static void mul_mat_vec_q_iq3_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
|
||||
const sycl::nd_item<3> &item_ct1,
|
||||
const uint32_t *iq3s_grid_ptr, const uint64_t *ksigns64_ptr ) {
|
||||
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
||||
item_ct1.get_local_id(1);
|
||||
|
||||
if (row >= nrows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int blocks_per_row = ncols / qk;
|
||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
||||
|
||||
// partial sum for each thread
|
||||
float tmp = 0.0f;
|
||||
|
||||
const block_q_t * x = (const block_q_t *) vx;
|
||||
const block_q8_1 * y = (const block_q8_1 *) vy;
|
||||
|
||||
for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
|
||||
i += blocks_per_warp) {
|
||||
const int ibx = row*blocks_per_row + i; // x block index
|
||||
|
||||
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
|
||||
|
||||
const int iqs =
|
||||
vdr *
|
||||
(item_ct1.get_local_id(2) %
|
||||
(qi / vdr)); // x block quant index when casting the quants to int
|
||||
|
||||
tmp += vec_dot_iq3_s_q8_1(&x[ibx], &y[iby], iqs, iq3s_grid_ptr, ksigns64_ptr);
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
||||
if (item_ct1.get_local_id(2) == 0) {
|
||||
dst[row] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
template <int qk, int qi, typename block_q_t, int vdr>
|
||||
static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
|
||||
const sycl::nd_item<3> &item_ct1,
|
||||
const uint32_t *iq1s_grid_ptr, const uint64_t *ksigns64_ptr ) {
|
||||
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
||||
item_ct1.get_local_id(1);
|
||||
|
||||
if (row >= nrows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int blocks_per_row = ncols / qk;
|
||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
||||
|
||||
// partial sum for each thread
|
||||
float tmp = 0.0f;
|
||||
|
||||
const block_q_t * x = (const block_q_t *) vx;
|
||||
const block_q8_1 * y = (const block_q8_1 *) vy;
|
||||
|
||||
for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
|
||||
i += blocks_per_warp) {
|
||||
const int ibx = row*blocks_per_row + i; // x block index
|
||||
|
||||
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
|
||||
|
||||
const int iqs =
|
||||
vdr *
|
||||
(item_ct1.get_local_id(2) %
|
||||
(qi / vdr)); // x block quant index when casting the quants to int
|
||||
|
||||
tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_ptr, ksigns64_ptr);
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
||||
if (item_ct1.get_local_id(2) == 0) {
|
||||
dst[row] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||
static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
@ -10475,6 +10933,64 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int nb = k / QK_K;
|
||||
{
|
||||
iq3s_grid.init(*stream);
|
||||
ksigns_iq2xs.init(*stream);
|
||||
kmask_iq2xs.init(*stream);
|
||||
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
auto iq3s_grid_ptr_ct1 = iq3s_grid.get_ptr();
|
||||
auto ksigns_iq2xs_ptr_ct1 = ksigns_iq2xs.get_ptr();
|
||||
auto kmask_iq2xs_ptr_ct1 = kmask_iq2xs.get_ptr();
|
||||
|
||||
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_iq3_s(
|
||||
vx, y, item_ct1, iq3s_grid_ptr_ct1,
|
||||
ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int nb = k / QK_K;
|
||||
{
|
||||
iq1s_grid.init(*stream);
|
||||
ksigns_iq2xs.init(*stream);
|
||||
kmask_iq2xs.init(*stream);
|
||||
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
auto iq1s_grid_ptr_ct1 = iq1s_grid.get_ptr();
|
||||
auto ksigns_iq2xs_ptr_ct1 = ksigns_iq2xs.get_ptr();
|
||||
auto kmask_iq2xs_ptr_ct1 = kmask_iq2xs.get_ptr();
|
||||
|
||||
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_iq1_s(
|
||||
vx, y, item_ct1, iq1s_grid_ptr_ct1,
|
||||
ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
static void convert_unary_sycl(const void *__restrict__ vx,
|
||||
dst_t *__restrict__ y, const int k,
|
||||
@ -10525,6 +11041,10 @@ static to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) try {
|
||||
return dequantize_row_iq2_xs_sycl;
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
return dequantize_row_iq3_xxs_sycl;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
return dequantize_row_iq3_s_sycl;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
return dequantize_row_iq1_s_sycl;
|
||||
case GGML_TYPE_F32:
|
||||
return convert_unary_sycl<float>;
|
||||
default:
|
||||
@ -10565,6 +11085,10 @@ static to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) {
|
||||
return dequantize_row_iq2_xs_sycl;
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
return dequantize_row_iq3_xxs_sycl;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
return dequantize_row_iq3_s_sycl;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
return dequantize_row_iq1_s_sycl;
|
||||
case GGML_TYPE_F16:
|
||||
return convert_unary_sycl<sycl::half>;
|
||||
default:
|
||||
@ -11154,6 +11678,61 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
|
||||
}
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
|
||||
float *dst, const int ncols,
|
||||
const int nrows,
|
||||
dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
iq3s_grid.init(*stream);
|
||||
ksigns64.init(*stream);
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
auto iq3s_grid_ptr_ct1 = iq3s_grid.get_ptr();
|
||||
auto ksigns64_ptr_ct1 = ksigns64.get_ptr();
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(32)]] {
|
||||
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S, block_iq3_s, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1,
|
||||
iq3s_grid_ptr_ct1, ksigns64_ptr_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
|
||||
float *dst, const int ncols,
|
||||
const int nrows,
|
||||
dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
iq1s_grid.init(*stream);
|
||||
ksigns64.init(*stream);
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
auto iq1s_grid_ptr_ct1 = iq1s_grid.get_ptr();
|
||||
auto ksigns64_ptr_ct1 = ksigns64.get_ptr();
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(32)]] {
|
||||
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1,
|
||||
iq1s_grid_ptr_ct1, ksigns64_ptr_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
float *dst, const int ncols_x,
|
||||
@ -13902,8 +14481,11 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYC
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
return max_compute_capability >= VER_GEN9 ? 128 : 64;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
return max_compute_capability >= VER_GEN9 ? 128 : 64;
|
||||
case GGML_TYPE_Q6_K:
|
||||
return 64;
|
||||
default:
|
||||
@ -13964,6 +14546,12 @@ inline void ggml_sycl_op_mul_mat_vec_q(
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
|
Loading…
Reference in New Issue
Block a user