mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-27 20:43:07 +01:00
QK_K = 64 tests pass on ARM_NEON and Metal
Sadly, that does not mean it actually works.
This commit is contained in:
parent
28e6146c11
commit
de64e061da
@ -2560,12 +2560,16 @@ typedef struct {
|
||||
uint8_t qs[QK4_NL/2];
|
||||
} block_iq4_nl;
|
||||
|
||||
#if QK_K == 64
|
||||
#define block_iq4_xs block_iq4_nl
|
||||
#else
|
||||
typedef struct {
|
||||
half d;
|
||||
uint16_t scales_h;
|
||||
uint8_t scales_l[QK_K/64];
|
||||
uint8_t qs[QK_K/2];
|
||||
} block_iq4_xs;
|
||||
#endif
|
||||
|
||||
//====================================== dot products =========================
|
||||
|
||||
@ -4346,7 +4350,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
#if QK_K == 256
|
||||
const int ix = tiisg;
|
||||
|
||||
device const float * y4 = y + 32 * ix;
|
||||
@ -4387,12 +4390,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||
|
||||
y4 += 32 * 32;
|
||||
}
|
||||
#else
|
||||
(void) x;
|
||||
(void) y;
|
||||
(void) yl;
|
||||
(void) nb32;
|
||||
#endif
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
@ -4482,7 +4479,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
#if QK_K == 256
|
||||
const int ix = tiisg;
|
||||
|
||||
device const float * y4 = y + 32 * ix;
|
||||
@ -4533,12 +4529,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
||||
|
||||
y4 += 32 * 32;
|
||||
}
|
||||
#else
|
||||
(void) x;
|
||||
(void) y;
|
||||
(void) yl;
|
||||
(void) nb32;
|
||||
#endif
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
@ -4628,7 +4618,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
#if QK_K == 256
|
||||
const int ix = tiisg;
|
||||
|
||||
device const float * y4 = y + 32 * ix;
|
||||
@ -4672,12 +4661,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
||||
|
||||
y4 += 32 * 32;
|
||||
}
|
||||
#else
|
||||
(void) x;
|
||||
(void) y;
|
||||
(void) yl;
|
||||
(void) nb32;
|
||||
#endif
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
@ -5016,7 +4999,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
||||
|
||||
const int nb32 = nb * (QK_K / 32);
|
||||
|
||||
#if QK_K == 256
|
||||
const int ix = tiisg/2;
|
||||
const int il = tiisg%2;
|
||||
|
||||
@ -5055,12 +5037,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
||||
|
||||
y4 += 16 * 32;
|
||||
}
|
||||
#else
|
||||
(void) x;
|
||||
(void) y;
|
||||
(void) yl;
|
||||
(void) nb32;
|
||||
#endif
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
@ -5167,6 +5143,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
}
|
||||
}
|
||||
|
||||
#if QK_K != 64
|
||||
void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
@ -5260,6 +5237,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
||||
kernel void kernel_mul_mv_iq1_s_f32(
|
||||
@ -5344,7 +5322,11 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
#if QK_K == 64
|
||||
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
||||
#else
|
||||
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
||||
#endif
|
||||
}
|
||||
|
||||
//============================= templates and their specializations =============================
|
||||
@ -5770,6 +5752,9 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
||||
#if QK_K == 64
|
||||
dequantize_iq4_nl(xb, il, reg);
|
||||
#else
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const int ib32 = il/2;
|
||||
il = il%2;
|
||||
@ -5786,6 +5771,7 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
|
||||
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
||||
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
||||
@ -6334,7 +6320,11 @@ template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_r
|
||||
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
#if QK_K == 64
|
||||
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
|
||||
#else
|
||||
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
#endif
|
||||
|
||||
//
|
||||
// matrix-matrix multiplication
|
||||
@ -6378,7 +6368,11 @@ template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_m
|
||||
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
#if QK_K == 64
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
|
||||
#else
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
#endif
|
||||
|
||||
//
|
||||
// indirect matrix-matrix multiplication
|
||||
@ -6434,7 +6428,11 @@ template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel
|
||||
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
#if QK_K == 64
|
||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
|
||||
#else
|
||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
#endif
|
||||
|
||||
//
|
||||
// matrix-vector multiplication
|
||||
@ -7707,7 +7705,11 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
|
||||
#if QK_K == 64
|
||||
kernel_mul_mv_iq4_nl_f32_impl(
|
||||
#else
|
||||
kernel_mul_mv_iq4_xs_f32_impl(
|
||||
#endif
|
||||
src0[id],
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
|
@ -10262,7 +10262,7 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
#if defined __ARM_NEON
|
||||
#if defined __ARM_NEON && QK_K != 64
|
||||
|
||||
const uint8x16_t m8 = vdupq_n_u8(0x08);
|
||||
const uint8x16_t m7 = vdupq_n_u8(0x07);
|
||||
|
Loading…
Reference in New Issue
Block a user