mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 13:58:46 +01:00
Quantized dot products for CUDA mul mat vec (#2067)
This commit is contained in:
parent
051c70dcd5
commit
924dd22fd3
@ -68,8 +68,9 @@ option(LLAMA_ACCELERATE "llama: enable Accelerate framework
|
|||||||
option(LLAMA_BLAS "llama: use BLAS" OFF)
|
option(LLAMA_BLAS "llama: use BLAS" OFF)
|
||||||
set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
|
set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
|
||||||
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
|
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
|
||||||
|
option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF)
|
||||||
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
|
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
|
||||||
set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels")
|
set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
|
||||||
option(LLAMA_CUDA_DMMV_F16 "llama: use 16 bit floats for dmmv CUDA kernels" OFF)
|
option(LLAMA_CUDA_DMMV_F16 "llama: use 16 bit floats for dmmv CUDA kernels" OFF)
|
||||||
set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
|
set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
|
||||||
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
|
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
|
||||||
@ -246,8 +247,14 @@ if (LLAMA_CUBLAS)
|
|||||||
set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h)
|
set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h)
|
||||||
|
|
||||||
add_compile_definitions(GGML_USE_CUBLAS)
|
add_compile_definitions(GGML_USE_CUBLAS)
|
||||||
|
if (LLAMA_CUDA_FORCE_DMMV)
|
||||||
|
add_compile_definitions(GGML_CUDA_FORCE_DMMV)
|
||||||
|
endif()
|
||||||
add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
|
add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
|
||||||
add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y})
|
add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
|
||||||
|
if (DEFINED LLAMA_CUDA_DMMV_Y)
|
||||||
|
add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_DMMV_Y}) # for backwards compatibility
|
||||||
|
endif()
|
||||||
if (LLAMA_CUDA_DMMV_F16)
|
if (LLAMA_CUDA_DMMV_F16)
|
||||||
add_compile_definitions(GGML_CUDA_DMMV_F16)
|
add_compile_definitions(GGML_CUDA_DMMV_F16)
|
||||||
endif()
|
endif()
|
||||||
@ -263,7 +270,7 @@ if (LLAMA_CUBLAS)
|
|||||||
if (LLAMA_CUDA_DMMV_F16)
|
if (LLAMA_CUDA_DMMV_F16)
|
||||||
set(CMAKE_CUDA_ARCHITECTURES "61") # needed for f16 CUDA intrinsics
|
set(CMAKE_CUDA_ARCHITECTURES "61") # needed for f16 CUDA intrinsics
|
||||||
else()
|
else()
|
||||||
set(CMAKE_CUDA_ARCHITECTURES "52") # lowest CUDA 12 standard
|
set(CMAKE_CUDA_ARCHITECTURES "52;61") # lowest CUDA 12 standard + lowest for integer intrinsics
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
||||||
|
13
Makefile
13
Makefile
@ -164,16 +164,21 @@ ifdef LLAMA_CUBLAS
|
|||||||
OBJS += ggml-cuda.o
|
OBJS += ggml-cuda.o
|
||||||
NVCC = nvcc
|
NVCC = nvcc
|
||||||
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
|
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
|
||||||
|
ifdef LLAMA_CUDA_FORCE_DMMV
|
||||||
|
NVCCFLAGS += -DGGML_CUDA_FORCE_DMMV
|
||||||
|
endif # LLAMA_CUDA_FORCE_DMMV
|
||||||
ifdef LLAMA_CUDA_DMMV_X
|
ifdef LLAMA_CUDA_DMMV_X
|
||||||
NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
|
NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
|
||||||
else
|
else
|
||||||
NVCCFLAGS += -DGGML_CUDA_DMMV_X=32
|
NVCCFLAGS += -DGGML_CUDA_DMMV_X=32
|
||||||
endif # LLAMA_CUDA_DMMV_X
|
endif # LLAMA_CUDA_DMMV_X
|
||||||
ifdef LLAMA_CUDA_DMMV_Y
|
ifdef LLAMA_CUDA_MMV_Y
|
||||||
NVCCFLAGS += -DGGML_CUDA_DMMV_Y=$(LLAMA_CUDA_DMMV_Y)
|
NVCCFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y)
|
||||||
|
else ifdef LLAMA_CUDA_DMMV_Y
|
||||||
|
NVCCFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_DMMV_Y) # for backwards compatibility
|
||||||
else
|
else
|
||||||
NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1
|
NVCCFLAGS += -DGGML_CUDA_MMV_Y=1
|
||||||
endif # LLAMA_CUDA_DMMV_Y
|
endif # LLAMA_CUDA_MMV_Y
|
||||||
ifdef LLAMA_CUDA_DMMV_F16
|
ifdef LLAMA_CUDA_DMMV_F16
|
||||||
NVCCFLAGS += -DGGML_CUDA_DMMV_F16
|
NVCCFLAGS += -DGGML_CUDA_DMMV_F16
|
||||||
endif # LLAMA_CUDA_DMMV_F16
|
endif # LLAMA_CUDA_DMMV_F16
|
||||||
|
@ -345,8 +345,9 @@ Building the program with BLAS support may lead to some performance improvements
|
|||||||
|
|
||||||
| Option | Legal values | Default | Description |
|
| Option | Legal values | Default | Description |
|
||||||
|-------------------------|------------------------|---------|-------------|
|
|-------------------------|------------------------|---------|-------------|
|
||||||
|
| LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 7.0/Turing/RTX 2000 or higher). Does not affect k-quants. |
|
||||||
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
|
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
|
||||||
| LLAMA_CUDA_DMMV_Y | Positive integer | 1 | Block size in y direction for the CUDA dequantization + mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. |
|
| LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. |
|
||||||
| LLAMA_CUDA_DMMV_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels. Can improve performance on relatively recent GPUs. |
|
| LLAMA_CUDA_DMMV_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels. Can improve performance on relatively recent GPUs. |
|
||||||
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
|
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
|
||||||
|
|
||||||
|
379
ggml-cuda.cu
379
ggml-cuda.cu
@ -70,9 +70,11 @@ typedef void (*ggml_cuda_op_t)(
|
|||||||
|
|
||||||
// QK = number of values after dequantization
|
// QK = number of values after dequantization
|
||||||
// QR = QK / number of values before dequantization
|
// QR = QK / number of values before dequantization
|
||||||
|
// QI = number of 32 bit integers before dequantization
|
||||||
|
|
||||||
#define QK4_0 32
|
#define QK4_0 32
|
||||||
#define QR4_0 2
|
#define QR4_0 2
|
||||||
|
#define QI4_0 4
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // delta
|
half d; // delta
|
||||||
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
||||||
@ -81,6 +83,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0
|
|||||||
|
|
||||||
#define QK4_1 32
|
#define QK4_1 32
|
||||||
#define QR4_1 2
|
#define QR4_1 2
|
||||||
|
#define QI4_1 4
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // delta
|
half d; // delta
|
||||||
half m; // min
|
half m; // min
|
||||||
@ -90,6 +93,7 @@ static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong
|
|||||||
|
|
||||||
#define QK5_0 32
|
#define QK5_0 32
|
||||||
#define QR5_0 2
|
#define QR5_0 2
|
||||||
|
#define QI5_0 4
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // delta
|
half d; // delta
|
||||||
uint8_t qh[4]; // 5-th bit of quants
|
uint8_t qh[4]; // 5-th bit of quants
|
||||||
@ -99,6 +103,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5
|
|||||||
|
|
||||||
#define QK5_1 32
|
#define QK5_1 32
|
||||||
#define QR5_1 2
|
#define QR5_1 2
|
||||||
|
#define QI5_1 4
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // delta
|
half d; // delta
|
||||||
half m; // min
|
half m; // min
|
||||||
@ -109,12 +114,25 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) +
|
|||||||
|
|
||||||
#define QK8_0 32
|
#define QK8_0 32
|
||||||
#define QR8_0 1
|
#define QR8_0 1
|
||||||
|
#define QI8_0 8
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // delta
|
half d; // delta
|
||||||
int8_t qs[QK8_0]; // quants
|
int8_t qs[QK8_0]; // quants
|
||||||
} block_q8_0;
|
} block_q8_0;
|
||||||
static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
|
static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
|
||||||
|
|
||||||
|
#define QK8_1 32
|
||||||
|
#define QR8_1 1
|
||||||
|
#define QI8_1 8
|
||||||
|
typedef struct {
|
||||||
|
half d; // delta
|
||||||
|
half s; // unquantized sum
|
||||||
|
int8_t qs[QK8_0]; // quants
|
||||||
|
} block_q8_1;
|
||||||
|
static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding");
|
||||||
|
|
||||||
|
typedef float (*vec_dot_q_cuda_t)(const void * vbq, const block_q8_1 * bq8_1, const int iqs);
|
||||||
|
|
||||||
//================================= k-quants
|
//================================= k-quants
|
||||||
|
|
||||||
#ifdef GGML_QKK_64
|
#ifdef GGML_QKK_64
|
||||||
@ -198,14 +216,15 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
|||||||
#define CUDA_SCALE_BLOCK_SIZE 256
|
#define CUDA_SCALE_BLOCK_SIZE 256
|
||||||
#define CUDA_ROPE_BLOCK_SIZE 256
|
#define CUDA_ROPE_BLOCK_SIZE 256
|
||||||
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
|
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
|
||||||
|
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
||||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||||
|
|
||||||
// dmmv = dequantize_mul_mat_vec
|
// dmmv = dequantize_mul_mat_vec
|
||||||
#ifndef GGML_CUDA_DMMV_X
|
#ifndef GGML_CUDA_DMMV_X
|
||||||
#define GGML_CUDA_DMMV_X 32
|
#define GGML_CUDA_DMMV_X 32
|
||||||
#endif
|
#endif
|
||||||
#ifndef GGML_CUDA_DMMV_Y
|
#ifndef GGML_CUDA_MMV_Y
|
||||||
#define GGML_CUDA_DMMV_Y 1
|
#define GGML_CUDA_MMV_Y 1
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef K_QUANTS_PER_ITERATION
|
#ifndef K_QUANTS_PER_ITERATION
|
||||||
@ -270,7 +289,6 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums
|
// sum up partial sums
|
||||||
__syncthreads();
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||||
@ -714,7 +732,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
__syncthreads();
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||||
@ -819,7 +836,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
__syncthreads();
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||||
@ -923,7 +939,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
__syncthreads();
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||||
@ -1028,7 +1043,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
__syncthreads();
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||||
@ -1139,7 +1153,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
__syncthreads();
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||||
@ -1158,6 +1171,41 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
|
|||||||
v.y = x[ib + iqs + 1];
|
v.y = x[ib + iqs + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __global__ void quantize_q8_1(const float * x, void * vy, const int k) {
|
||||||
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
block_q8_1 * y = (block_q8_1 *) vy;
|
||||||
|
|
||||||
|
const int ib = i / QK8_0; // block index
|
||||||
|
const int iqs = i % QK8_0; // quant index
|
||||||
|
|
||||||
|
const float xi = x[i];
|
||||||
|
float amax = fabsf(xi);
|
||||||
|
float sum = xi;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
|
||||||
|
sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = amax / 127;
|
||||||
|
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
|
||||||
|
|
||||||
|
y[ib].qs[iqs] = q;
|
||||||
|
|
||||||
|
if (iqs > 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
y[ib].d = d;
|
||||||
|
y[ib].s = sum;
|
||||||
|
}
|
||||||
|
|
||||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||||
static __global__ void dequantize_block(const void * vx, float * y, const int k) {
|
static __global__ void dequantize_block(const void * vx, float * y, const int k) {
|
||||||
const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
|
const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
|
||||||
@ -1179,6 +1227,182 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
|
|||||||
y[iybs + iqs + y_offset] = v.y;
|
y[iybs + iqs + y_offset] = v.y;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
|
||||||
|
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
|
||||||
|
const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
|
||||||
|
|
||||||
|
int vi;
|
||||||
|
memcpy(&vi, &bq4_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
|
||||||
|
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
|
||||||
|
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_0)]);
|
||||||
|
|
||||||
|
const float d = __half2float(bq4_0->d) * __half2float(bq8_1->d);
|
||||||
|
|
||||||
|
// subtract 8 from each quantized value
|
||||||
|
const int vi0 = __vsub4((vi >> 0) & 0x0F0F0F0F, 0x08080808);
|
||||||
|
const int vi1 = __vsub4((vi >> 4) & 0x0F0F0F0F, 0x08080808);
|
||||||
|
|
||||||
|
// SIMD dot product of quantized values
|
||||||
|
int sumi = __dp4a(vi0, ui0, 0);
|
||||||
|
sumi = __dp4a(vi1, ui1, sumi);
|
||||||
|
|
||||||
|
return sumi*d;
|
||||||
|
#else
|
||||||
|
return 0.0f; // only to satisfy the compiler
|
||||||
|
#endif // __CUDA_ARCH__ >= 600
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
|
||||||
|
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
|
||||||
|
const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
|
||||||
|
|
||||||
|
const int vi = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]);
|
||||||
|
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
|
||||||
|
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_1)]);
|
||||||
|
|
||||||
|
const float d = __half2float(bq4_1->d) * __half2float(bq8_1->d);
|
||||||
|
const float m = bq4_1->m;
|
||||||
|
const float s = bq8_1->s;
|
||||||
|
|
||||||
|
const int vi0 = (vi >> 0) & 0x0F0F0F0F;
|
||||||
|
const int vi1 = (vi >> 4) & 0x0F0F0F0F;
|
||||||
|
|
||||||
|
// SIMD dot product of quantized values
|
||||||
|
int sumi = __dp4a(vi0, ui0, 0);
|
||||||
|
sumi = __dp4a(vi1, ui1, sumi);
|
||||||
|
|
||||||
|
return sumi*d + m*s / QI4_1; // scale sum by QI4_1 because there are QI4_1 threads working on this block
|
||||||
|
#else
|
||||||
|
return 0.0f; // only to satisfy the compiler
|
||||||
|
#endif // __CUDA_ARCH__ >= 600
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
|
||||||
|
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
|
||||||
|
const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
|
||||||
|
|
||||||
|
int qs;
|
||||||
|
memcpy(&qs, &bq5_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
|
||||||
|
const int qh0 = bq5_0->qh[iqs/2 + 0] >> 4*(iqs%2);
|
||||||
|
const int qh1 = bq5_0->qh[iqs/2 + 2] >> 4*(iqs%2);
|
||||||
|
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
|
||||||
|
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_0)]);
|
||||||
|
|
||||||
|
const float d = __half2float(bq5_0->d) * __half2float(bq8_1->d);
|
||||||
|
|
||||||
|
int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
|
||||||
|
vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5
|
||||||
|
vi0 |= (qh0 << 11) & 0x00001000; // 2 -> 13
|
||||||
|
vi0 |= (qh0 << 18) & 0x00100000; // 3 -> 21
|
||||||
|
vi0 |= (qh0 << 25) & 0x10000000; // 4 -> 29
|
||||||
|
vi0 = __vsub4(vi0, 0x10101010); // subtract 16 from quantized values
|
||||||
|
int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values
|
||||||
|
|
||||||
|
int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits
|
||||||
|
vi1 |= (qh1 << 4) & 0x00000010; // 1 -> 5
|
||||||
|
vi1 |= (qh1 << 11) & 0x00001000; // 2 -> 13
|
||||||
|
vi1 |= (qh1 << 18) & 0x00100000; // 3 -> 21
|
||||||
|
vi1 |= (qh1 << 25) & 0x10000000; // 4 -> 29
|
||||||
|
vi1 = __vsub4(vi1, 0x10101010); // subtract 16 from quantized values
|
||||||
|
sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values
|
||||||
|
|
||||||
|
return sumi*d;
|
||||||
|
#else
|
||||||
|
return 0.0f; // only to satisfy the compiler
|
||||||
|
#endif // __CUDA_ARCH__ >= 600
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
|
||||||
|
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
|
||||||
|
const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
|
||||||
|
|
||||||
|
const int qs = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]);
|
||||||
|
const int qh0 = bq5_1->qh[iqs/2 + 0] >> 4*(iqs%2);
|
||||||
|
const int qh1 = bq5_1->qh[iqs/2 + 2] >> 4*(iqs%2);
|
||||||
|
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
|
||||||
|
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_1)]);
|
||||||
|
|
||||||
|
const float d = __half2float(bq5_1->d) * __half2float(bq8_1->d);
|
||||||
|
const float m = bq5_1->m;
|
||||||
|
const float s = bq8_1->s;
|
||||||
|
|
||||||
|
int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
|
||||||
|
vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5
|
||||||
|
vi0 |= (qh0 << 11) & 0x00001000; // 2 -> 13
|
||||||
|
vi0 |= (qh0 << 18) & 0x00100000; // 3 -> 21
|
||||||
|
vi0 |= (qh0 << 25) & 0x10000000; // 4 -> 29
|
||||||
|
int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values
|
||||||
|
|
||||||
|
int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits
|
||||||
|
vi1 |= (qh1 << 4) & 0x00000010; // 1 -> 5
|
||||||
|
vi1 |= (qh1 << 11) & 0x00001000; // 2 -> 13
|
||||||
|
vi1 |= (qh1 << 18) & 0x00100000; // 3 -> 21
|
||||||
|
vi1 |= (qh1 << 25) & 0x10000000; // 4 -> 29
|
||||||
|
sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values
|
||||||
|
|
||||||
|
return sumi*d + m*s / QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block
|
||||||
|
#else
|
||||||
|
return 0.0f; // only to satisfy the compiler
|
||||||
|
#endif // __CUDA_ARCH__ >= 600
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
|
||||||
|
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
|
||||||
|
const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
|
||||||
|
|
||||||
|
int vi;
|
||||||
|
memcpy(&vi, &bq8_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
|
||||||
|
const int ui = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
|
||||||
|
|
||||||
|
const float d = __half2float(bq8_0->d) * __half2float(bq8_1->d);
|
||||||
|
|
||||||
|
// SIMD dot product of quantized values
|
||||||
|
int sumi = __dp4a(vi, ui, 0);
|
||||||
|
|
||||||
|
return sumi*d;
|
||||||
|
#else
|
||||||
|
return 0.0f; // only to satisfy the compiler
|
||||||
|
#endif // __CUDA_ARCH__ >= 600
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int qk, int qi, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda>
|
||||||
|
static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
|
||||||
|
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||||
|
|
||||||
|
if (row >= nrows) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int blocks_per_row = ncols / qk;
|
||||||
|
const int blocks_per_warp = WARP_SIZE / qi;
|
||||||
|
|
||||||
|
// partial sum for each thread
|
||||||
|
float tmp = 0.0f;
|
||||||
|
|
||||||
|
const block_q_t * x = (const block_q_t *) vx;
|
||||||
|
const block_q8_1 * y = (const block_q8_1 *) vy;
|
||||||
|
|
||||||
|
for (int i = 0; i < blocks_per_row; i += blocks_per_warp) {
|
||||||
|
const int ibx = row*blocks_per_row + i + threadIdx.x / qi; // x block index
|
||||||
|
|
||||||
|
const int iby = i + threadIdx.x / qi; // y block index
|
||||||
|
|
||||||
|
const int iqs = threadIdx.x % qi; // x block quant index when casting the quants to int
|
||||||
|
|
||||||
|
tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// sum up partial sums and write back result
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
dst[row] = tmp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||||
static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
|
static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
|
||||||
// qk = quantized weights per x block
|
// qk = quantized weights per x block
|
||||||
@ -1233,7 +1457,6 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
__syncthreads();
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||||
@ -1284,7 +1507,6 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
|
|||||||
const int idst = channel*nrows_dst + row_dst;
|
const int idst = channel*nrows_dst + row_dst;
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
__syncthreads();
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||||
@ -1330,7 +1552,6 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
__syncthreads();
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||||
@ -1440,7 +1661,6 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums
|
// sum up partial sums
|
||||||
__syncthreads();
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||||
@ -1494,6 +1714,11 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
|
|||||||
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
|
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int k, cudaStream_t stream) {
|
||||||
|
const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
||||||
|
quantize_q8_1<<<num_blocks, CUDA_QUANTIZE_BLOCK_SIZE, 0, stream>>>(x, vy, k);
|
||||||
|
}
|
||||||
|
|
||||||
static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||||
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
||||||
dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||||
@ -1562,45 +1787,45 @@ static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cu
|
|||||||
|
|
||||||
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||||
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
|
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||||
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
|
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||||
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
|
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||||
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
|
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||||
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
|
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
@ -1647,6 +1872,51 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
|
|||||||
dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
|
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||||
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||||
|
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, vec_dot_q4_0_q8_1>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
|
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||||
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||||
|
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, vec_dot_q4_1_q8_1>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
|
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||||
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||||
|
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, vec_dot_q5_0_q8_1>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
|
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||||
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||||
|
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, vec_dot_q5_1_q8_1>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
|
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||||
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||||
|
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, vec_dot_q8_0_q8_1>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||||
|
}
|
||||||
|
|
||||||
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||||
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
||||||
dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||||
@ -1654,9 +1924,9 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
|
|||||||
|
|
||||||
static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||||
dequantize_mul_mat_vec<1, 1, convert_f16>
|
dequantize_mul_mat_vec<1, 1, convert_f16>
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
@ -1822,6 +2092,7 @@ static size_t g_scratch_offset = 0;
|
|||||||
|
|
||||||
static int g_device_count = -1;
|
static int g_device_count = -1;
|
||||||
static int g_main_device = 0;
|
static int g_main_device = 0;
|
||||||
|
static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
|
||||||
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
|
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
|
||||||
|
|
||||||
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
||||||
@ -1839,9 +2110,12 @@ void ggml_init_cublas() {
|
|||||||
for (int id = 0; id < g_device_count; ++id) {
|
for (int id = 0; id < g_device_count; ++id) {
|
||||||
cudaDeviceProp prop;
|
cudaDeviceProp prop;
|
||||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
|
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
|
||||||
fprintf(stderr, " Device %d: %s\n", id, prop.name);
|
fprintf(stderr, " Device %d: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor);
|
||||||
|
|
||||||
g_tensor_split[id] = total_vram;
|
g_tensor_split[id] = total_vram;
|
||||||
total_vram += prop.totalGlobalMem;
|
total_vram += prop.totalGlobalMem;
|
||||||
|
|
||||||
|
g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
|
||||||
}
|
}
|
||||||
for (int id = 0; id < g_device_count; ++id) {
|
for (int id = 0; id < g_device_count; ++id) {
|
||||||
g_tensor_split[id] /= total_vram;
|
g_tensor_split[id] /= total_vram;
|
||||||
@ -2057,7 +2331,7 @@ inline void ggml_cuda_op_rms_norm(
|
|||||||
(void) i1;
|
(void) i1;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
inline void ggml_cuda_op_mul_mat_vec(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
|
||||||
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
|
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
|
||||||
cudaStream_t & cudaStream_main){
|
cudaStream_t & cudaStream_main){
|
||||||
@ -2069,7 +2343,53 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
|||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t nrows = i01_high - i01_low;
|
const int64_t nrows = i01_high - i01_low;
|
||||||
|
|
||||||
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
|
#ifdef GGML_CUDA_FORCE_DMMV
|
||||||
|
const bool use_mul_mat_vec_q = false;
|
||||||
|
#else
|
||||||
|
int id;
|
||||||
|
CUDA_CHECK(cudaGetDevice(&id));
|
||||||
|
|
||||||
|
const bool mul_mat_vec_q_implemented = src0->type == GGML_TYPE_Q4_0 ||
|
||||||
|
src0->type == GGML_TYPE_Q4_1 ||
|
||||||
|
src0->type == GGML_TYPE_Q5_0 ||
|
||||||
|
src0->type == GGML_TYPE_Q5_1 ||
|
||||||
|
src0->type == GGML_TYPE_Q8_0;
|
||||||
|
|
||||||
|
// The integer intrinsics used in mul_mat_vec_q are available with compute capability 6.
|
||||||
|
// However, they have bad performance with Pascal cards.
|
||||||
|
// Therefore, in a multi GPU setting decide at runtime which GPUs should use mul_mat_vec_q.
|
||||||
|
const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 700 && mul_mat_vec_q_implemented;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (use_mul_mat_vec_q) {
|
||||||
|
size_t as;
|
||||||
|
void * src1_q8_1 = ggml_cuda_pool_malloc(ne00*sizeof(block_q8_1)/QK8_1, &as);
|
||||||
|
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, cudaStream_main);
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
mul_mat_vec_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
|
break;
|
||||||
|
case GGML_TYPE_Q4_1:
|
||||||
|
mul_mat_vec_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
|
break;
|
||||||
|
case GGML_TYPE_Q5_0:
|
||||||
|
mul_mat_vec_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
|
break;
|
||||||
|
case GGML_TYPE_Q5_1:
|
||||||
|
mul_mat_vec_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
|
break;
|
||||||
|
case GGML_TYPE_Q8_0:
|
||||||
|
mul_mat_vec_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_cuda_pool_free(src1_q8_1, as);
|
||||||
|
} else {
|
||||||
|
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
|
||||||
#ifdef GGML_CUDA_DMMV_F16
|
#ifdef GGML_CUDA_DMMV_F16
|
||||||
size_t ash;
|
size_t ash;
|
||||||
dfloat * src1_dfloat = nullptr; // dfloat == half
|
dfloat * src1_dfloat = nullptr; // dfloat == half
|
||||||
@ -2132,6 +2452,7 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
|||||||
ggml_cuda_pool_free(src1_dfloat, ash);
|
ggml_cuda_pool_free(src1_dfloat, ash);
|
||||||
}
|
}
|
||||||
#endif // GGML_CUDA_DMMV_F16
|
#endif // GGML_CUDA_DMMV_F16
|
||||||
|
}
|
||||||
|
|
||||||
(void) src1;
|
(void) src1;
|
||||||
(void) dst;
|
(void) dst;
|
||||||
@ -2701,8 +3022,8 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
|
|||||||
}else if (src0->type == GGML_TYPE_F32) {
|
}else if (src0->type == GGML_TYPE_F32) {
|
||||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
|
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
|
||||||
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
|
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
|
||||||
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[1] % GGML_CUDA_DMMV_Y == 0) {
|
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
|
||||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false, false);
|
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false);
|
||||||
} else {
|
} else {
|
||||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
|
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user