mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 21:37:19 +01:00
ggml-cuda : add TQ2_0 support
This commit is contained in:
parent
5cd85b5e00
commit
970b5ab7ca
@ -126,6 +126,9 @@ typedef sycl::half2 ggml_half2;
|
||||
#define QI6_K (QK_K / (4*QR6_K))
|
||||
#define QR6_K 2
|
||||
|
||||
#define QI2_0 (QK_K / (4*QR2_0))
|
||||
#define QR2_0 4
|
||||
|
||||
#define QI2_XXS (QK_K / (4*QR2_XXS))
|
||||
#define QR2_XXS 4
|
||||
|
||||
|
@ -440,6 +440,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
|
||||
static constexpr int qi = QI6_K;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct ggml_cuda_type_traits<GGML_TYPE_TQ2_0> {
|
||||
static constexpr int qk = QK_K;
|
||||
static constexpr int qr = QR2_0;
|
||||
static constexpr int qi = QI2_0;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
|
||||
static constexpr int qk = QK_K;
|
||||
|
@ -277,6 +277,26 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
|
||||
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_tq2_0(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int64_t i = blockIdx.x;
|
||||
const block_tq2_0 * x = (const block_tq2_0 *) vx;
|
||||
|
||||
const int64_t tid = threadIdx.x; // 0..64
|
||||
const int64_t n = tid/32; // 0 or 1
|
||||
const int64_t l = tid - 32*n; // 0..32
|
||||
|
||||
const uint8_t q = x[i].qs[32*n + l];
|
||||
dst_t * y = yy + i*QK_K + 128*n;
|
||||
|
||||
float d = __half2float(x[i].d);
|
||||
y[l+ 0] = d * ((q >> 0) & 3) - d;
|
||||
y[l+32] = d * ((q >> 2) & 3) - d;
|
||||
y[l+64] = d * ((q >> 4) & 3) - d;
|
||||
y[l+96] = d * ((q >> 6) & 3) - d;
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
@ -515,6 +535,12 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k
|
||||
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_tq2_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
dequantize_block_tq2_0<<<nb, 64, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
||||
const int nb = k / QK_K;
|
||||
@ -613,6 +639,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
||||
return dequantize_row_q5_K_cuda;
|
||||
case GGML_TYPE_Q6_K:
|
||||
return dequantize_row_q6_K_cuda;
|
||||
case GGML_TYPE_TQ2_0:
|
||||
return dequantize_row_tq2_0_cuda;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
return dequantize_row_iq2_xxs_cuda;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
@ -660,6 +688,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||
return dequantize_row_q5_K_cuda;
|
||||
case GGML_TYPE_Q6_K:
|
||||
return dequantize_row_q6_K_cuda;
|
||||
case GGML_TYPE_TQ2_0:
|
||||
return dequantize_row_tq2_0_cuda;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
return dequantize_row_iq2_xxs_cuda;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
|
@ -2860,6 +2860,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_Q8_K:
|
||||
case GGML_TYPE_TQ2_0:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
|
@ -61,6 +61,9 @@ void ggml_cuda_op_mul_mat_q(
|
||||
case GGML_TYPE_Q6_K:
|
||||
mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
|
||||
break;
|
||||
case GGML_TYPE_TQ2_0:
|
||||
mul_mat_q_case<GGML_TYPE_TQ2_0>(ctx, args, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
mul_mat_q_case<GGML_TYPE_IQ2_XXS>(ctx, args, stream);
|
||||
break;
|
||||
@ -113,6 +116,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_TQ2_0:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
|
@ -63,6 +63,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
||||
case GGML_TYPE_Q5_K:
|
||||
return MMQ_Q8_1_DS_LAYOUT_DS4;
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_TQ2_0:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
@ -139,6 +140,9 @@ static constexpr __device__ int get_mmq_y_device() {
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
}
|
||||
|
||||
// tile_x_sizes{qs, dm, sc}
|
||||
|
||||
// TODO: TQ2_0 to minimize shared mem
|
||||
#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
|
||||
#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
|
||||
#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
|
||||
@ -161,6 +165,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
||||
type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
|
||||
type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
|
||||
type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
|
||||
type == GGML_TYPE_TQ2_0 ? MMQ_DP4A_TXS_Q8_0 :
|
||||
type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
|
||||
type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
|
||||
type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
|
||||
@ -195,6 +200,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
||||
type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
|
||||
type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
|
||||
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
|
||||
type == GGML_TYPE_TQ2_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
|
||||
type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
|
||||
type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
|
||||
type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
|
||||
@ -1808,6 +1814,103 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
// This is the first "simple" type with a block size of 256
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_tq2_0(
|
||||
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_tile + 2*WARP_SIZE);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_TQ2_0, mmq_y);
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
|
||||
const int kqsx = threadIdx.x % QI2_0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_0) {
|
||||
int i = i0 + threadIdx.y*(WARP_SIZE/QI2_0) + threadIdx.x/QI2_0;
|
||||
|
||||
if (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
|
||||
const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride;
|
||||
const int qs0 = get_int_b2(bxi->qs, kqsx);
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < QR2_0; ++l) {
|
||||
// 0..7, 32..39
|
||||
// 8..15, 40..47
|
||||
// 16..23, 48..55
|
||||
// 24..31, 56..63
|
||||
// FIXME: this might assume WARP_SIZE is >= 32
|
||||
const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
|
||||
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = __vsub4((qs0 >> (2*l)) & 0x03030303, 0x01010101);
|
||||
}
|
||||
#else
|
||||
x_qs[i*(2*WARP_SIZE + 1) + kqsx] = qs0;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
// TODO: does this work with WARP_SIZE != 32?
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_0/2)) {
|
||||
int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_0) + threadIdx.x/(QI2_0/2);
|
||||
|
||||
if (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
|
||||
const block_tq2_0 * bxi = (const block_tq2_0 *) x + kbx0 + i*stride;
|
||||
|
||||
const int k = threadIdx.x % (QI2_0/2);
|
||||
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k] = bxi->d;
|
||||
#else
|
||||
x_df[i*(WARP_SIZE/4) + i/4 + k] = bxi->d;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_tq2_0_q8_1_dp4a(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
||||
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_TQ2_0, mmq_y);
|
||||
const int * x_qs = (const int *) x;
|
||||
const float * x_df = (const float *) x_qs + txs.qs;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const float * y_df = (const float *) y;
|
||||
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_0*VDR_TQ2_0_Q8_1_MMQ) {
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_tq2_0_q8_1_impl<VDR_TQ2_0_Q8_1_MMQ>(
|
||||
&x_qs[i*(2*WARP_SIZE + 1) + k0/QR2_0], &y_qs[j*MMQ_TILE_Y_K + k01],
|
||||
x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2)], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
||||
// x_df[i*(WARP_SIZE/QI2_0) + i/QI2_0], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
|
||||
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
|
||||
|
||||
@ -2427,6 +2530,14 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_TQ2_0> {
|
||||
static constexpr int vdr = VDR_TQ2_0_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_tq2_0<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_tq2_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> {
|
||||
static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
|
||||
@ -2916,6 +3027,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_TQ2_0);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
|
||||
|
@ -14,6 +14,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
|
||||
type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
|
||||
type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
|
||||
type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
|
||||
type == GGML_TYPE_TQ2_0 ? vec_dot_tq2_0_q8_1 :
|
||||
type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
|
||||
type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
|
||||
type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
|
||||
@ -37,6 +38,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
|
||||
type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
|
||||
type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
|
||||
type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
|
||||
type == GGML_TYPE_TQ2_0 ? VDR_TQ2_0_Q8_1_MMVQ :
|
||||
type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ :
|
||||
type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ :
|
||||
type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ :
|
||||
@ -271,6 +273,13 @@ static void mul_mat_vec_q6_K_q8_1_cuda(
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_tq2_0_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
mul_mat_vec_q_cuda<GGML_TYPE_TQ2_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq2_xxs_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
@ -385,6 +394,9 @@ void ggml_cuda_op_mul_mat_vec_q(
|
||||
case GGML_TYPE_Q6_K:
|
||||
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_TQ2_0:
|
||||
mul_mat_vec_tq2_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
|
@ -23,6 +23,7 @@ SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}
|
||||
TYPES_MMQ = [
|
||||
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
|
||||
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
|
||||
"GGML_TYPE_TQ2_0",
|
||||
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
|
||||
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS"
|
||||
]
|
||||
|
@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmq.cuh"
|
||||
|
||||
DECL_MMQ_CASE(GGML_TYPE_TQ2_0);
|
@ -524,6 +524,36 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
|
||||
return d6 * sumf_d;
|
||||
}
|
||||
|
||||
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
||||
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
||||
|
||||
#define VDR_TQ2_0_Q8_1_MMVQ 2
|
||||
#define VDR_TQ2_0_Q8_1_MMQ 8
|
||||
|
||||
// Can use the same for both mmvq and mmq, because there are no sub-scales in a TQ2_0 block
|
||||
template <int vdr> static __device__ __forceinline__ float vec_dot_tq2_0_q8_1_impl(
|
||||
const int * __restrict__ v, const int * __restrict__ u, const float & d2, const float * __restrict__ d8) {
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < QR2_0; ++i0) {
|
||||
int sumi = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vdr; ++i) {
|
||||
const int vi = (v[i] >> (2*i0)) & 0x03030303;
|
||||
|
||||
sumi = ggml_cuda_dp4a(__vsub4(vi, 0x01010101), u[vdr*i0 + i], sumi); // SIMD dot product
|
||||
}
|
||||
|
||||
// TODO: batch subtract by using d8 sum
|
||||
sumf += d8[i0] * sumi;
|
||||
}
|
||||
|
||||
return d2 * sumf;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
||||
@ -786,6 +816,37 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
|
||||
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_tq2_0_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
||||
const block_tq2_0 * btq2_0 = (const block_tq2_0 *) vbq + kbx;
|
||||
|
||||
// iqs 0..7 all need bq8_offset 0, 1, 2, 3
|
||||
// iqs 8..15 all need bq8_offset 4, 5, 6, 7
|
||||
const int bq8_offset = QR2_0 * (iqs / 8);
|
||||
|
||||
int v[VDR_TQ2_0_Q8_1_MMVQ];
|
||||
int u[QR2_0*VDR_TQ2_0_Q8_1_MMVQ];
|
||||
float d8[QR2_0];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VDR_TQ2_0_Q8_1_MMVQ; ++i) {
|
||||
v[i] = get_int_b2(btq2_0->qs, iqs + i);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < QR2_0; ++i0) {
|
||||
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i0;
|
||||
|
||||
for (int i = 0; i < VDR_TQ2_0_Q8_1_MMVQ; ++i) {
|
||||
u[VDR_TQ2_0_Q8_1_MMVQ*i0 + i] = get_int_b4(bq8i->qs, (iqs % QI8_1) + i);
|
||||
}
|
||||
d8[i0] = __low2float(bq8i->ds);
|
||||
}
|
||||
|
||||
return vec_dot_tq2_0_q8_1_impl<VDR_TQ2_0_Q8_1_MMVQ>(v, u, btq2_0->d, d8);
|
||||
}
|
||||
|
||||
#define VDR_IQ2_XXS_Q8_1_MMVQ 2
|
||||
#define VDR_IQ2_XXS_Q8_1_MMQ 2
|
||||
|
||||
|
@ -3375,7 +3375,8 @@ static const ggml_type all_types[] = {
|
||||
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
|
||||
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
|
||||
GGML_TYPE_Q6_K,
|
||||
// GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
|
||||
// GGML_TYPE_TQ1_0,
|
||||
GGML_TYPE_TQ2_0,
|
||||
GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
|
||||
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
|
||||
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
|
||||
@ -3387,6 +3388,7 @@ static const ggml_type base_types[] = {
|
||||
GGML_TYPE_Q4_0,
|
||||
GGML_TYPE_Q4_1, // for I8MM tests
|
||||
GGML_TYPE_Q4_K,
|
||||
GGML_TYPE_TQ2_0,
|
||||
GGML_TYPE_IQ2_XXS
|
||||
};
|
||||
|
||||
@ -3397,7 +3399,8 @@ static const ggml_type other_types[] = {
|
||||
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
|
||||
GGML_TYPE_Q5_K,
|
||||
GGML_TYPE_Q6_K,
|
||||
// GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
|
||||
// GGML_TYPE_TQ1_0,
|
||||
GGML_TYPE_TQ2_0,
|
||||
GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
|
||||
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
|
||||
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
|
||||
|
Loading…
x
Reference in New Issue
Block a user