mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 05:48:47 +01:00
ggml_cuda_cpy
support for 4d tensors and float16->float32 upcasting (ggml/686)
* added cuda float16->float32 upcasting to ggml_cuda_cpy * added ability to copy 4d tensors with the cuda backend * added tests for float16_>float32 upcast and 4d tensor cuda copys * added 4d copy test for float32->float16 copy * applied patch suggested by @iamlemec * simplify cpy tests --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
a4b07c057a
commit
625a699b54
131
ggml-cuda.cu
131
ggml-cuda.cu
@ -5511,27 +5511,37 @@ static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
|
|||||||
*dsti = *xi;
|
*dsti = *xi;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
|
||||||
|
const half * xi = (const half *) cxi;
|
||||||
|
float * dsti = (float *) cdsti;
|
||||||
|
|
||||||
|
*dsti = *xi;
|
||||||
|
}
|
||||||
|
|
||||||
template <cpy_kernel_t cpy_1>
|
template <cpy_kernel_t cpy_1>
|
||||||
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
|
const int nb12, const int nb13) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
if (i >= ne) {
|
if (i >= ne) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
|
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
|
||||||
// then combine those indices with the corresponding byte offsets to get the total offsets
|
// then combine those indices with the corresponding byte offsets to get the total offsets
|
||||||
const int i02 = i / (ne00*ne01);
|
const int i03 = i/(ne00 * ne01 * ne02);
|
||||||
const int i01 = (i - i02*ne01*ne00) / ne00;
|
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||||
const int i00 = i - i02*ne01*ne00 - i01*ne00;
|
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||||
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
|
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
||||||
|
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
||||||
|
|
||||||
const int i12 = i / (ne10*ne11);
|
const int i13 = i/(ne10 * ne11 * ne12);
|
||||||
const int i11 = (i - i12*ne10*ne11) / ne10;
|
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
||||||
const int i10 = i - i12*ne10*ne11 - i11*ne10;
|
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
||||||
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
|
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
||||||
|
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
|
||||||
|
|
||||||
cpy_1(cx + x_offset, cdst + dst_offset);
|
cpy_1(cx + x_offset, cdst + dst_offset);
|
||||||
}
|
}
|
||||||
@ -5625,23 +5635,26 @@ static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
|
|||||||
|
|
||||||
template <cpy_kernel_t cpy_blck, int qk>
|
template <cpy_kernel_t cpy_blck, int qk>
|
||||||
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
|
const int nb12, const int nb13) {
|
||||||
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
||||||
|
|
||||||
if (i >= ne) {
|
if (i >= ne) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int i02 = i / (ne00*ne01);
|
const int i03 = i/(ne00 * ne01 * ne02);
|
||||||
const int i01 = (i - i02*ne01*ne00) / ne00;
|
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||||
const int i00 = (i - i02*ne01*ne00 - i01*ne00);
|
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||||
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
|
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
||||||
|
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
||||||
|
|
||||||
const int i12 = i / (ne10*ne11);
|
const int i13 = i/(ne10 * ne11 * ne12);
|
||||||
const int i11 = (i - i12*ne10*ne11) / ne10;
|
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
||||||
const int i10 = (i - i12*ne10*ne11 - i11*ne10)/qk;
|
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
||||||
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
|
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
||||||
|
const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
|
||||||
|
|
||||||
cpy_blck(cx + x_offset, cdst + dst_offset);
|
cpy_blck(cx + x_offset, cdst + dst_offset);
|
||||||
}
|
}
|
||||||
@ -7308,69 +7321,82 @@ static void ggml_mul_mat_vec_nc_f16_f32_cuda(
|
|||||||
(vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
|
(vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static void ggml_cpy_f16_f32_cuda(
|
||||||
|
const char * cx, char * cdst, const int ne,
|
||||||
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
||||||
|
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||||
|
cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||||
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f32_f32_cuda(
|
static void ggml_cpy_f32_f32_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
||||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||||
cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||||
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f32_f16_cuda(
|
static void ggml_cpy_f32_f16_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
||||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||||
cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||||
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f32_q8_0_cuda(
|
static void ggml_cpy_f32_q8_0_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
||||||
GGML_ASSERT(ne % QK8_0 == 0);
|
GGML_ASSERT(ne % QK8_0 == 0);
|
||||||
const int num_blocks = ne / QK8_0;
|
const int num_blocks = ne / QK8_0;
|
||||||
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
|
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
|
||||||
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f32_q4_0_cuda(
|
static void ggml_cpy_f32_q4_0_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
||||||
GGML_ASSERT(ne % QK4_0 == 0);
|
GGML_ASSERT(ne % QK4_0 == 0);
|
||||||
const int num_blocks = ne / QK4_0;
|
const int num_blocks = ne / QK4_0;
|
||||||
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
|
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
|
||||||
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f32_q4_1_cuda(
|
static void ggml_cpy_f32_q4_1_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
||||||
GGML_ASSERT(ne % QK4_1 == 0);
|
GGML_ASSERT(ne % QK4_1 == 0);
|
||||||
const int num_blocks = ne / QK4_1;
|
const int num_blocks = ne / QK4_1;
|
||||||
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
|
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
|
||||||
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f16_f16_cuda(
|
static void ggml_cpy_f16_f16_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||||
|
|
||||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||||
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||||
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
|
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
|
||||||
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
||||||
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
|
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
|
||||||
@ -10119,19 +10145,25 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
|
|||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t ne01 = src0->ne[1];
|
const int64_t ne01 = src0->ne[1];
|
||||||
GGML_ASSERT(src0->ne[3] == 1);
|
const int64_t ne02 = src0->ne[2];
|
||||||
|
|
||||||
|
//GGML_ASSERT(src0->ne[3] == 1);
|
||||||
|
|
||||||
const int64_t nb00 = src0->nb[0];
|
const int64_t nb00 = src0->nb[0];
|
||||||
const int64_t nb01 = src0->nb[1];
|
const int64_t nb01 = src0->nb[1];
|
||||||
const int64_t nb02 = src0->nb[2];
|
const int64_t nb02 = src0->nb[2];
|
||||||
|
const int64_t nb03 = src0->nb[3];
|
||||||
|
|
||||||
const int64_t ne10 = src1->ne[0];
|
const int64_t ne10 = src1->ne[0];
|
||||||
const int64_t ne11 = src1->ne[1];
|
const int64_t ne11 = src1->ne[1];
|
||||||
GGML_ASSERT(src1->ne[3] == 1);
|
const int64_t ne12 = src1->ne[2];
|
||||||
|
|
||||||
|
//GGML_ASSERT(src1->ne[3] == 1);
|
||||||
|
|
||||||
const int64_t nb10 = src1->nb[0];
|
const int64_t nb10 = src1->nb[0];
|
||||||
const int64_t nb11 = src1->nb[1];
|
const int64_t nb11 = src1->nb[1];
|
||||||
const int64_t nb12 = src1->nb[2];
|
const int64_t nb12 = src1->nb[2];
|
||||||
|
const int64_t nb13 = src1->nb[3];
|
||||||
|
|
||||||
ggml_cuda_set_device(g_main_device);
|
ggml_cuda_set_device(g_main_device);
|
||||||
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||||
@ -10143,17 +10175,19 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
|
|||||||
char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
|
char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
|
||||||
|
|
||||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||||
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
|
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||||
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
|
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
|
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
|
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
|
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||||
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
|
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||||
|
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
|
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
|
||||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||||
@ -11156,6 +11190,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||||||
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
|
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
return false;
|
return false;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
|
@ -1927,8 +1927,10 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {0, 2, 1, 3}));
|
test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {0, 2, 1, 3}));
|
||||||
test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {1, 2, 0, 3}));
|
test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {1, 2, 0, 3}));
|
||||||
|
|
||||||
for (ggml_type type : all_types) {
|
for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
||||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, type, {256, 10, 10, 1}));
|
for (ggml_type type_dst : all_types) {
|
||||||
|
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
test_cases.emplace_back(new test_cont());
|
test_cases.emplace_back(new test_cont());
|
||||||
|
Loading…
Reference in New Issue
Block a user