From 970b5ab7ca1a335b178f6831534b066d529c90c5 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Fri, 27 Dec 2024 20:21:28 -0500 Subject: [PATCH] ggml-cuda : add TQ2_0 support --- ggml/src/ggml-common.h | 3 + ggml/src/ggml-cuda/common.cuh | 7 ++ ggml/src/ggml-cuda/convert.cu | 30 +++++ ggml/src/ggml-cuda/ggml-cuda.cu | 1 + ggml/src/ggml-cuda/mmq.cu | 4 + ggml/src/ggml-cuda/mmq.cuh | 112 ++++++++++++++++++ ggml/src/ggml-cuda/mmvq.cu | 12 ++ .../template-instances/generate_cu_files.py | 1 + .../template-instances/mmq-instance-tq2_0.cu | 5 + ggml/src/ggml-cuda/vecdotq.cuh | 61 ++++++++++ tests/test-backend-ops.cpp | 7 +- 11 files changed, 241 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-tq2_0.cu diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index f13fd4dea..b4c6d2766 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -126,6 +126,9 @@ typedef sycl::half2 ggml_half2; #define QI6_K (QK_K / (4*QR6_K)) #define QR6_K 2 +#define QI2_0 (QK_K / (4*QR2_0)) +#define QR2_0 4 + #define QI2_XXS (QK_K / (4*QR2_XXS)) #define QR2_XXS 4 diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 2c0a56226..8ff930dad 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -440,6 +440,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI6_K; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_0; + static constexpr int qi = QI2_0; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK_K; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 3896f956d..c2f723ca5 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -277,6 +277,26 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); } +template +static __global__ void dequantize_block_tq2_0(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int64_t i = blockIdx.x; + const block_tq2_0 * x = (const block_tq2_0 *) vx; + + const int64_t tid = threadIdx.x; // 0..64 + const int64_t n = tid/32; // 0 or 1 + const int64_t l = tid - 32*n; // 0..32 + + const uint8_t q = x[i].qs[32*n + l]; + dst_t * y = yy + i*QK_K + 128*n; + + float d = __half2float(x[i].d); + y[l+ 0] = d * ((q >> 0) & 3) - d; + y[l+32] = d * ((q >> 2) & 3) - d; + y[l+64] = d * ((q >> 4) & 3) - d; + y[l+96] = d * ((q >> 6) & 3) - d; +} + template static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -515,6 +535,12 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k dequantize_block_q6_K<<>>(vx, y); } +template +static void dequantize_row_tq2_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_tq2_0<<>>(vx, y); +} + template static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; @@ -613,6 +639,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_q5_K_cuda; case GGML_TYPE_Q6_K: return dequantize_row_q6_K_cuda; + case GGML_TYPE_TQ2_0: + return dequantize_row_tq2_0_cuda; case GGML_TYPE_IQ2_XXS: return dequantize_row_iq2_xxs_cuda; case GGML_TYPE_IQ2_XS: @@ -660,6 +688,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_q5_K_cuda; case GGML_TYPE_Q6_K: return dequantize_row_q6_K_cuda; + case GGML_TYPE_TQ2_0: + return dequantize_row_tq2_0_cuda; case GGML_TYPE_IQ2_XXS: return dequantize_row_iq2_xxs_cuda; case GGML_TYPE_IQ2_XS: diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c180adc84..07765de7f 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2860,6 +2860,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: case GGML_TYPE_Q8_K: + case GGML_TYPE_TQ2_0: case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ2_S: diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 270251df4..c5f3b21b3 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -61,6 +61,9 @@ void ggml_cuda_op_mul_mat_q( case GGML_TYPE_Q6_K: mul_mat_q_case(ctx, args, stream); break; + case GGML_TYPE_TQ2_0: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_IQ2_XXS: mul_mat_q_case(ctx, args, stream); break; @@ -113,6 +116,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_TQ2_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 3cd508a1d..0cd671bc9 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -63,6 +63,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_Q5_K: return MMQ_Q8_1_DS_LAYOUT_DS4; case GGML_TYPE_Q6_K: + case GGML_TYPE_TQ2_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: @@ -139,6 +140,9 @@ static constexpr __device__ int get_mmq_y_device() { #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) } +// tile_x_sizes{qs, dm, sc} + +// TODO: TQ2_0 to minimize shared mem #define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} #define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} #define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0} @@ -161,6 +165,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : + type == GGML_TYPE_TQ2_0 ? MMQ_DP4A_TXS_Q8_0 : type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 : type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 : type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 : @@ -195,6 +200,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 : type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 : type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : + type == GGML_TYPE_TQ2_0 ? MMQ_MMA_TILE_X_K_Q8_0 : type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 : type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K : type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K : @@ -1808,6 +1814,103 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #endif // INT8_MMA_AVAILABLE } +// This is the first "simple" type with a block size of 256 +template static __device__ __forceinline__ void load_tiles_tq2_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_tile + 2*WARP_SIZE); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_TQ2_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x % QI2_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_0) { + int i = i0 + threadIdx.y*(WARP_SIZE/QI2_0) + threadIdx.x/QI2_0; + + if (need_check) { + i = min(i, i_max); + } + + const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride; + const int qs0 = get_int_b2(bxi->qs, kqsx); + +#ifdef INT8_MMA_AVAILABLE + +#pragma unroll + for (int l = 0; l < QR2_0; ++l) { + // 0..7, 32..39 + // 8..15, 40..47 + // 16..23, 48..55 + // 24..31, 56..63 + // FIXME: this might assume WARP_SIZE is >= 32 + const int k = (kqsx/8)*32 + l*8 + kqsx % 8; + + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = __vsub4((qs0 >> (2*l)) & 0x03030303, 0x01010101); + } +#else + x_qs[i*(2*WARP_SIZE + 1) + kqsx] = qs0; +#endif // INT8_MMA_AVAILABLE + } + + // TODO: does this work with WARP_SIZE != 32? +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_0/2)) { + int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_0) + threadIdx.x/(QI2_0/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride; + + const int k = threadIdx.x % (QI2_0/2); + +#ifdef INT8_MMA_AVAILABLE + + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = bxi->d; +#else + x_df[i*(WARP_SIZE/4) + i/4 + k] = bxi->d; +#endif // INT8_MMA_AVAILABLE + } +} + +template +static __device__ __forceinline__ void vec_dot_tq2_0_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_TQ2_0, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_0*VDR_TQ2_0_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_tq2_0_q8_1_impl( + &x_qs[i*(2*WARP_SIZE + 1) + k0/QR2_0], &y_qs[j*MMQ_TILE_Y_K + k01], + x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2)], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + // x_df[i*(WARP_SIZE/QI2_0) + i/QI2_0], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + template static __device__ __forceinline__ void load_tiles_iq4_nl( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -2427,6 +2530,14 @@ struct mmq_type_traits { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a; }; +template +struct mmq_type_traits { + static constexpr int vdr = VDR_TQ2_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_tq2_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_tq2_0_q8_1_dp4a; +}; + template struct mmq_type_traits { static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ; @@ -2916,6 +3027,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q3_K); extern DECL_MMQ_CASE(GGML_TYPE_Q4_K); extern DECL_MMQ_CASE(GGML_TYPE_Q5_K); extern DECL_MMQ_CASE(GGML_TYPE_Q6_K); +extern DECL_MMQ_CASE(GGML_TYPE_TQ2_0); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index e3b912d87..d8c8893b6 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -14,6 +14,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 : type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 : type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 : + type == GGML_TYPE_TQ2_0 ? vec_dot_tq2_0_q8_1 : type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 : type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 : type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 : @@ -37,6 +38,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ : type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ : type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ : + type == GGML_TYPE_TQ2_0 ? VDR_TQ2_0_Q8_1_MMVQ : type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ : type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ : type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ : @@ -271,6 +273,13 @@ static void mul_mat_vec_q6_K_q8_1_cuda( mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } +static void mul_mat_vec_tq2_0_q8_1_cuda( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +} + static void mul_mat_vec_iq2_xxs_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { @@ -385,6 +394,9 @@ void ggml_cuda_op_mul_mat_vec_q( case GGML_TYPE_Q6_K: mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_TQ2_0: + mul_mat_vec_tq2_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; case GGML_TYPE_IQ2_XXS: mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index d7874e6ea..00dc20884 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -23,6 +23,7 @@ SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block} TYPES_MMQ = [ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K", + "GGML_TYPE_TQ2_0", "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S", "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS" ] diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-tq2_0.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-tq2_0.cu new file mode 100644 index 000000000..2780a4d47 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-tq2_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_TQ2_0); diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 40091a0ef..80019ab24 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -524,6 +524,36 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( return d6 * sumf_d; } +// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called +// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q + +#define VDR_TQ2_0_Q8_1_MMVQ 2 +#define VDR_TQ2_0_Q8_1_MMQ 8 + +// Can use the same for both mmvq and mmq, because there are no sub-scales in a TQ2_0 block +template static __device__ __forceinline__ float vec_dot_tq2_0_q8_1_impl( + const int * __restrict__ v, const int * __restrict__ u, const float & d2, const float * __restrict__ d8) { + + float sumf = 0.0f; + +#pragma unroll + for (int i0 = 0; i0 < QR2_0; ++i0) { + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + const int vi = (v[i] >> (2*i0)) & 0x03030303; + + sumi = ggml_cuda_dp4a(__vsub4(vi, 0x01010101), u[vdr*i0 + i], sumi); // SIMD dot product + } + + // TODO: batch subtract by using d8 sum + sumf += d8[i0] * sumi; + } + + return d2 * sumf; +} + static __device__ __forceinline__ float vec_dot_q4_0_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { @@ -786,6 +816,37 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1( return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); } +static __device__ __forceinline__ float vec_dot_tq2_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_tq2_0 * btq2_0 = (const block_tq2_0 *) vbq + kbx; + + // iqs 0..7 all need bq8_offset 0, 1, 2, 3 + // iqs 8..15 all need bq8_offset 4, 5, 6, 7 + const int bq8_offset = QR2_0 * (iqs / 8); + + int v[VDR_TQ2_0_Q8_1_MMVQ]; + int u[QR2_0*VDR_TQ2_0_Q8_1_MMVQ]; + float d8[QR2_0]; + +#pragma unroll + for (int i = 0; i < VDR_TQ2_0_Q8_1_MMVQ; ++i) { + v[i] = get_int_b2(btq2_0->qs, iqs + i); + } + +#pragma unroll + for (int i0 = 0; i0 < QR2_0; ++i0) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i0; + + for (int i = 0; i < VDR_TQ2_0_Q8_1_MMVQ; ++i) { + u[VDR_TQ2_0_Q8_1_MMVQ*i0 + i] = get_int_b4(bq8i->qs, (iqs % QI8_1) + i); + } + d8[i0] = __low2float(bq8i->ds); + } + + return vec_dot_tq2_0_q8_1_impl(v, u, btq2_0->d, d8); +} + #define VDR_IQ2_XXS_Q8_1_MMVQ 2 #define VDR_IQ2_XXS_Q8_1_MMQ 2 diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ccdd3fb57..1f080ecf3 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3375,7 +3375,8 @@ static const ggml_type all_types[] = { GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K, - // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends + // GGML_TYPE_TQ1_0, + GGML_TYPE_TQ2_0, GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, @@ -3387,6 +3388,7 @@ static const ggml_type base_types[] = { GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, // for I8MM tests GGML_TYPE_Q4_K, + GGML_TYPE_TQ2_0, GGML_TYPE_IQ2_XXS }; @@ -3397,7 +3399,8 @@ static const ggml_type other_types[] = { GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K, - // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends + // GGML_TYPE_TQ1_0, + GGML_TYPE_TQ2_0, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,