mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 05:48:47 +01:00
SOTA 2-bit quants (#4773)
* iq2_xxs: basics * iq2_xxs: scalar and AVX2 dot products Needed to change Q8_K to have quants in the -127...127 range, else the IQ2_XXS AVX implementation becomes very awkward. The alternative would have been to use Q8_0 instead. Perhaps I'll change later, for now this is what we have. * iq2_xxs: ARM_NEON dot product Somehow strangely slow (112 ms/token). * iq2_xxs: WIP Metal Dequantize works, something is still wrong with the dot product. * iq2_xxs: Metal dot product now works We have PP-512 = 475 t/s TG-128 = 47.3 t/s Not the greatest performance, but not complete garbage either. * iq2_xxs: slighty faster dot product TG-128 is now 48.4 t/s * iq2_xxs: slighty faster dot product TG-128 is now 50.9 t/s * iq2_xxs: even faster Metal dot product TG-128 is now 54.1 t/s. Strangely enough, putting the signs lookup table into shared memory has a bigger impact than the grid values being in shared memory. * iq2_xxs: dequantize CUDA kernel - fix conflict with master * iq2_xxs: quantized CUDA dot product (MMVQ) We get TG-128 = 153.1 t/s * iq2_xxs: slightly faster CUDA dot product TG-128 is now at 155.1 t/s. * iq2_xxs: add to llama ftype enum * iq2_xxs: fix MoE on Metal * Fix missing MMQ ops when on hipBLAS I had put the ggml_supports_mmq call at the wrong place. * Fix bug in qequantize_row_iq2_xxs The 0.25f factor was missing. Great detective work by @ggerganov! * Fixing tests * PR suggestion --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
parent
668b31fc7d
commit
dd5ae06405
205
ggml-cuda.cu
205
ggml-cuda.cu
@ -477,6 +477,14 @@ typedef struct {
|
|||||||
} block_q6_K;
|
} block_q6_K;
|
||||||
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
|
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
|
||||||
|
|
||||||
|
#define QR2_XXS 8
|
||||||
|
#define QI2_XXS (QK_K / (4*QR2_XXS))
|
||||||
|
typedef struct {
|
||||||
|
half d;
|
||||||
|
uint16_t qs[QK_K/8];
|
||||||
|
} block_iq2_xxs;
|
||||||
|
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
|
||||||
|
|
||||||
#define WARP_SIZE 32
|
#define WARP_SIZE 32
|
||||||
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
||||||
|
|
||||||
@ -1292,6 +1300,128 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static const __device__ uint64_t kgrid_iq2xxs[256] = {
|
||||||
|
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
||||||
|
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
||||||
|
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
||||||
|
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
|
||||||
|
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
|
||||||
|
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
|
||||||
|
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
|
||||||
|
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
|
||||||
|
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
|
||||||
|
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
|
||||||
|
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
|
||||||
|
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
|
||||||
|
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
|
||||||
|
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
|
||||||
|
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
|
||||||
|
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
|
||||||
|
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
|
||||||
|
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
|
||||||
|
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
|
||||||
|
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
|
||||||
|
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
|
||||||
|
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
|
||||||
|
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
|
||||||
|
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
|
||||||
|
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
|
||||||
|
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
|
||||||
|
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
|
||||||
|
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
|
||||||
|
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
|
||||||
|
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
|
||||||
|
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
|
||||||
|
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
|
||||||
|
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
|
||||||
|
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
|
||||||
|
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
|
||||||
|
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
|
||||||
|
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
|
||||||
|
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
|
||||||
|
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
|
||||||
|
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
|
||||||
|
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
|
||||||
|
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
|
||||||
|
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
|
||||||
|
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
|
||||||
|
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
|
||||||
|
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
|
||||||
|
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
|
||||||
|
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
|
||||||
|
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
|
||||||
|
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
|
||||||
|
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
|
||||||
|
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
|
||||||
|
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
|
||||||
|
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
|
||||||
|
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
|
||||||
|
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
|
||||||
|
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
|
||||||
|
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
|
||||||
|
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
|
||||||
|
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
|
||||||
|
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
|
||||||
|
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
|
||||||
|
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
|
||||||
|
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
||||||
|
};
|
||||||
|
|
||||||
|
static const __device__ uint8_t ksigns_iq2xs[128] = {
|
||||||
|
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
||||||
|
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
||||||
|
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
|
||||||
|
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
|
||||||
|
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
|
||||||
|
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
|
||||||
|
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
|
||||||
|
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
||||||
|
};
|
||||||
|
|
||||||
|
static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
|
||||||
|
|
||||||
|
inline bool ggml_cuda_supports_mmq(enum ggml_type type) {
|
||||||
|
switch (type) {
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
case GGML_TYPE_Q4_1:
|
||||||
|
case GGML_TYPE_Q5_0:
|
||||||
|
case GGML_TYPE_Q5_1:
|
||||||
|
case GGML_TYPE_Q8_0:
|
||||||
|
case GGML_TYPE_Q2_K:
|
||||||
|
case GGML_TYPE_Q3_K:
|
||||||
|
case GGML_TYPE_Q4_K:
|
||||||
|
case GGML_TYPE_Q5_K:
|
||||||
|
case GGML_TYPE_Q6_K:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
|
||||||
|
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
#if QK_K == 256
|
||||||
|
const int il = tid/8; // 0...3
|
||||||
|
const int ib = tid%8; // 0...7
|
||||||
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
|
const uint16_t * q2 = x[i].qs + 4*ib;
|
||||||
|
const uint8_t * aux8 = (const uint8_t *)q2;
|
||||||
|
const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[il]);
|
||||||
|
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
||||||
|
const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
|
||||||
|
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
||||||
|
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
||||||
|
|
||||||
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
||||||
@ -3825,6 +3955,55 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
|
|||||||
return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
|
return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
||||||
|
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||||
|
#if QK_K == 256
|
||||||
|
const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
|
||||||
|
|
||||||
|
#if QR2_XXS == 8
|
||||||
|
const int ib32 = iqs;
|
||||||
|
const uint16_t * q2 = bq2->qs + 4*ib32;
|
||||||
|
const uint8_t * aux8 = (const uint8_t *)q2;
|
||||||
|
const int8_t * q8 = bq8_1[ib32].qs;
|
||||||
|
uint32_t aux32 = q2[2] | (q2[3] << 16);
|
||||||
|
int sumi = 0;
|
||||||
|
for (int l = 0; l < 4; ++l) {
|
||||||
|
const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[l]);
|
||||||
|
const uint8_t signs = ksigns_iq2xs[aux32 & 127];
|
||||||
|
for (int j = 0; j < 8; ++j) {
|
||||||
|
sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
||||||
|
}
|
||||||
|
q8 += 8;
|
||||||
|
aux32 >>= 7;
|
||||||
|
}
|
||||||
|
const float d = (float)bq2->d * (0.5f + aux32) * (float)bq8_1[ib32].ds.x * 0.25f;
|
||||||
|
return d * sumi;
|
||||||
|
#else
|
||||||
|
// iqs is 0...15
|
||||||
|
const int ib32 = iqs/2;
|
||||||
|
const int il = iqs%2;
|
||||||
|
const uint16_t * q2 = bq2->qs + 4*ib32;
|
||||||
|
const uint8_t * aux8 = (const uint8_t *)q2;
|
||||||
|
const uint8_t * grid1 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]);
|
||||||
|
const uint8_t * grid2 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]);
|
||||||
|
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
||||||
|
const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f;
|
||||||
|
const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
|
||||||
|
const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
|
||||||
|
const int8_t * q8 = bq8_1[ib32].qs + 16*il;
|
||||||
|
int sumi1 = 0, sumi2 = 0;
|
||||||
|
for (int j = 0; j < 8; ++j) {
|
||||||
|
sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1);
|
||||||
|
sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1);
|
||||||
|
}
|
||||||
|
return d * (sumi1 + sumi2);
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
assert(false);
|
||||||
|
return 0.f;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
||||||
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
||||||
static __device__ __forceinline__ void mul_mat_q(
|
static __device__ __forceinline__ void mul_mat_q(
|
||||||
@ -5664,6 +5843,12 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename src_t, typename dst_t>
|
template <typename src_t, typename dst_t>
|
||||||
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
||||||
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
||||||
@ -5692,6 +5877,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|||||||
return dequantize_row_q5_K_cuda;
|
return dequantize_row_q5_K_cuda;
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
return dequantize_row_q6_K_cuda;
|
return dequantize_row_q6_K_cuda;
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
|
return dequantize_row_iq2_xxs_cuda;
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
return convert_unary_cuda<float>;
|
return convert_unary_cuda<float>;
|
||||||
default:
|
default:
|
||||||
@ -5721,6 +5908,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|||||||
return dequantize_row_q5_K_cuda;
|
return dequantize_row_q5_K_cuda;
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
return dequantize_row_q6_K_cuda;
|
return dequantize_row_q6_K_cuda;
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
|
return dequantize_row_iq2_xxs_cuda;
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
return convert_unary_cuda<half>;
|
return convert_unary_cuda<half>;
|
||||||
default:
|
default:
|
||||||
@ -5915,6 +6104,15 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
|
|||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
|
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||||
|
const dim3 block_nums(block_num_y, 1, 1);
|
||||||
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||||
|
mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_mul_mat_q4_0_q8_1_cuda(
|
static void ggml_mul_mat_q4_0_q8_1_cuda(
|
||||||
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||||
@ -7407,6 +7605,7 @@ static int64_t get_row_rounding(ggml_type type) {
|
|||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
|
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
@ -7427,6 +7626,7 @@ static int64_t get_row_rounding(ggml_type type) {
|
|||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
return 64;
|
return 64;
|
||||||
@ -7477,6 +7677,9 @@ static void ggml_cuda_op_mul_mat_vec_q(
|
|||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
||||||
break;
|
break;
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
|
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
break;
|
break;
|
||||||
@ -8693,6 +8896,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|||||||
|
|
||||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
|
||||||
|
use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq(src0->type);
|
||||||
|
|
||||||
// debug helpers
|
// debug helpers
|
||||||
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
|
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
|
||||||
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
|
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
|
||||||
|
40
ggml-metal.m
40
ggml-metal.m
@ -88,6 +88,7 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_i32);
|
GGML_METAL_DECL_KERNEL(get_rows_i32);
|
||||||
|
GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs);
|
||||||
GGML_METAL_DECL_KERNEL(rms_norm);
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
||||||
GGML_METAL_DECL_KERNEL(group_norm);
|
GGML_METAL_DECL_KERNEL(group_norm);
|
||||||
GGML_METAL_DECL_KERNEL(norm);
|
GGML_METAL_DECL_KERNEL(norm);
|
||||||
@ -106,6 +107,7 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
|
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
|
||||||
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
|
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
|
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
|
||||||
@ -121,6 +123,7 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
||||||
@ -133,6 +136,7 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
|
||||||
@ -145,6 +149,7 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32);
|
||||||
GGML_METAL_DECL_KERNEL(rope_f32);
|
GGML_METAL_DECL_KERNEL(rope_f32);
|
||||||
GGML_METAL_DECL_KERNEL(rope_f16);
|
GGML_METAL_DECL_KERNEL(rope_f16);
|
||||||
GGML_METAL_DECL_KERNEL(alibi_f32);
|
GGML_METAL_DECL_KERNEL(alibi_f32);
|
||||||
@ -379,6 +384,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_i32);
|
GGML_METAL_ADD_KERNEL(get_rows_i32);
|
||||||
|
GGML_METAL_ADD_KERNEL(get_rows_iq2_xxs);
|
||||||
GGML_METAL_ADD_KERNEL(rms_norm);
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
||||||
GGML_METAL_ADD_KERNEL(group_norm);
|
GGML_METAL_ADD_KERNEL(group_norm);
|
||||||
GGML_METAL_ADD_KERNEL(norm);
|
GGML_METAL_ADD_KERNEL(norm);
|
||||||
@ -397,6 +403,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mv_iq2_xxs_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
|
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
|
||||||
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
|
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
|
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
|
||||||
@ -412,6 +419,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xxs_f32);
|
||||||
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
||||||
@ -425,6 +433,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_iq2_xxs_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
|
||||||
@ -437,6 +446,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xxs_f32);
|
||||||
}
|
}
|
||||||
GGML_METAL_ADD_KERNEL(rope_f32);
|
GGML_METAL_ADD_KERNEL(rope_f32);
|
||||||
GGML_METAL_ADD_KERNEL(rope_f16);
|
GGML_METAL_ADD_KERNEL(rope_f16);
|
||||||
@ -502,6 +512,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
||||||
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
||||||
GGML_METAL_DEL_KERNEL(get_rows_i32);
|
GGML_METAL_DEL_KERNEL(get_rows_i32);
|
||||||
|
GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs);
|
||||||
GGML_METAL_DEL_KERNEL(rms_norm);
|
GGML_METAL_DEL_KERNEL(rms_norm);
|
||||||
GGML_METAL_DEL_KERNEL(group_norm);
|
GGML_METAL_DEL_KERNEL(group_norm);
|
||||||
GGML_METAL_DEL_KERNEL(norm);
|
GGML_METAL_DEL_KERNEL(norm);
|
||||||
@ -520,6 +531,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
|
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
|
||||||
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
|
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
|
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
|
||||||
@ -535,6 +547,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32);
|
||||||
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
||||||
@ -548,6 +561,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
|
||||||
@ -560,6 +574,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32);
|
||||||
}
|
}
|
||||||
GGML_METAL_DEL_KERNEL(rope_f32);
|
GGML_METAL_DEL_KERNEL(rope_f32);
|
||||||
GGML_METAL_DEL_KERNEL(rope_f16);
|
GGML_METAL_DEL_KERNEL(rope_f16);
|
||||||
@ -1541,6 +1556,7 @@ bool ggml_metal_graph_compute(
|
|||||||
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
|
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
|
||||||
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
|
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
|
||||||
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
|
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
|
||||||
|
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xxs_f32]; break;
|
||||||
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
||||||
}
|
}
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
@ -1653,6 +1669,12 @@ bool ggml_metal_graph_compute(
|
|||||||
nth1 = 32;
|
nth1 = 32;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
|
{
|
||||||
|
nth0 = 4;
|
||||||
|
nth1 = 16;
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32];
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
||||||
@ -1686,9 +1708,14 @@ bool ggml_metal_graph_compute(
|
|||||||
|
|
||||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
||||||
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
||||||
|
//src0t == GGML_TYPE_IQ2_XXS ||
|
||||||
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
|
else if (src0t == GGML_TYPE_IQ2_XXS) {
|
||||||
|
[encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
|
}
|
||||||
else if (src0t == GGML_TYPE_Q4_K) {
|
else if (src0t == GGML_TYPE_Q4_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
@ -1778,6 +1805,7 @@ bool ggml_metal_graph_compute(
|
|||||||
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
|
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
|
||||||
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
|
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
|
||||||
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
|
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
|
||||||
|
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break;
|
||||||
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
||||||
}
|
}
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
@ -1893,6 +1921,12 @@ bool ggml_metal_graph_compute(
|
|||||||
nth1 = 32;
|
nth1 = 32;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
|
{
|
||||||
|
nth0 = 4;
|
||||||
|
nth1 = 16;
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32];
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
||||||
@ -1942,9 +1976,14 @@ bool ggml_metal_graph_compute(
|
|||||||
|
|
||||||
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
||||||
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
||||||
|
//src2t == GGML_TYPE_IQ2_XXS ||
|
||||||
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
|
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
|
else if (src2t == GGML_TYPE_IQ2_XXS) {
|
||||||
|
[encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
|
}
|
||||||
else if (src2t == GGML_TYPE_Q4_K) {
|
else if (src2t == GGML_TYPE_Q4_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
@ -1982,6 +2021,7 @@ bool ggml_metal_graph_compute(
|
|||||||
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
|
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
|
||||||
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
|
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
|
||||||
case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
|
case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
|
||||||
|
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break;
|
||||||
default: GGML_ASSERT(false && "not implemented");
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
314
ggml-metal.metal
314
ggml-metal.metal
@ -2446,6 +2446,12 @@ typedef struct {
|
|||||||
} block_q6_K;
|
} block_q6_K;
|
||||||
// 210 bytes / block
|
// 210 bytes / block
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
half d;
|
||||||
|
uint16_t qs[QK_K/8];
|
||||||
|
} block_iq2_xxs;
|
||||||
|
// 66 bytes / block for QK_K = 256, so 2.0625 bpw
|
||||||
|
|
||||||
//====================================== dot products =========================
|
//====================================== dot products =========================
|
||||||
|
|
||||||
void kernel_mul_mv_q2_K_f32_impl(
|
void kernel_mul_mv_q2_K_f32_impl(
|
||||||
@ -3468,6 +3474,221 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|||||||
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ======================= "True" 2-bit
|
||||||
|
|
||||||
|
constexpr constant static uint64_t kgrid_iq2xxs[256] = {
|
||||||
|
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
||||||
|
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
||||||
|
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
||||||
|
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
|
||||||
|
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
|
||||||
|
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
|
||||||
|
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
|
||||||
|
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
|
||||||
|
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
|
||||||
|
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
|
||||||
|
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
|
||||||
|
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
|
||||||
|
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
|
||||||
|
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
|
||||||
|
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
|
||||||
|
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
|
||||||
|
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
|
||||||
|
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
|
||||||
|
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
|
||||||
|
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
|
||||||
|
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
|
||||||
|
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
|
||||||
|
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
|
||||||
|
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
|
||||||
|
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
|
||||||
|
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
|
||||||
|
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
|
||||||
|
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
|
||||||
|
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
|
||||||
|
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
|
||||||
|
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
|
||||||
|
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
|
||||||
|
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
|
||||||
|
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
|
||||||
|
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
|
||||||
|
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
|
||||||
|
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
|
||||||
|
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
|
||||||
|
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
|
||||||
|
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
|
||||||
|
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
|
||||||
|
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
|
||||||
|
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
|
||||||
|
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
|
||||||
|
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
|
||||||
|
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
|
||||||
|
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
|
||||||
|
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
|
||||||
|
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
|
||||||
|
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
|
||||||
|
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
|
||||||
|
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
|
||||||
|
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
|
||||||
|
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
|
||||||
|
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
|
||||||
|
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
|
||||||
|
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
|
||||||
|
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
|
||||||
|
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
|
||||||
|
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
|
||||||
|
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
|
||||||
|
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
|
||||||
|
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
|
||||||
|
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
||||||
|
};
|
||||||
|
|
||||||
|
constexpr constant static uint8_t ksigns_iq2xs[128] = {
|
||||||
|
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
||||||
|
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
||||||
|
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
|
||||||
|
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
|
||||||
|
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
|
||||||
|
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
|
||||||
|
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
|
||||||
|
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
||||||
|
};
|
||||||
|
|
||||||
|
constexpr constant static uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
|
||||||
|
|
||||||
|
void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||||
|
device const void * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant uint & r2,
|
||||||
|
constant uint & r3,
|
||||||
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
|
const int nb = ne00/QK_K;
|
||||||
|
const int r0 = tgpig.x;
|
||||||
|
const int r1 = tgpig.y;
|
||||||
|
const int im = tgpig.z;
|
||||||
|
|
||||||
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
||||||
|
const int ib_row = first_row * nb;
|
||||||
|
|
||||||
|
const uint i12 = im%ne12;
|
||||||
|
const uint i13 = im/ne12;
|
||||||
|
|
||||||
|
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||||
|
|
||||||
|
device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
|
||||||
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||||
|
|
||||||
|
float yl[32];
|
||||||
|
float sumf[N_DST]={0.f}, all_sum;
|
||||||
|
|
||||||
|
const int nb32 = nb * (QK_K / 32);
|
||||||
|
|
||||||
|
threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
|
||||||
|
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
|
||||||
|
{
|
||||||
|
int nval = 4;
|
||||||
|
int pos = (32*sgitg + tiisg)*nval;
|
||||||
|
for (int i = 0; i < nval; ++i) values[pos + i] = kgrid_iq2xxs[pos + i];
|
||||||
|
nval = 2;
|
||||||
|
pos = (32*sgitg + tiisg)*nval;
|
||||||
|
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if QK_K == 256
|
||||||
|
const int ix = tiisg;
|
||||||
|
|
||||||
|
device const float * y4 = y + 32 * ix;
|
||||||
|
|
||||||
|
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
||||||
|
|
||||||
|
for (int i = 0; i < 32; ++i) {
|
||||||
|
yl[i] = y4[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ibl = ib32 / (QK_K / 32);
|
||||||
|
const int ib = ib32 % (QK_K / 32);
|
||||||
|
|
||||||
|
device const block_iq2_xxs * xr = x + ibl;
|
||||||
|
device const uint16_t * q2 = xr->qs + 4 * ib;
|
||||||
|
device const half * dh = &xr->d;
|
||||||
|
|
||||||
|
for (int row = 0; row < N_DST; row++) {
|
||||||
|
|
||||||
|
const float db = dh[0];
|
||||||
|
device const uint8_t * aux8 = (device const uint8_t *)q2;
|
||||||
|
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
||||||
|
const float d = db * (0.5f + (aux32 >> 28));
|
||||||
|
|
||||||
|
float sum = 0;
|
||||||
|
for (int l = 0; l < 4; ++l) {
|
||||||
|
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
|
||||||
|
const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
|
||||||
|
for (int j = 0; j < 8; ++j) {
|
||||||
|
sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sumf[row] += d * sum;
|
||||||
|
|
||||||
|
dh += nb*sizeof(block_iq2_xxs)/2;
|
||||||
|
q2 += nb*sizeof(block_iq2_xxs)/2;
|
||||||
|
}
|
||||||
|
|
||||||
|
y4 += 32 * 32;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// TODO
|
||||||
|
#endif
|
||||||
|
|
||||||
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
|
all_sum = simd_sum(sumf[row]);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
|
||||||
|
kernel void kernel_mul_mv_iq2_xxs_f32(
|
||||||
|
device const void * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant uint & r2,
|
||||||
|
constant uint & r3,
|
||||||
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
|
kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
||||||
|
}
|
||||||
|
|
||||||
//============================= templates and their specializations =============================
|
//============================= templates and their specializations =============================
|
||||||
|
|
||||||
// NOTE: this is not dequantizing - we are simply fitting the template
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
||||||
@ -3739,6 +3960,31 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename type4x4>
|
||||||
|
void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
|
||||||
|
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||||
|
const float d = xb->d;
|
||||||
|
const int ib32 = il/2;
|
||||||
|
il = il%2;
|
||||||
|
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
||||||
|
// each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
|
||||||
|
device const uint16_t * q2 = xb->qs + 4*ib32;
|
||||||
|
const uint32_t aux32_g = q2[0] | (q2[1] << 16);
|
||||||
|
const uint32_t aux32_s = q2[2] | (q2[3] << 16);
|
||||||
|
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
|
||||||
|
const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
|
||||||
|
constant uint8_t * grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]);
|
||||||
|
uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
||||||
|
}
|
||||||
|
grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]);
|
||||||
|
signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
||||||
kernel void kernel_get_rows(
|
kernel void kernel_get_rows(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
@ -4278,6 +4524,7 @@ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows
|
|||||||
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||||
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
||||||
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||||
|
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||||
|
|
||||||
//
|
//
|
||||||
// matrix-matrix multiplication
|
// matrix-matrix multiplication
|
||||||
@ -4314,6 +4561,7 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|||||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||||
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
||||||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||||
|
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||||
|
|
||||||
//
|
//
|
||||||
// indirect matrix-matrix multiplication
|
// indirect matrix-matrix multiplication
|
||||||
@ -4362,6 +4610,7 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|||||||
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||||
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
||||||
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||||
|
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||||
|
|
||||||
//
|
//
|
||||||
// matrix-vector multiplication
|
// matrix-vector multiplication
|
||||||
@ -5134,3 +5383,68 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|||||||
tiisg,
|
tiisg,
|
||||||
sgitg);
|
sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
|
||||||
|
kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
||||||
|
device const char * ids,
|
||||||
|
device const char * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant uint64_t & nbi1,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant int64_t & ne13,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
constant uint & r2,
|
||||||
|
constant uint & r3,
|
||||||
|
constant int & idx,
|
||||||
|
device const char * src00,
|
||||||
|
device const char * src01,
|
||||||
|
device const char * src02,
|
||||||
|
device const char * src03,
|
||||||
|
device const char * src04,
|
||||||
|
device const char * src05,
|
||||||
|
device const char * src06,
|
||||||
|
device const char * src07,
|
||||||
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||||
|
|
||||||
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||||
|
|
||||||
|
tgpig.z = tgpig.z%(ne12*ne13);
|
||||||
|
|
||||||
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||||
|
|
||||||
|
kernel_mul_mv_iq2_xxs_f32_impl(
|
||||||
|
src0[id],
|
||||||
|
(device const float *) (src1 + bid*nb11),
|
||||||
|
dst + bid*ne0,
|
||||||
|
ne00,
|
||||||
|
ne01,
|
||||||
|
ne02,
|
||||||
|
ne10,
|
||||||
|
ne12,
|
||||||
|
ne0,
|
||||||
|
ne1,
|
||||||
|
r2,
|
||||||
|
r3,
|
||||||
|
shared_values,
|
||||||
|
tgpig,
|
||||||
|
tiisg,
|
||||||
|
sgitg);
|
||||||
|
}
|
||||||
|
294
ggml-quants.c
294
ggml-quants.c
@ -2340,6 +2340,138 @@ size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t *
|
|||||||
return (n/QK_K*sizeof(block_q6_K));
|
return (n/QK_K*sizeof(block_q6_K));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ====================== "True" 2-bit (de)-quantization
|
||||||
|
|
||||||
|
void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k) {
|
||||||
|
(void)x;
|
||||||
|
(void)y;
|
||||||
|
(void)k;
|
||||||
|
assert(k % QK_K == 0);
|
||||||
|
//fprintf(stderr, "=========================== %s: not implemented\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
static const uint64_t iq2xxs_grid[256] = {
|
||||||
|
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
||||||
|
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
||||||
|
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
||||||
|
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
|
||||||
|
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
|
||||||
|
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
|
||||||
|
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
|
||||||
|
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
|
||||||
|
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
|
||||||
|
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
|
||||||
|
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
|
||||||
|
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
|
||||||
|
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
|
||||||
|
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
|
||||||
|
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
|
||||||
|
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
|
||||||
|
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
|
||||||
|
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
|
||||||
|
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
|
||||||
|
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
|
||||||
|
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
|
||||||
|
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
|
||||||
|
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
|
||||||
|
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
|
||||||
|
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
|
||||||
|
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
|
||||||
|
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
|
||||||
|
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
|
||||||
|
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
|
||||||
|
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
|
||||||
|
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
|
||||||
|
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
|
||||||
|
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
|
||||||
|
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
|
||||||
|
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
|
||||||
|
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
|
||||||
|
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
|
||||||
|
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
|
||||||
|
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
|
||||||
|
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
|
||||||
|
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
|
||||||
|
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
|
||||||
|
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
|
||||||
|
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
|
||||||
|
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
|
||||||
|
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
|
||||||
|
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
|
||||||
|
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
|
||||||
|
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
|
||||||
|
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
|
||||||
|
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
|
||||||
|
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
|
||||||
|
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
|
||||||
|
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
|
||||||
|
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
|
||||||
|
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
|
||||||
|
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
|
||||||
|
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
|
||||||
|
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
|
||||||
|
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
|
||||||
|
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
|
||||||
|
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
|
||||||
|
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
|
||||||
|
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
||||||
|
};
|
||||||
|
|
||||||
|
static const uint8_t ksigns_iq2xs[128] = {
|
||||||
|
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
||||||
|
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
||||||
|
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
|
||||||
|
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
|
||||||
|
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
|
||||||
|
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
|
||||||
|
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
|
||||||
|
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
||||||
|
};
|
||||||
|
static const uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
|
||||||
|
|
||||||
|
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) {
|
||||||
|
assert(k % QK_K == 0);
|
||||||
|
const int nb = k / QK_K;
|
||||||
|
|
||||||
|
uint32_t aux32[2];
|
||||||
|
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; i++) {
|
||||||
|
|
||||||
|
const float d = GGML_FP16_TO_FP32(x[i].d);
|
||||||
|
|
||||||
|
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
||||||
|
memcpy(aux32, x[i].qs + 4*ib32, 2*sizeof(uint32_t));
|
||||||
|
const float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f;
|
||||||
|
for (int l = 0; l < 4; ++l) {
|
||||||
|
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
|
||||||
|
const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
|
||||||
|
for (int j = 0; j < 8; ++j) {
|
||||||
|
y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||||
|
}
|
||||||
|
y += 8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void quantize_row_iq2_xxs(const float * restrict x, void * restrict vy, int k) {
|
||||||
|
assert(k % QK_K == 0);
|
||||||
|
block_iq2_xxs * restrict y = vy;
|
||||||
|
quantize_row_iq2_xxs_reference(x, y, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist) {
|
||||||
|
assert(k % QK_K == 0);
|
||||||
|
(void)hist; // TODO: collect histograms
|
||||||
|
|
||||||
|
for (int j = 0; j < n; j += k) {
|
||||||
|
block_iq2_xxs * restrict y = (block_iq2_xxs *)dst + j/QK_K;
|
||||||
|
quantize_row_iq2_xxs_reference(src + j, y, k);
|
||||||
|
}
|
||||||
|
return (n/QK_K*sizeof(block_iq2_xxs));
|
||||||
|
}
|
||||||
|
|
||||||
//===================================== Q8_K ==============================================
|
//===================================== Q8_K ==============================================
|
||||||
|
|
||||||
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
|
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
|
||||||
@ -2362,7 +2494,9 @@ void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict
|
|||||||
x += QK_K;
|
x += QK_K;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const float iscale = -128.f/max;
|
//const float iscale = -128.f/max;
|
||||||
|
// We need this change for IQ2_XXS, else the AVX implementation becomes very awkward
|
||||||
|
const float iscale = -127.f/max;
|
||||||
for (int j = 0; j < QK_K; ++j) {
|
for (int j = 0; j < QK_K; ++j) {
|
||||||
int v = nearest_int(iscale*x[j]);
|
int v = nearest_int(iscale*x[j]);
|
||||||
y[i].qs[j] = MIN(127, v);
|
y[i].qs[j] = MIN(127, v);
|
||||||
@ -7065,3 +7199,161 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
|
|||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
static const int8_t keven_signs_q2xs[1024] = {
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
|
||||||
|
1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
|
||||||
|
1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
|
||||||
|
1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
|
||||||
|
1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
|
||||||
|
1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
|
||||||
|
1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
|
||||||
|
1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
|
||||||
|
1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
|
||||||
|
1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
|
||||||
|
1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
|
||||||
|
1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
|
||||||
|
1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
|
||||||
|
1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
|
||||||
|
1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
|
||||||
|
1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
|
||||||
|
1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
|
||||||
|
1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
|
||||||
|
1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
|
||||||
|
1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
|
||||||
|
1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
|
||||||
|
1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
|
||||||
|
1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
|
||||||
|
1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
|
||||||
|
1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
|
||||||
|
1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
|
||||||
|
1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
|
||||||
|
1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
|
||||||
|
1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
|
||||||
|
};
|
||||||
|
|
||||||
|
void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||||
|
assert(n % QK_K == 0);
|
||||||
|
|
||||||
|
const block_iq2_xxs * restrict x = vx;
|
||||||
|
const block_q8_K * restrict y = vy;
|
||||||
|
|
||||||
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
|
#if defined(__ARM_NEON)
|
||||||
|
|
||||||
|
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
||||||
|
|
||||||
|
uint32_t aux32[4];
|
||||||
|
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||||
|
|
||||||
|
int8x16x4_t q2u;
|
||||||
|
int8x16x4_t q2s;
|
||||||
|
int8x16x4_t q8b;
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||||
|
const uint16_t * restrict q2 = x[i].qs;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
float sumf1 = 0, sumf2 = 0;
|
||||||
|
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
||||||
|
q8b = vld1q_s8_x4(q8); q8 += 64;
|
||||||
|
memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
|
||||||
|
q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1])));
|
||||||
|
q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3])));
|
||||||
|
q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9])));
|
||||||
|
q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11])));
|
||||||
|
q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
|
||||||
|
q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
|
||||||
|
q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127))));
|
||||||
|
q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127))));
|
||||||
|
q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
|
||||||
|
q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
|
||||||
|
q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
|
||||||
|
q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
|
||||||
|
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]);
|
||||||
|
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]);
|
||||||
|
sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28));
|
||||||
|
sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28));
|
||||||
|
}
|
||||||
|
sumf += d*(sumf1 + sumf2);
|
||||||
|
}
|
||||||
|
*s = 0.25f * sumf;
|
||||||
|
|
||||||
|
#elif defined(__AVX2__)
|
||||||
|
|
||||||
|
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
||||||
|
|
||||||
|
uint32_t aux32[4];
|
||||||
|
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||||
|
|
||||||
|
__m256 accumf = _mm256_setzero_ps();
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||||
|
const uint16_t * restrict q2 = x[i].qs;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
__m256i sumi1 = _mm256_setzero_si256();
|
||||||
|
__m256i sumi2 = _mm256_setzero_si256();
|
||||||
|
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
||||||
|
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
||||||
|
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
||||||
|
memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
|
||||||
|
const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
|
||||||
|
const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
|
||||||
|
const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
|
||||||
|
signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
|
||||||
|
const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
|
||||||
|
signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
|
||||||
|
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
|
||||||
|
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
|
||||||
|
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
|
||||||
|
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
|
||||||
|
const uint16_t ls1 = aux32[1] >> 28;
|
||||||
|
const uint16_t ls2 = aux32[3] >> 28;
|
||||||
|
const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
|
||||||
|
const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
|
||||||
|
sumi1 = _mm256_add_epi32(sumi1, p1);
|
||||||
|
sumi2 = _mm256_add_epi32(sumi2, p2);
|
||||||
|
}
|
||||||
|
|
||||||
|
accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = 0.125f * hsum_float_8(accumf);
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
uint32_t aux32[2];
|
||||||
|
const uint8_t * aux8 = (const uint8_t *)aux32;
|
||||||
|
|
||||||
|
float sumf = 0.f;
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||||
|
const uint16_t * restrict q2 = x[i].qs;
|
||||||
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
|
int32_t bsum = 0;
|
||||||
|
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
||||||
|
memcpy(aux32, q2, 2*sizeof(uint32_t));
|
||||||
|
q2 += 4;
|
||||||
|
const uint32_t ls = 2*(aux32[1] >> 28) + 1;
|
||||||
|
int32_t sumi = 0;
|
||||||
|
for (int l = 0; l < 4; ++l) {
|
||||||
|
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
|
||||||
|
const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
|
||||||
|
for (int j = 0; j < 8; ++j) {
|
||||||
|
sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
||||||
|
}
|
||||||
|
q8 += 8;
|
||||||
|
}
|
||||||
|
bsum += sumi * ls;
|
||||||
|
}
|
||||||
|
sumf += d * bsum;
|
||||||
|
}
|
||||||
|
*s = 0.125f * sumf;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
@ -165,6 +165,14 @@ typedef struct {
|
|||||||
} block_q8_K;
|
} block_q8_K;
|
||||||
static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
|
static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
|
||||||
|
|
||||||
|
// (Almost) "true" 2-bit quantization.
|
||||||
|
// Due to the need to use blocks as per ggml dsign, it ends up using
|
||||||
|
// 2.0625 bpw because of the 16-bit scale for each block of 256.
|
||||||
|
typedef struct {
|
||||||
|
ggml_fp16_t d;
|
||||||
|
uint16_t qs[QK_K/8];
|
||||||
|
} block_iq2_xxs;
|
||||||
|
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
|
||||||
|
|
||||||
// Quantization
|
// Quantization
|
||||||
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
|
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
|
||||||
@ -180,6 +188,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
|
|||||||
void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
|
void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
|
||||||
void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
|
void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
|
||||||
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
|
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
|
||||||
|
void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k);
|
||||||
|
|
||||||
void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
|
void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
|
||||||
void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
|
void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
|
||||||
@ -194,6 +203,7 @@ void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
|
|||||||
void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
|
void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
|
||||||
void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
|
void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
|
||||||
void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
|
void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
|
||||||
|
void quantize_row_iq2_xxs(const float * restrict x, void * restrict y, int k);
|
||||||
|
|
||||||
// Dequantization
|
// Dequantization
|
||||||
void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
|
void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
|
||||||
@ -209,6 +219,7 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
|
|||||||
void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k);
|
void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k);
|
||||||
void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
|
void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
|
||||||
void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
|
void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
|
||||||
|
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k);
|
||||||
|
|
||||||
// Dot product
|
// Dot product
|
||||||
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
||||||
@ -222,3 +233,4 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx,
|
|||||||
void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
||||||
void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
||||||
void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
||||||
|
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
||||||
|
26
ggml.c
26
ggml.c
@ -573,6 +573,17 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||||||
.vec_dot = ggml_vec_dot_q6_K_q8_K,
|
.vec_dot = ggml_vec_dot_q6_K_q8_K,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||||
},
|
},
|
||||||
|
[GGML_TYPE_IQ2_XXS] = {
|
||||||
|
.type_name = "iq2_xxs",
|
||||||
|
.blck_size = QK_K,
|
||||||
|
.type_size = sizeof(block_iq2_xxs),
|
||||||
|
.is_quantized = true,
|
||||||
|
.to_float = (ggml_to_float_t) dequantize_row_iq2_xxs,
|
||||||
|
.from_float = quantize_row_iq2_xxs,
|
||||||
|
.from_float_reference = (ggml_from_float_t) quantize_row_iq2_xxs_reference,
|
||||||
|
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
|
||||||
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||||
|
},
|
||||||
[GGML_TYPE_Q8_K] = {
|
[GGML_TYPE_Q8_K] = {
|
||||||
.type_name = "q8_K",
|
.type_name = "q8_K",
|
||||||
.blck_size = QK_K,
|
.blck_size = QK_K,
|
||||||
@ -2111,6 +2122,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
|||||||
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
|
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
|
||||||
case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
|
case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
|
||||||
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
|
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
|
||||||
|
case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
|
||||||
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
||||||
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
|
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
|
||||||
}
|
}
|
||||||
@ -7436,6 +7448,7 @@ static void ggml_compute_forward_add(
|
|||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
||||||
} break;
|
} break;
|
||||||
@ -7700,6 +7713,7 @@ static void ggml_compute_forward_add1(
|
|||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
|
ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
|
||||||
} break;
|
} break;
|
||||||
@ -7814,6 +7828,7 @@ static void ggml_compute_forward_acc(
|
|||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
@ -10455,6 +10470,7 @@ static void ggml_compute_forward_out_prod(
|
|||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
|
ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
|
||||||
} break;
|
} break;
|
||||||
@ -10629,6 +10645,7 @@ static void ggml_compute_forward_set(
|
|||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
@ -10823,6 +10840,7 @@ static void ggml_compute_forward_get_rows(
|
|||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
||||||
} break;
|
} break;
|
||||||
@ -11459,6 +11477,7 @@ static void ggml_compute_forward_alibi(
|
|||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
case GGML_TYPE_Q8_K:
|
case GGML_TYPE_Q8_K:
|
||||||
case GGML_TYPE_I8:
|
case GGML_TYPE_I8:
|
||||||
case GGML_TYPE_I16:
|
case GGML_TYPE_I16:
|
||||||
@ -11533,6 +11552,7 @@ static void ggml_compute_forward_clamp(
|
|||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
case GGML_TYPE_Q8_K:
|
case GGML_TYPE_Q8_K:
|
||||||
case GGML_TYPE_I8:
|
case GGML_TYPE_I8:
|
||||||
case GGML_TYPE_I16:
|
case GGML_TYPE_I16:
|
||||||
@ -18648,6 +18668,12 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
|
|||||||
block_q6_K * block = (block_q6_K*)dst + start / QK_K;
|
block_q6_K * block = (block_q6_K*)dst + start / QK_K;
|
||||||
result = ggml_quantize_q6_K(src + start, block, n, n, hist);
|
result = ggml_quantize_q6_K(src + start, block, n, n, hist);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_TYPE_IQ2_XXS:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(start % QK_K == 0);
|
||||||
|
block_iq2_xxs * block = (block_iq2_xxs*)dst + start / QK_K;
|
||||||
|
result = ggml_quantize_iq2_xxs(src + start, block, n, n, hist);
|
||||||
|
} break;
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
{
|
{
|
||||||
int elemsize = sizeof(ggml_fp16_t);
|
int elemsize = sizeof(ggml_fp16_t);
|
||||||
|
3
ggml.h
3
ggml.h
@ -339,6 +339,7 @@ extern "C" {
|
|||||||
GGML_TYPE_Q5_K = 13,
|
GGML_TYPE_Q5_K = 13,
|
||||||
GGML_TYPE_Q6_K = 14,
|
GGML_TYPE_Q6_K = 14,
|
||||||
GGML_TYPE_Q8_K = 15,
|
GGML_TYPE_Q8_K = 15,
|
||||||
|
GGML_TYPE_IQ2_XXS = 16,
|
||||||
GGML_TYPE_I8,
|
GGML_TYPE_I8,
|
||||||
GGML_TYPE_I16,
|
GGML_TYPE_I16,
|
||||||
GGML_TYPE_I32,
|
GGML_TYPE_I32,
|
||||||
@ -373,6 +374,7 @@ extern "C" {
|
|||||||
GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
|
GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
|
||||||
GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
|
GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
|
||||||
GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
|
GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
|
||||||
|
GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
|
||||||
};
|
};
|
||||||
|
|
||||||
// available tensor operations:
|
// available tensor operations:
|
||||||
@ -2067,6 +2069,7 @@ extern "C" {
|
|||||||
GGML_API size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
|
GGML_API size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
|
||||||
GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
|
GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
|
||||||
GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
|
GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
|
||||||
|
GGML_API size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist);
|
||||||
|
|
||||||
GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
|
GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
|
||||||
|
|
||||||
|
@ -2222,6 +2222,7 @@ struct llama_model_loader {
|
|||||||
case GGML_TYPE_Q4_K: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M; break;
|
case GGML_TYPE_Q4_K: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M; break;
|
||||||
case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break;
|
case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break;
|
||||||
case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break;
|
case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break;
|
||||||
|
case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
|
LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
|
||||||
@ -2593,6 +2594,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
|
|||||||
case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small";
|
case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small";
|
||||||
case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium";
|
case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium";
|
||||||
case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K";
|
case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K";
|
||||||
|
case LLAMA_FTYPE_MOSTLY_IQ2_XXS:return "IQ2_XSS - 2.0625 bpw";
|
||||||
|
|
||||||
default: return "unknown, may not work";
|
default: return "unknown, may not work";
|
||||||
}
|
}
|
||||||
@ -9038,6 +9040,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||||||
case LLAMA_FTYPE_MOSTLY_Q5_K_S:
|
case LLAMA_FTYPE_MOSTLY_Q5_K_S:
|
||||||
case LLAMA_FTYPE_MOSTLY_Q5_K_M: quantized_type = GGML_TYPE_Q5_K; break;
|
case LLAMA_FTYPE_MOSTLY_Q5_K_M: quantized_type = GGML_TYPE_Q5_K; break;
|
||||||
case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break;
|
case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break;
|
||||||
|
case LLAMA_FTYPE_MOSTLY_IQ2_XXS:quantized_type = GGML_TYPE_IQ2_XXS; break;
|
||||||
|
|
||||||
default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
|
default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
|
||||||
}
|
}
|
||||||
|
1
llama.h
1
llama.h
@ -103,6 +103,7 @@ extern "C" {
|
|||||||
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, // except 1d tensors
|
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, // except 1d tensors
|
||||||
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors
|
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors
|
||||||
LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors
|
LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_IQ2_XXS = 19, // except 1d tensors
|
||||||
|
|
||||||
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||||
};
|
};
|
||||||
|
@ -134,6 +134,11 @@ int main(int argc, char * argv[]) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ((ggml_type)i == GGML_TYPE_IQ2_XXS) {
|
||||||
|
printf("Skip %s due to missing quantization functionality\n", ggml_type_name((ggml_type) i));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
printf("Testing %s\n", ggml_type_name((ggml_type) i));
|
printf("Testing %s\n", ggml_type_name((ggml_type) i));
|
||||||
|
|
||||||
if (qfns.from_float && qfns.to_float) {
|
if (qfns.from_float && qfns.to_float) {
|
||||||
|
Loading…
Reference in New Issue
Block a user