mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 21:37:19 +01:00
CUDA: mul_mat_vec_q tiling, refactor mul mat logic (#5434)
* CUDA: mul_mat_vec_q tiling, refactor mul mat logic Co-authored-by: slaren <slarengh@gmail.com> --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
2891c8aa9a
commit
3bdc4cd0f5
223
ggml-cuda.cu
223
ggml-cuda.cu
@ -150,8 +150,8 @@
|
|||||||
#define CUDA_USE_TENSOR_CORES
|
#define CUDA_USE_TENSOR_CORES
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// max batch size to use MMQ kernels when tensor cores are available
|
#define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels
|
||||||
#define MMQ_MAX_BATCH_SIZE 32
|
#define MMQ_MAX_BATCH_SIZE 32 // max batch size to use MMQ kernels when tensor cores are available
|
||||||
|
|
||||||
#if defined(GGML_USE_HIPBLAS)
|
#if defined(GGML_USE_HIPBLAS)
|
||||||
#define __CUDA_ARCH__ 1300
|
#define __CUDA_ARCH__ 1300
|
||||||
@ -5310,51 +5310,59 @@ template <bool need_check> static __global__ void
|
|||||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
||||||
}
|
}
|
||||||
|
|
||||||
#define MMVQ_NWARPS_NVIDIA 4
|
template <int ncols_y, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
||||||
#define MMVQ_NWARPS_AMD_RDNA2 1
|
|
||||||
#define MMVQ_NWARPS_AMD_OLD 4
|
|
||||||
|
|
||||||
template <int nwarps, int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(nwarps*WARP_SIZE, 1) // tells the compiler to use as many registers as it wants
|
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
||||||
|
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
static __global__ void mul_mat_vec_q(
|
static __global__ void mul_mat_vec_q(
|
||||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
||||||
|
constexpr int nwarps = 1;
|
||||||
|
constexpr int rows_per_cuda_block = 1;
|
||||||
|
#else
|
||||||
|
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
||||||
|
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
||||||
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
||||||
|
|
||||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||||
const int row = blockIdx.x;
|
const int row0 = rows_per_cuda_block*blockIdx.x;
|
||||||
|
|
||||||
const int blocks_per_row_x = ncols_x / qk;
|
const int blocks_per_row_x = ncols_x / qk;
|
||||||
const int blocks_per_col_y = nrows_y / QK8_1;
|
const int blocks_per_col_y = nrows_y / QK8_1;
|
||||||
const int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
|
constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
|
||||||
|
|
||||||
// partial sum for each thread
|
// partial sum for each thread
|
||||||
float tmp[ncols_y_template != 0 ? ncols_y_template : 8] = {0.0f};
|
float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
|
||||||
|
|
||||||
const block_q_t * x = (const block_q_t *) vx;
|
const block_q_t * x = (const block_q_t *) vx;
|
||||||
const block_q8_1 * y = (const block_q8_1 *) vy;
|
const block_q8_1 * y = (const block_q8_1 *) vy;
|
||||||
|
|
||||||
for (int i = tid / (qi/vdr); i < blocks_per_row_x; i += blocks_per_iter) {
|
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
|
||||||
const int ibx = row*blocks_per_row_x + i; // x block index
|
const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
|
||||||
|
|
||||||
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
|
// x block quant index when casting the quants to int
|
||||||
|
const int kqs = vdr * (tid % (qi/vdr));
|
||||||
const int iqs = vdr * (tid % (qi/vdr)); // x block quant index when casting the quants to int
|
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols_y; ++j) {
|
for (int j = 0; j < ncols_y; ++j) {
|
||||||
tmp[j] += vec_dot_q_cuda(&x[ibx], &y[j*blocks_per_col_y + iby], iqs);
|
#pragma unroll
|
||||||
|
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
||||||
|
tmp[j][i] += vec_dot_q_cuda(
|
||||||
|
&x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y_template != 0 ? ncols_y_template : 8][WARP_SIZE];
|
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
|
||||||
if (threadIdx.y > 0) {
|
if (threadIdx.y > 0) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols_y; ++j) {
|
for (int j = 0; j < ncols_y; ++j) {
|
||||||
tmp_shared[threadIdx.y-1][j][threadIdx.x] = tmp[j];
|
#pragma unroll
|
||||||
|
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
||||||
|
tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@ -5366,13 +5374,16 @@ static __global__ void mul_mat_vec_q(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols_y; ++j) {
|
for (int j = 0; j < ncols_y; ++j) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < nwarps-1; ++i) {
|
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
||||||
tmp[j] += tmp_shared[i][j][threadIdx.x];
|
#pragma unroll
|
||||||
|
for (int l = 0; l < nwarps-1; ++l) {
|
||||||
|
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
|
||||||
|
}
|
||||||
|
tmp[j][i] = warp_reduce_sum(tmp[j][i]);
|
||||||
}
|
}
|
||||||
tmp[j] = warp_reduce_sum(tmp[j]);
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x < rows_per_cuda_block) {
|
||||||
dst[j*nrows_dst + row] = tmp[j];
|
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -6851,65 +6862,75 @@ static void mul_mat_vec_q_cuda(
|
|||||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||||
|
|
||||||
GGML_ASSERT(ncols_x % qk == 0);
|
GGML_ASSERT(ncols_x % qk == 0);
|
||||||
GGML_ASSERT(ncols_y <= 4);
|
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
|
||||||
|
|
||||||
int id;
|
int id;
|
||||||
CUDA_CHECK(cudaGetDevice(&id));
|
CUDA_CHECK(cudaGetDevice(&id));
|
||||||
|
|
||||||
int nwarps;
|
int64_t nwarps = 1;
|
||||||
if (g_device_caps[id].cc >= CC_OFFSET_AMD) {
|
int64_t rows_per_cuda_block = 1;
|
||||||
nwarps = g_device_caps[id].cc >= CC_RDNA2 ? MMVQ_NWARPS_AMD_RDNA2 : MMVQ_NWARPS_AMD_OLD;
|
|
||||||
} else {
|
|
||||||
nwarps = MMVQ_NWARPS_NVIDIA;
|
|
||||||
}
|
|
||||||
|
|
||||||
const dim3 block_nums(nrows_x, 1, 1);
|
if (g_device_caps[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
|
||||||
|
switch(ncols_y) {
|
||||||
|
case 1:
|
||||||
|
nwarps = 4;
|
||||||
|
rows_per_cuda_block = 1;
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
case 3:
|
||||||
|
case 4:
|
||||||
|
nwarps = 4;
|
||||||
|
rows_per_cuda_block = 2;
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
case 6:
|
||||||
|
case 7:
|
||||||
|
case 8:
|
||||||
|
nwarps = 2;
|
||||||
|
rows_per_cuda_block = 2;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
|
||||||
|
const dim3 block_nums(nblocks, 1, 1);
|
||||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||||
|
|
||||||
switch (nwarps) {
|
switch (ncols_y) {
|
||||||
case 1: switch(ncols_y) {
|
|
||||||
case 1:
|
case 1:
|
||||||
mul_mat_vec_q<1, 1, qk, qi, block_q_t, vdr, vec_dot>
|
mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
mul_mat_vec_q<1, 2, qk, qi, block_q_t, vdr, vec_dot>
|
mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
mul_mat_vec_q<1, 3, qk, qi, block_q_t, vdr, vec_dot>
|
mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
mul_mat_vec_q<1, 4, qk, qi, block_q_t, vdr, vec_dot>
|
mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
break;
|
break;
|
||||||
default:
|
case 5:
|
||||||
GGML_ASSERT(false);
|
mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
break;
|
break;
|
||||||
} break;
|
case 6:
|
||||||
case 4: switch(ncols_y) {
|
mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
case 1:
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
mul_mat_vec_q<4, 1, qk, qi, block_q_t, vdr, vec_dot>
|
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 7:
|
||||||
mul_mat_vec_q<4, 2, qk, qi, block_q_t, vdr, vec_dot>
|
mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 8:
|
||||||
mul_mat_vec_q<4, 3, qk, qi, block_q_t, vdr, vec_dot>
|
mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||||
break;
|
break;
|
||||||
case 4:
|
|
||||||
mul_mat_vec_q<4, 4, qk, qi, block_q_t, vdr, vec_dot>
|
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
GGML_ASSERT(false);
|
|
||||||
break;
|
|
||||||
} break;
|
|
||||||
|
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
break;
|
break;
|
||||||
@ -9735,7 +9756,7 @@ static __global__ void k_compute_batched_ptrs(
|
|||||||
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
|
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_mul_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
GGML_ASSERT(!ggml_is_transposed(src0));
|
GGML_ASSERT(!ggml_is_transposed(src0));
|
||||||
GGML_ASSERT(!ggml_is_transposed(src1));
|
GGML_ASSERT(!ggml_is_transposed(src1));
|
||||||
|
|
||||||
@ -9893,39 +9914,69 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|||||||
|
|
||||||
int64_t min_compute_capability = INT_MAX;
|
int64_t min_compute_capability = INT_MAX;
|
||||||
|
|
||||||
|
bool any_pascal_with_slow_fp16 = false;
|
||||||
if (split) {
|
if (split) {
|
||||||
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
|
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
|
||||||
auto & tensor_split = buft_ctx->tensor_split;
|
auto & tensor_split = buft_ctx->tensor_split;
|
||||||
for (int id = 0; id < g_device_count; ++id) {
|
for (int id = 0; id < g_device_count; ++id) {
|
||||||
if (min_compute_capability > g_device_caps[id].cc && tensor_split[id] < (id + 1 < g_device_count ? tensor_split[id + 1] : 1.0f)) {
|
// skip devices that are not going to do any work:
|
||||||
|
if (tensor_split[id] >= (id + 1 < g_device_count ? tensor_split[id + 1] : 1.0f)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (min_compute_capability > g_device_caps[id].cc) {
|
||||||
min_compute_capability = g_device_caps[id].cc;
|
min_compute_capability = g_device_caps[id].cc;
|
||||||
}
|
}
|
||||||
|
if (g_device_caps[id].cc == 610) {
|
||||||
|
any_pascal_with_slow_fp16 = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
min_compute_capability = g_device_caps[g_main_device].cc;
|
min_compute_capability = g_device_caps[g_main_device].cc;
|
||||||
|
any_pascal_with_slow_fp16 = g_device_caps[g_main_device].cc == 610;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check data types and tensor shapes for custom matrix multiplication kernels:
|
||||||
|
bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
|
||||||
|
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||||
|
&& src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1;
|
||||||
|
|
||||||
|
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
|
||||||
|
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||||
|
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||||
|
|
||||||
|
bool use_mul_mat_q = ggml_cuda_supports_mmq(src0->type)
|
||||||
|
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||||||
|
|
||||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
|
||||||
const bool fp16_performance_good = min_compute_capability >= CC_RDNA1;
|
const bool fp16_performance_good = min_compute_capability >= CC_RDNA1;
|
||||||
bool use_mul_mat_q = ggml_is_quantized(src0->type);
|
|
||||||
#ifdef CUDA_USE_TENSOR_CORES
|
#ifdef CUDA_USE_TENSOR_CORES
|
||||||
use_mul_mat_q = use_mul_mat_q && min_compute_capability < CC_RDNA3;
|
use_mul_mat_q = use_mul_mat_q && min_compute_capability < CC_RDNA3;
|
||||||
#endif // CUDA_USE_TENSOR_CORES
|
#endif // CUDA_USE_TENSOR_CORES
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
const bool fp16_performance_good = min_compute_capability >= CC_VOLTA;
|
// fp16 performance is good on Volta or newer and on P100 (compute capability 6.0)
|
||||||
bool use_mul_mat_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
|
const bool fp16_performance_good = min_compute_capability >= CC_PASCAL && !any_pascal_with_slow_fp16;
|
||||||
|
|
||||||
|
// mmvq and mmq need the __dp4a instruction which on NVIDIA is only available for CC >= 6.1
|
||||||
|
use_mul_mat_vec_q = use_mul_mat_vec_q && min_compute_capability >= MIN_CC_DP4A;
|
||||||
|
use_mul_mat_q = use_mul_mat_q && min_compute_capability >= MIN_CC_DP4A;
|
||||||
|
|
||||||
#ifdef CUDA_USE_TENSOR_CORES
|
#ifdef CUDA_USE_TENSOR_CORES
|
||||||
// when tensor cores are available, use them for large batch size
|
// when tensor cores are available, use them for large batch size
|
||||||
// ref: https://github.com/ggerganov/llama.cpp/pull/3776
|
// ref: https://github.com/ggerganov/llama.cpp/pull/3776
|
||||||
use_mul_mat_q = use_mul_mat_q && !(fp16_performance_good && src1->ne[1] > MMQ_MAX_BATCH_SIZE);
|
use_mul_mat_q = use_mul_mat_q && (!fp16_performance_good || src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
|
||||||
#endif // CUDA_USE_TENSOR_CORES
|
#endif // CUDA_USE_TENSOR_CORES
|
||||||
|
|
||||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
|
||||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq(src0->type);
|
// if mmvq is available it's a better choice than dmmv:
|
||||||
|
#ifndef GGML_CUDA_FORCE_DMMV
|
||||||
|
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
|
||||||
|
#endif // GGML_CUDA_FORCE_DMMV
|
||||||
|
|
||||||
// debug helpers
|
// debug helpers
|
||||||
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
|
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
|
||||||
@ -9943,24 +9994,10 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|||||||
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
|
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
|
||||||
} else if (!split && all_on_device && fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
} else if (!split && all_on_device && fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
||||||
// KQ + KQV multi-batch
|
// KQ + KQV multi-batch
|
||||||
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
|
ggml_cuda_mul_mat_batched_cublas(src0, src1, dst);
|
||||||
} else if (src0->type == GGML_TYPE_F32) {
|
} else if (use_dequantize_mul_mat_vec) {
|
||||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
|
|
||||||
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
|
|
||||||
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->type == GGML_TYPE_F32) {
|
|
||||||
#ifdef GGML_CUDA_FORCE_DMMV
|
|
||||||
const bool use_mul_mat_vec_q = false;
|
|
||||||
#else
|
|
||||||
const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
|
|
||||||
#endif // GGML_CUDA_FORCE_DMMV
|
|
||||||
|
|
||||||
if (use_mul_mat_vec_q) {
|
|
||||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
|
|
||||||
} else {
|
|
||||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
|
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
|
||||||
}
|
} else if (use_mul_mat_vec_q) {
|
||||||
} else {
|
|
||||||
if (src1->ne[1] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32) {
|
|
||||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
|
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
|
||||||
} else if (use_mul_mat_q) {
|
} else if (use_mul_mat_q) {
|
||||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
|
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
|
||||||
@ -9968,10 +10005,6 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|||||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
|
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
GGML_ASSERT(false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
template<typename ... Srcs>
|
template<typename ... Srcs>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user