mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-10-30 06:30:15 +01:00
CUDA: mul_mat_vec_q max. batch size 8 -> 4 (#5370)
This commit is contained in:
parent
b08f22c882
commit
17c97fb062
36
ggml-cuda.cu
36
ggml-cuda.cu
@ -6831,7 +6831,7 @@ static void mul_mat_vec_q_cuda(
|
|||||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, cudaStream_t stream) {
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, cudaStream_t stream) {
|
||||||
|
|
||||||
GGML_ASSERT(ncols_x % qk == 0);
|
GGML_ASSERT(ncols_x % qk == 0);
|
||||||
GGML_ASSERT(ncols_y <= 8);
|
GGML_ASSERT(ncols_y <= 4);
|
||||||
|
|
||||||
const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||||
const dim3 block_nums(block_num_y, 1, 1);
|
const dim3 block_nums(block_num_y, 1, 1);
|
||||||
@ -6853,22 +6853,22 @@ static void mul_mat_vec_q_cuda(
|
|||||||
mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
|
mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||||
break;
|
break;
|
||||||
case 5:
|
// case 5:
|
||||||
mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
|
// mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||||
break;
|
// break;
|
||||||
case 6:
|
// case 6:
|
||||||
mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
|
// mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||||
break;
|
// break;
|
||||||
case 7:
|
// case 7:
|
||||||
mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
|
// mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||||
break;
|
// break;
|
||||||
case 8:
|
// case 8:
|
||||||
mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
|
// mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||||
break;
|
// break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
// mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
|
// mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
@ -9909,7 +9909,7 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|||||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
|
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (src1->ne[1] <= 8 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type)) {
|
if (src1->ne[1] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type)) {
|
||||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
|
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
|
||||||
} else if (use_mul_mat_q) {
|
} else if (use_mul_mat_q) {
|
||||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
|
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
|
||||||
|
Loading…
Reference in New Issue
Block a user