CUDA: support for mat. mul. with ne03 != ne13 (#11656)

This commit is contained in:
Johannes Gäßler 2025-02-05 08:58:31 +01:00 committed by GitHub
parent 1ec208083c
commit fa62da9b2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 81 additions and 60 deletions

View File

@ -1366,8 +1366,6 @@ static void ggml_cuda_op_mul_mat(
const int64_t ne13 = src1->ne[3]; const int64_t ne13 = src1->ne[3];
const int64_t nrows1 = ggml_nrows(src1); const int64_t nrows1 = ggml_nrows(src1);
GGML_ASSERT(ne03 == ne13);
const int64_t ne0 = dst->ne[0]; const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1]; const int64_t ne1 = dst->ne[1];
@ -1381,9 +1379,11 @@ static void ggml_cuda_op_mul_mat(
GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1)); GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0); GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
const int64_t i02_divisor = ne12 / ne02; const int64_t i02_divisor = ne12 / ne02;
const int64_t i03_divisor = ne13 / ne03;
const size_t src0_ts = ggml_type_size(src0->type); const size_t src0_ts = ggml_type_size(src0->type);
const size_t src0_bs = ggml_blck_size(src0->type); const size_t src0_bs = ggml_blck_size(src0->type);
@ -1399,6 +1399,7 @@ static void ggml_cuda_op_mul_mat(
GGML_ASSERT(!(split && ne02 > 1)); GGML_ASSERT(!(split && ne02 > 1));
GGML_ASSERT(!(split && ne03 > 1)); GGML_ASSERT(!(split && ne03 > 1));
GGML_ASSERT(!(split && ne02 < ne12)); GGML_ASSERT(!(split && ne02 < ne12));
GGML_ASSERT(!(split && ne03 < ne13));
ggml_tensor_extra_gpu * src0_extra = split ? (ggml_tensor_extra_gpu *) src0->extra : nullptr; ggml_tensor_extra_gpu * src0_extra = split ? (ggml_tensor_extra_gpu *) src0->extra : nullptr;
@ -1562,7 +1563,8 @@ static void ggml_cuda_op_mul_mat(
} }
// for split tensors the data begins at i0 == i0_offset_low // for split tensors the data begins at i0 == i0_offset_low
char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs; const size_t nbytes_src0_matrix = ne01*ne00*src0_ts / src0_bs;
char * src0_dd_i = dev[id].src0_dd + ((i03/i03_divisor)*ne02 + (i02/i02_divisor)) * nbytes_src0_matrix;
float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10; float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
char * src1_ddq_i = dev[id].src1_ddq + src1_ddq_i_offset; char * src1_ddq_i = dev[id].src1_ddq + src1_ddq_i_offset;
float * dst_dd_i = dev[id].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff); float * dst_dd_i = dev[id].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff);
@ -1606,8 +1608,9 @@ static void ggml_cuda_op_mul_mat(
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
} }
if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) { if (src1_col_0 == 0 && !src0_is_contiguous && i03 % i03_divisor == 0 && i02 % i02_divisor == 0) {
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream)); CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
src0_dd_i, src0, i03/i03_divisor, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
} }
// do the computation // do the computation
@ -1882,7 +1885,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
if (!split && use_mul_mat_vec && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { if (!split && use_mul_mat_vec && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
// the custom F16 vector kernel can be used over batched cuBLAS GEMM // the custom F16 vector kernel can be used over batched cuBLAS GEMM
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
ggml_cuda_mul_mat_vec(ctx, src0, src1, dst); ggml_cuda_mul_mat_vec(ctx, src0, src1, dst);
@ -2216,12 +2219,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_rms_norm_back(ctx, dst); ggml_cuda_op_rms_norm_back(ctx, dst);
break; break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
return false;
} else {
ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
}
break; break;
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
ggml_cuda_mul_mat_id(ctx, dst); ggml_cuda_mul_mat_id(ctx, dst);
@ -2998,9 +2996,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) { if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
return false; return false;
} }
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
return false;
}
#ifdef GGML_USE_MUSA #ifdef GGML_USE_MUSA
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 && if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
!ggml_is_transposed(a) && !ggml_is_transposed(b)) { !ggml_is_transposed(a) && !ggml_is_transposed(b)) {

View File

@ -1,18 +1,21 @@
#include "ggml.h"
#include "common.cuh" #include "common.cuh"
#include "mmv.cuh" #include "mmv.cuh"
template <typename T, typename type_acc, int block_size> template <typename T, typename type_acc, int block_size>
static __global__ void mul_mat_vec( static __global__ void mul_mat_vec(
const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row, const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) { const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
const int64_t row = blockIdx.x; const int64_t row = blockIdx.x;
const int64_t channel = blockIdx.z; const int64_t channel = blockIdx.y;
const int64_t sample = blockIdx.z;
const int tid = threadIdx.x; const int tid = threadIdx.x;
constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int warp_size = ggml_cuda_get_physical_warp_size();
x += (channel/channel_ratio)*stride_channel_x + row*stride_row; x += (sample/sample_ratio)*stride_sample_x + (channel/channel_ratio)*stride_channel_x + row*stride_row;
y += channel *stride_channel_y; y += sample *stride_sample_y + channel *stride_channel_y;
dst += channel *stride_channel_dst; dst += sample *stride_sample_dst + channel *stride_channel_dst;
const float2 * y2 = (const float2 *) y; const float2 * y2 = (const float2 *) y;
@ -91,12 +94,15 @@ template <typename T, typename type_acc>
static void launch_mul_mat_vec_cuda( static void launch_mul_mat_vec_cuda(
const T * x, const float * y, float * dst, const T * x, const float * y, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) { cudaStream_t stream) {
GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(ncols % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0); GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(nchannels_y % nchannels_x == 0); GGML_ASSERT(nchannels_y % nchannels_x == 0);
GGML_ASSERT(nsamples_y % nsamples_x == 0);
const int64_t channel_ratio = nchannels_y / nchannels_x; const int64_t channel_ratio = nchannels_y / nchannels_x;
const int64_t sample_ratio = nsamples_y / nsamples_x;
int device; int device;
int warp_size; int warp_size;
@ -118,40 +124,48 @@ static void launch_mul_mat_vec_cuda(
} }
const int smem = warp_size*sizeof(float); const int smem = warp_size*sizeof(float);
const dim3 block_nums(nrows, 1, nchannels_y); const dim3 block_nums(nrows, nchannels_y, nsamples_y);
const dim3 block_dims(block_size_best, 1, 1); const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) { switch (block_size_best) {
case 32: { case 32: {
mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>> mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 64: { case 64: {
mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>> mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 96: { case 96: {
mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>> mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 128: { case 128: {
mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>> mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 160: { case 160: {
mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>> mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 192: { case 192: {
mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>> mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 224: { case 224: {
mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>> mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
case 256: { case 256: {
mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>> mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break; } break;
default: { default: {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
@ -163,16 +177,19 @@ template<typename T>
static void mul_mat_vec_cuda( static void mul_mat_vec_cuda(
const T * x, const float * y, float * dst, const T * x, const float * y, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
enum ggml_prec prec, cudaStream_t stream) { enum ggml_prec prec, cudaStream_t stream) {
switch (prec) { switch (prec) {
case GGML_PREC_DEFAULT: { case GGML_PREC_DEFAULT: {
launch_mul_mat_vec_cuda<T, half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, launch_mul_mat_vec_cuda<T, half>
stride_channel_x, stride_channel_y, stride_channel_dst, stream); (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break; } break;
case GGML_PREC_F32: { case GGML_PREC_F32: {
launch_mul_mat_vec_cuda<T, float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, launch_mul_mat_vec_cuda<T, float>
stride_channel_x, stride_channel_y, stride_channel_dst, stream); (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break; } break;
} }
} }
@ -181,10 +198,19 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0]; GGML_TENSOR_BINARY_OP_LOCALS;
const int64_t ne01 = src0->ne[1];
GGML_ASSERT(src1->ne[1] == 1); const size_t ts_src0 = ggml_type_size(src0->type);
const size_t ts_src1 = ggml_type_size(src1->type);
const size_t ts_dst = ggml_type_size(dst->type);
GGML_ASSERT(ne11 == 1);
GGML_ASSERT(ne12 == ne2);
GGML_ASSERT(ne13 == ne3);
GGML_ASSERT(nb00 == ts_src0);
GGML_ASSERT(nb10 == ts_src1);
GGML_ASSERT(nb0 == ts_dst);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
@ -192,29 +218,22 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
const float * src1_d = (const float *) src1->data; const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data; float * dst_d = (float *) dst->data;
const int64_t ne02 = src0->ne[2]; const int64_t s01 = src0->nb[1] / ts_src0;
const int64_t ne12 = src1->ne[2]; const int64_t s02 = src0->nb[2] / ts_src0;
GGML_ASSERT(dst->ne[2] == ne12); const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s2 = dst->nb[2] / ts_dst;
GGML_ASSERT(src0->ne[3] == 1); const int64_t s03 = src0->nb[3] / ts_src0;
GGML_ASSERT(src1->ne[3] == 1); const int64_t s13 = src1->nb[3] / ts_src1;
GGML_ASSERT( dst->ne[3] == 1); const int64_t s3 = dst->nb[3] / ts_dst;
const int64_t stride_row = src0->nb[1] / ggml_type_size(src0->type);
const int64_t channel_stride_x = src0->nb[2] / ggml_type_size(src0->type);
const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type);
const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type);
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F16: { case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0->data; const half * src0_d = (const half *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
} break; } break;
case GGML_TYPE_BF16: { case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
} break; } break;
default: default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
@ -243,20 +262,27 @@ void ggml_cuda_op_mul_mat_vec(
const int64_t stride_row = ne00; const int64_t stride_row = ne00;
const int64_t nchannels_x = 1; const int64_t nchannels_x = 1;
const int64_t nchannels_y = 1; const int64_t nchannels_y = 1;
const int64_t channel_stride_x = 0; const int64_t stride_channel_x = 0;
const int64_t channel_stride_y = 0; const int64_t stride_channel_y = 0;
const int64_t channel_stride_dst = 0; const int64_t stride_channel_dst = 0;
const int64_t nsamples_x = 1;
const int64_t nsamples_y = 1;
const int64_t stride_sample_x = 0;
const int64_t stride_sample_y = 0;
const int64_t stride_sample_dst = 0;
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F16: { case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0_dd_i; const half * src0_d = (const half *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break; } break;
case GGML_TYPE_BF16: { case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break; } break;
default: default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));