mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 13:28:50 +01:00
llava : add MobileVLM support (#5132)
* New Feature: 1. Sum_Rows: fix cuda kernel overflow fix block shape error when nrows too big 2. Im2Col: Support Batch in cuda Support f32 to f32 both in cpu && cuda 3. DepthWiseConv: Support by Im2Col && MulMat 4. Pool_2d: Supoort avg pooling in cuda 5. HardSigmoid: Imp in cuda 6. HardSwish: Imp in cuda * fix tabs instead of spaces * code clean * CUDA POOL2D * ADD POOL2D test case in test-backend-ops.cpp * code clean * fix pool2d_kernel nits * fix bug in pool2d kernel * fix avg pooling, count_include_pad nits * test-backend-ops : add more pool_2d tests * cuda : fix warnings and formatting * ggml : check types in release builds too in pool_2d * test-backend-ops : remove f16 pool_2d tests * cuda : more style fixes * Add assert in ggml_cuda_op_pool2d * pool2d float padding fallback * test-backend-ops : add dst_type to im2col --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
b2b9f025e7
commit
15606309a0
@ -111,17 +111,71 @@ llama_print_timings: eval time = 1279.03 ms / 18 runs ( 71.06 m
|
|||||||
llama_print_timings: total time = 34570.79 ms
|
llama_print_timings: total time = 34570.79 ms
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Orin compile and run
|
||||||
|
### compile
|
||||||
|
```sh
|
||||||
|
make LLAMA_CUBLAS=1 CUDA_DOCKER_ARCH=sm_87 LLAMA_CUDA_F16=1 -j 32
|
||||||
|
```
|
||||||
|
|
||||||
|
### run on Orin
|
||||||
|
### case 1
|
||||||
|
**input**
|
||||||
|
```sh
|
||||||
|
./llava-cli \
|
||||||
|
-m /data/local/tmp/ggml-model-q4_k.gguf \
|
||||||
|
--mmproj /data/local/tmp/mmproj-model-f16.gguf \
|
||||||
|
--image /data/local/tmp/demo.jpeg \
|
||||||
|
-p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\nWho is the author of this book? \nAnswer the question using a single word or phrase. ASSISTANT:" \
|
||||||
|
--n-gpu-layers 999
|
||||||
|
```
|
||||||
|
**output**
|
||||||
|
```sh
|
||||||
|
|
||||||
|
encode_image_with_clip: image encoded in 296.62 ms by CLIP ( 2.06 ms per image patch)
|
||||||
|
|
||||||
|
Susan Wise Bauer
|
||||||
|
|
||||||
|
llama_print_timings: load time = 1067.64 ms
|
||||||
|
llama_print_timings: sample time = 1.53 ms / 6 runs ( 0.25 ms per token, 3934.43 tokens per second)
|
||||||
|
llama_print_timings: prompt eval time = 306.84 ms / 246 tokens ( 1.25 ms per token, 801.72 tokens per second)
|
||||||
|
llama_print_timings: eval time = 91.50 ms / 6 runs ( 15.25 ms per token, 65.58 tokens per second)
|
||||||
|
llama_print_timings: total time = 1352.63 ms / 252 tokens
|
||||||
|
```
|
||||||
|
|
||||||
|
### case 2
|
||||||
|
**input**
|
||||||
|
```sh
|
||||||
|
./llava-cli \
|
||||||
|
-m /data/local/tmp/ggml-model-q4_k.gguf \
|
||||||
|
--mmproj /data/local/tmp/mmproj-model-f16.gguf \
|
||||||
|
-p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\nWhat is in the image? ASSISTANT:" \
|
||||||
|
--n-gpu-layers 999
|
||||||
|
|
||||||
|
```
|
||||||
|
**output**
|
||||||
|
```sh
|
||||||
|
encode_image_with_clip: image encoded in 302.15 ms by CLIP ( 2.10 ms per image patch)
|
||||||
|
|
||||||
|
The image features a cat lying in the grass.
|
||||||
|
|
||||||
|
llama_print_timings: load time = 1057.07 ms
|
||||||
|
llama_print_timings: sample time = 3.27 ms / 11 runs ( 0.30 ms per token, 3360.83 tokens per second)
|
||||||
|
llama_print_timings: prompt eval time = 213.60 ms / 232 tokens ( 0.92 ms per token, 1086.14 tokens per second)
|
||||||
|
llama_print_timings: eval time = 166.65 ms / 11 runs ( 15.15 ms per token, 66.01 tokens per second)
|
||||||
|
llama_print_timings: total time = 1365.47 ms / 243 tokens
|
||||||
|
```
|
||||||
|
|
||||||
## Minor shortcomings
|
## Minor shortcomings
|
||||||
The `n_patch` of output in `ldp` is 1/4 of the input. In order to implement quickly, we uniformly modified `clip_n_patches` function to a quarter. when counting the time consumption, the calculated time will be 4 times bigger than the real cost.
|
The `n_patch` of output in `ldp` is 1/4 of the input. In order to implement quickly, we uniformly modified `clip_n_patches` function to a quarter. when counting the time consumption, the calculated time will be 4 times bigger than the real cost.
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
|
|
||||||
- [ ] Support non-CPU backend for the new operators, such as `depthwise`, `hardswish`, `hardsigmoid`
|
- [x] Support non-CPU backend for the new operators, such as `depthwise`, `hardswish`, `hardsigmoid`
|
||||||
- [ ] Optimize LDP projector performance
|
- [ ] Optimize LDP projector performance
|
||||||
|
|
||||||
- Optimize the structure definition to avoid unnecessary memory rearrangements, to reduce the use of `ggml_permute_cpy`;
|
- Optimize the structure definition to avoid unnecessary memory rearrangements, to reduce the use of `ggml_permute_cpy`;
|
||||||
- Optimize operator implementation (ARM CPU/NVIDIA GPU): such as depthwise conv, hardswish, hardsigmoid, etc.
|
- Optimize operator implementation (ARM CPU/NVIDIA GPU): such as depthwise conv, hardswish, hardsigmoid, etc.
|
||||||
- [ ] run MobileVLM on `Jetson Orin`
|
- [x] run MobileVLM on `Jetson Orin`
|
||||||
- [ ] Support more model variants, such as `MobileVLM-3B`.
|
- [ ] Support more model variants, such as `MobileVLM-3B`.
|
||||||
|
|
||||||
|
|
||||||
|
209
ggml-cuda.cu
209
ggml-cuda.cu
@ -524,6 +524,8 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong
|
|||||||
#define CUDA_SILU_BLOCK_SIZE 256
|
#define CUDA_SILU_BLOCK_SIZE 256
|
||||||
#define CUDA_TANH_BLOCK_SIZE 256
|
#define CUDA_TANH_BLOCK_SIZE 256
|
||||||
#define CUDA_RELU_BLOCK_SIZE 256
|
#define CUDA_RELU_BLOCK_SIZE 256
|
||||||
|
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
|
||||||
|
#define CUDA_HARDSWISH_BLOCK_SIZE 256
|
||||||
#define CUDA_SQR_BLOCK_SIZE 256
|
#define CUDA_SQR_BLOCK_SIZE 256
|
||||||
#define CUDA_CPY_BLOCK_SIZE 32
|
#define CUDA_CPY_BLOCK_SIZE 32
|
||||||
#define CUDA_SCALE_BLOCK_SIZE 256
|
#define CUDA_SCALE_BLOCK_SIZE 256
|
||||||
@ -540,6 +542,7 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong
|
|||||||
#define CUDA_PAD_BLOCK_SIZE 256
|
#define CUDA_PAD_BLOCK_SIZE 256
|
||||||
#define CUDA_ACC_BLOCK_SIZE 256
|
#define CUDA_ACC_BLOCK_SIZE 256
|
||||||
#define CUDA_IM2COL_BLOCK_SIZE 256
|
#define CUDA_IM2COL_BLOCK_SIZE 256
|
||||||
|
#define CUDA_POOL2D_BLOCK_SIZE 256
|
||||||
|
|
||||||
#define CUDA_Q8_0_NE_ALIGN 2048
|
#define CUDA_Q8_0_NE_ALIGN 2048
|
||||||
|
|
||||||
@ -823,6 +826,24 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
|
|||||||
dst[i] = fmaxf(x[i], 0);
|
dst[i] = fmaxf(x[i], 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
|
||||||
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
|
||||||
|
}
|
||||||
|
|
||||||
|
static __global__ void hardswish_f32(const float * x, float * dst, const int k) {
|
||||||
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
|
||||||
|
}
|
||||||
|
|
||||||
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
|
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
if (i >= k) {
|
if (i >= k) {
|
||||||
@ -5823,7 +5844,7 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols,
|
|||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
|
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
|
||||||
const int row = blockIdx.y;
|
const int row = blockIdx.x;
|
||||||
const int col = threadIdx.x;
|
const int col = threadIdx.x;
|
||||||
|
|
||||||
float sum = 0.0f;
|
float sum = 0.0f;
|
||||||
@ -6145,9 +6166,10 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
|
|||||||
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
|
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void im2col_f32_f16(
|
template <typename T>
|
||||||
const float * x, half * dst,
|
static __global__ void im2col_kernel(
|
||||||
int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW,
|
const float * x, T * dst, int batch_offset,
|
||||||
|
int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW,
|
||||||
int s0, int s1, int p0, int p1, int d0, int d1) {
|
int s0, int s1, int p0, int p1, int d0, int d1) {
|
||||||
const int i = threadIdx.x + blockIdx.x * blockDim.x;
|
const int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
if (i >= pelements) {
|
if (i >= pelements) {
|
||||||
@ -6160,21 +6182,73 @@ static __global__ void im2col_f32_f16(
|
|||||||
const int ky = (i - kd) / OW;
|
const int ky = (i - kd) / OW;
|
||||||
const int ix = i % OW;
|
const int ix = i % OW;
|
||||||
|
|
||||||
|
const int oh = blockIdx.y;
|
||||||
|
const int batch = blockIdx.z / IC;
|
||||||
|
const int ic = blockIdx.z % IC;
|
||||||
|
|
||||||
const int64_t iiw = ix * s0 + kx * d0 - p0;
|
const int64_t iiw = ix * s0 + kx * d0 - p0;
|
||||||
const int64_t iih = blockIdx.y * s1 + ky * d1 - p1;
|
const int64_t iih = oh * s1 + ky * d1 - p1;
|
||||||
|
|
||||||
const int64_t offset_dst =
|
const int64_t offset_dst =
|
||||||
(blockIdx.y * OW + ix) * CHW +
|
((batch * OH + oh) * OW + ix) * CHW +
|
||||||
(blockIdx.z * (KW * KH) + ky * KW + kx);
|
(ic * (KW * KH) + ky * KW + kx);
|
||||||
|
|
||||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||||
dst[offset_dst] = __float2half(0.0f);
|
dst[offset_dst] = 0.0f;
|
||||||
} else {
|
} else {
|
||||||
const int64_t offset_src = blockIdx.z * offset_delta;
|
const int64_t offset_src = ic * offset_delta + batch * batch_offset;
|
||||||
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
|
dst[offset_dst] = x[offset_src + iih * IW + iiw];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Ti, typename To>
|
||||||
|
static __global__ void pool2d_nchw_kernel(
|
||||||
|
const int ih, const int iw, const int oh, const int ow,
|
||||||
|
const int kh, const int kw, const int sh, const int sw,
|
||||||
|
const int ph, const int pw, const int parallel_elements,
|
||||||
|
const Ti* src, To* dst, const enum ggml_op_pool op) {
|
||||||
|
int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
if (idx >= parallel_elements) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int I_HW = ih * iw;
|
||||||
|
const int O_HW = oh * ow;
|
||||||
|
const int nc = idx / O_HW;
|
||||||
|
const int cur_oh = idx % O_HW / ow;
|
||||||
|
const int cur_ow = idx % O_HW % ow;
|
||||||
|
const Ti* i_ptr = src + nc * I_HW;
|
||||||
|
To* o_ptr = dst + nc * O_HW;
|
||||||
|
const int start_h = cur_oh * sh - ph;
|
||||||
|
const int bh = max(0, start_h);
|
||||||
|
const int eh = min(ih, start_h + kh);
|
||||||
|
const int start_w = cur_ow * sw - pw;
|
||||||
|
const int bw = max(0, start_w);
|
||||||
|
const int ew = min(iw, start_w + kw);
|
||||||
|
const To scale = 1. / (kh * kw);
|
||||||
|
To res = 0;
|
||||||
|
|
||||||
|
switch (op) {
|
||||||
|
case GGML_OP_POOL_AVG: res = 0; break;
|
||||||
|
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = bh; i < eh; i += 1) {
|
||||||
|
for (int j = bw; j < ew; j += 1) {
|
||||||
|
#if __CUDA_ARCH__ >= 350
|
||||||
|
Ti cur = __ldg(i_ptr + i * iw + j);
|
||||||
|
#else
|
||||||
|
Ti cur = i_ptr[i * iw + j];
|
||||||
|
#endif
|
||||||
|
switch (op) {
|
||||||
|
case GGML_OP_POOL_AVG: res += cur * scale; break;
|
||||||
|
case GGML_OP_POOL_MAX: res = max(res, (To)cur); break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
o_ptr[cur_oh * ow + cur_ow] = res;
|
||||||
|
}
|
||||||
|
|
||||||
template<int qk, int qr, dequantize_kernel_t dq>
|
template<int qk, int qr, dequantize_kernel_t dq>
|
||||||
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
||||||
@ -6388,6 +6462,16 @@ static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
|
|||||||
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||||
|
const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
|
||||||
|
hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||||
|
const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE;
|
||||||
|
hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||||
|
}
|
||||||
|
|
||||||
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
|
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
|
||||||
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
||||||
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
|
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
|
||||||
@ -7475,7 +7559,7 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const
|
|||||||
|
|
||||||
static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||||
const dim3 block_nums(1, nrows, 1);
|
const dim3 block_nums(nrows, 1, 1);
|
||||||
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -7587,14 +7671,15 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void im2col_f32_f16_cuda(const float* x, half* dst,
|
template <typename T>
|
||||||
|
static void im2col_cuda(const float* x, T* dst,
|
||||||
int IW, int IH, int OW, int OH, int KW, int KH, int IC,
|
int IW, int IH, int OW, int OH, int KW, int KH, int IC,
|
||||||
int offset_delta,
|
int batch, int batch_offset, int offset_delta,
|
||||||
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
|
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
|
||||||
const int parallel_elements = OW * KW * KH;
|
const int parallel_elements = OW * KW * KH;
|
||||||
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
|
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
|
||||||
dim3 block_nums(num_blocks, OH, IC);
|
dim3 block_nums(num_blocks, OH, batch * IC);
|
||||||
im2col_f32_f16<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
|
im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// buffer pool for cuda
|
// buffer pool for cuda
|
||||||
@ -8179,6 +8264,34 @@ static void ggml_cuda_op_relu(
|
|||||||
(void) src1_dd;
|
(void) src1_dd;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_cuda_op_hardsigmoid(
|
||||||
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
hardsigmoid_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||||
|
|
||||||
|
(void) src1;
|
||||||
|
(void) dst;
|
||||||
|
(void) src1_dd;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cuda_op_hardswish(
|
||||||
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
hardswish_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||||
|
|
||||||
|
(void) src1;
|
||||||
|
(void) dst;
|
||||||
|
(void) src1_dd;
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cuda_op_leaky_relu(
|
static void ggml_cuda_op_leaky_relu(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
@ -8810,13 +8923,46 @@ static void ggml_cuda_op_alibi(
|
|||||||
(void) src1_dd;
|
(void) src1_dd;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_cuda_op_pool2d(
|
||||||
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
const int32_t * opts = (const int32_t *)dst->op_params;
|
||||||
|
enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
|
||||||
|
const int k0 = opts[1];
|
||||||
|
const int k1 = opts[2];
|
||||||
|
const int s0 = opts[3];
|
||||||
|
const int s1 = opts[4];
|
||||||
|
const int p0 = opts[5];
|
||||||
|
const int p1 = opts[6];
|
||||||
|
|
||||||
|
const int64_t IH = src0->ne[1];
|
||||||
|
const int64_t IW = src0->ne[0];
|
||||||
|
|
||||||
|
const int64_t N = dst->ne[3];
|
||||||
|
const int64_t OC = dst->ne[2];
|
||||||
|
const int64_t OH = dst->ne[1];
|
||||||
|
const int64_t OW = dst->ne[0];
|
||||||
|
|
||||||
|
const int parallel_elements = N * OC * OH * OW;
|
||||||
|
const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE;
|
||||||
|
dim3 block_nums(num_blocks);
|
||||||
|
pool2d_nchw_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, main_stream>>>(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_dd, dst_dd, op);
|
||||||
|
|
||||||
|
(void) src1;
|
||||||
|
(void) src1_dd;
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cuda_op_im2col(
|
static void ggml_cuda_op_im2col(
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
||||||
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
|
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
|
||||||
@ -8838,8 +8984,14 @@ static void ggml_cuda_op_im2col(
|
|||||||
const int64_t OW = dst->ne[1];
|
const int64_t OW = dst->ne[1];
|
||||||
|
|
||||||
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
||||||
|
const int64_t batch = src1->ne[3];
|
||||||
|
const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
|
||||||
|
|
||||||
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
if(dst->type == GGML_TYPE_F16) {
|
||||||
|
im2col_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
||||||
|
} else {
|
||||||
|
im2col_cuda(src1_dd, (float*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
||||||
|
}
|
||||||
|
|
||||||
(void) src0;
|
(void) src0;
|
||||||
(void) src0_dd;
|
(void) src0_dd;
|
||||||
@ -9435,6 +9587,13 @@ static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, g
|
|||||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_cuda_hardsigmoid(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_hardsigmoid);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cuda_hardswish(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_hardswish);
|
||||||
|
}
|
||||||
static void ggml_cuda_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_leaky_relu);
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_leaky_relu);
|
||||||
}
|
}
|
||||||
@ -10220,6 +10379,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
|
|||||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_cuda_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pool2d);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
|
||||||
}
|
}
|
||||||
@ -10321,6 +10484,12 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
|
|||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
func = ggml_cuda_relu;
|
func = ggml_cuda_relu;
|
||||||
break;
|
break;
|
||||||
|
case GGML_UNARY_OP_HARDSIGMOID:
|
||||||
|
func = ggml_cuda_hardsigmoid;
|
||||||
|
break;
|
||||||
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
|
func = ggml_cuda_hardswish;
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -10395,6 +10564,9 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
|
|||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
func = ggml_cuda_im2col;
|
func = ggml_cuda_im2col;
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_POOL_2D:
|
||||||
|
func = ggml_cuda_pool2d;
|
||||||
|
break;
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
func = ggml_cuda_sum_rows;
|
func = ggml_cuda_sum_rows;
|
||||||
break;
|
break;
|
||||||
@ -11123,6 +11295,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||||||
case GGML_UNARY_OP_GELU:
|
case GGML_UNARY_OP_GELU:
|
||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
|
case GGML_UNARY_OP_HARDSIGMOID:
|
||||||
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
case GGML_UNARY_OP_GELU_QUICK:
|
case GGML_UNARY_OP_GELU_QUICK:
|
||||||
case GGML_UNARY_OP_TANH:
|
case GGML_UNARY_OP_TANH:
|
||||||
return true;
|
return true;
|
||||||
@ -11221,6 +11395,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
case GGML_OP_ALIBI:
|
case GGML_OP_ALIBI:
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
|
118
ggml.c
118
ggml.c
@ -5349,7 +5349,7 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
|
|||||||
int s0,
|
int s0,
|
||||||
int p0,
|
int p0,
|
||||||
int d0) {
|
int d0) {
|
||||||
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K]
|
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K]
|
||||||
|
|
||||||
struct ggml_tensor * result =
|
struct ggml_tensor * result =
|
||||||
ggml_mul_mat(ctx,
|
ggml_mul_mat(ctx,
|
||||||
@ -5427,16 +5427,15 @@ struct ggml_tensor * ggml_conv_depthwise_2d(
|
|||||||
int p1,
|
int p1,
|
||||||
int d0,
|
int d0,
|
||||||
int d1) {
|
int d1) {
|
||||||
|
|
||||||
struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
|
struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
|
||||||
struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
|
struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
|
||||||
ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
|
ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
|
||||||
s0, s1, p0, p1, d0, d1, true); // [N * IC, OH, OW, KH * KW]
|
s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
|
||||||
|
struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
|
||||||
struct ggml_tensor * result =
|
|
||||||
ggml_mul_mat(ctx,
|
|
||||||
ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1), // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
|
|
||||||
ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3])); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
|
|
||||||
|
|
||||||
|
new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
|
||||||
|
struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
|
||||||
result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
|
result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
@ -5457,7 +5456,8 @@ struct ggml_tensor * ggml_im2col(
|
|||||||
int p1,
|
int p1,
|
||||||
int d0,
|
int d0,
|
||||||
int d1,
|
int d1,
|
||||||
bool is_2D) {
|
bool is_2D,
|
||||||
|
enum ggml_type dst_type) {
|
||||||
|
|
||||||
if(is_2D) {
|
if(is_2D) {
|
||||||
GGML_ASSERT(a->ne[2] == b->ne[2]);
|
GGML_ASSERT(a->ne[2] == b->ne[2]);
|
||||||
@ -5481,7 +5481,7 @@ struct ggml_tensor * ggml_im2col(
|
|||||||
is_2D ? b->ne[3] : 1,
|
is_2D ? b->ne[3] : 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
|
struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
|
||||||
int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
|
int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
|
||||||
ggml_set_op_params(result, params, sizeof(params));
|
ggml_set_op_params(result, params, sizeof(params));
|
||||||
|
|
||||||
@ -5506,7 +5506,7 @@ struct ggml_tensor * ggml_conv_2d(
|
|||||||
int p1,
|
int p1,
|
||||||
int d0,
|
int d0,
|
||||||
int d1) {
|
int d1) {
|
||||||
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]
|
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N, OH, OW, IC * KH * KW]
|
||||||
|
|
||||||
struct ggml_tensor * result =
|
struct ggml_tensor * result =
|
||||||
ggml_mul_mat(ctx,
|
ggml_mul_mat(ctx,
|
||||||
@ -5632,12 +5632,13 @@ struct ggml_tensor * ggml_pool_2d(
|
|||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * result;
|
||||||
const int64_t ne[3] = {
|
const int64_t ne[3] = {
|
||||||
ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
|
ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
|
||||||
ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
|
ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
|
||||||
a->ne[2],
|
a->ne[2],
|
||||||
};
|
};
|
||||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
|
result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
|
||||||
|
|
||||||
int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
|
int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
|
||||||
ggml_set_op_params(result, params, sizeof(params));
|
ggml_set_op_params(result, params, sizeof(params));
|
||||||
@ -5645,7 +5646,6 @@ struct ggml_tensor * ggml_pool_2d(
|
|||||||
result->op = GGML_OP_POOL_2D;
|
result->op = GGML_OP_POOL_2D;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
result->src[0] = a;
|
result->src[0] = a;
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -12493,6 +12493,92 @@ static void ggml_compute_forward_conv_transpose_1d(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// src0: kernel [OC, IC, KH, KW]
|
||||||
|
// src1: image [N, IC, IH, IW]
|
||||||
|
// dst: result [N, OH, OW, IC*KH*KW]
|
||||||
|
static void ggml_compute_forward_im2col_f32(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
const struct ggml_tensor * src0,
|
||||||
|
const struct ggml_tensor * src1,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
int64_t t0 = ggml_perf_time_us();
|
||||||
|
UNUSED(t0);
|
||||||
|
|
||||||
|
GGML_TENSOR_BINARY_OP_LOCALS;
|
||||||
|
|
||||||
|
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
||||||
|
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
||||||
|
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
||||||
|
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
||||||
|
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
||||||
|
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
||||||
|
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int64_t N = is_2D ? ne13 : ne12;
|
||||||
|
const int64_t IC = is_2D ? ne12 : ne11;
|
||||||
|
const int64_t IH = is_2D ? ne11 : 1;
|
||||||
|
const int64_t IW = ne10;
|
||||||
|
|
||||||
|
const int64_t KH = is_2D ? ne01 : 1;
|
||||||
|
const int64_t KW = ne00;
|
||||||
|
|
||||||
|
const int64_t OH = is_2D ? ne2 : 1;
|
||||||
|
const int64_t OW = ne1;
|
||||||
|
|
||||||
|
int ofs0 = is_2D ? nb13 : nb12;
|
||||||
|
int ofs1 = is_2D ? nb12 : nb11;
|
||||||
|
|
||||||
|
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
||||||
|
GGML_ASSERT(nb10 == sizeof(float));
|
||||||
|
|
||||||
|
if (params->type == GGML_TASK_INIT) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params->type == GGML_TASK_FINALIZE) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
|
||||||
|
{
|
||||||
|
float * const wdata = (float *) dst->data;
|
||||||
|
|
||||||
|
for (int64_t in = 0; in < N; in++) {
|
||||||
|
for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
|
||||||
|
for (int64_t iow = 0; iow < OW; iow++) {
|
||||||
|
for (int64_t iic = ith; iic < IC; iic += nth) {
|
||||||
|
|
||||||
|
// micro kernel
|
||||||
|
float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
|
||||||
|
const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
|
||||||
|
|
||||||
|
for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
|
||||||
|
for (int64_t ikw = 0; ikw < KW; ikw++) {
|
||||||
|
const int64_t iiw = iow*s0 + ikw*d0 - p0;
|
||||||
|
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
||||||
|
|
||||||
|
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||||
|
dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
|
||||||
|
} else {
|
||||||
|
dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// src0: kernel [OC, IC, KH, KW]
|
// src0: kernel [OC, IC, KH, KW]
|
||||||
// src1: image [N, IC, IH, IW]
|
// src1: image [N, IC, IH, IW]
|
||||||
// dst: result [N, OH, OW, IC*KH*KW]
|
// dst: result [N, OH, OW, IC*KH*KW]
|
||||||
@ -12583,14 +12669,14 @@ static void ggml_compute_forward_im2col(
|
|||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
const struct ggml_tensor * src1,
|
const struct ggml_tensor * src1,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
switch (src0->type) {
|
switch (dst->type) {
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_im2col_f16(params, src0, src1, dst);
|
ggml_compute_forward_im2col_f16(params, src0, src1, dst);
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(false);
|
ggml_compute_forward_im2col_f32(params, src0, src1, dst);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
@ -12781,8 +12867,8 @@ static void ggml_compute_forward_pool_2d(
|
|||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src,
|
const struct ggml_tensor * src,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
assert(src->type == GGML_TYPE_F32);
|
GGML_ASSERT(src->type == GGML_TYPE_F32);
|
||||||
assert(params->ith == 0);
|
GGML_ASSERT(params->ith == 0);
|
||||||
|
|
||||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||||
return;
|
return;
|
||||||
|
3
ggml.h
3
ggml.h
@ -1495,7 +1495,8 @@ extern "C" {
|
|||||||
int p1,
|
int p1,
|
||||||
int d0,
|
int d0,
|
||||||
int d1,
|
int d1,
|
||||||
bool is_2D);
|
bool is_2D,
|
||||||
|
enum ggml_type dst_type);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_conv_depthwise_2d(
|
GGML_API struct ggml_tensor * ggml_conv_depthwise_2d(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
@ -227,6 +227,14 @@ static std::string var_to_str(ggml_type type) {
|
|||||||
return ggml_type_name(type);
|
return ggml_type_name(type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::string var_to_str(ggml_op_pool pool) {
|
||||||
|
switch (pool) {
|
||||||
|
case GGML_OP_POOL_AVG: return "avg";
|
||||||
|
case GGML_OP_POOL_MAX: return "max";
|
||||||
|
default: return std::to_string(pool);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#define VARS_TO_STR1(a) VAR_TO_STR(a)
|
#define VARS_TO_STR1(a) VAR_TO_STR(a)
|
||||||
#define VARS_TO_STR2(a, b) VAR_TO_STR(a) + "," + VAR_TO_STR(b)
|
#define VARS_TO_STR2(a, b) VAR_TO_STR(a) + "," + VAR_TO_STR(b)
|
||||||
#define VARS_TO_STR3(a, b, c) VAR_TO_STR(a) + "," + VARS_TO_STR2(b, c)
|
#define VARS_TO_STR3(a, b, c) VAR_TO_STR(a) + "," + VARS_TO_STR2(b, c)
|
||||||
@ -238,6 +246,7 @@ static std::string var_to_str(ggml_type type) {
|
|||||||
#define VARS_TO_STR9(a, b, c, d, e, f, g, h, i) VAR_TO_STR(a) + "," + VARS_TO_STR8(b, c, d, e, f, g, h, i)
|
#define VARS_TO_STR9(a, b, c, d, e, f, g, h, i) VAR_TO_STR(a) + "," + VARS_TO_STR8(b, c, d, e, f, g, h, i)
|
||||||
#define VARS_TO_STR10(a, b, c, d, e, f, g, h, i, j) VAR_TO_STR(a) + "," + VARS_TO_STR9(b, c, d, e, f, g, h, i, j)
|
#define VARS_TO_STR10(a, b, c, d, e, f, g, h, i, j) VAR_TO_STR(a) + "," + VARS_TO_STR9(b, c, d, e, f, g, h, i, j)
|
||||||
#define VARS_TO_STR11(a, b, c, d, e, f, g, h, i, j, k) VAR_TO_STR(a) + "," + VARS_TO_STR10(b, c, d, e, f, g, h, i, j, k)
|
#define VARS_TO_STR11(a, b, c, d, e, f, g, h, i, j, k) VAR_TO_STR(a) + "," + VARS_TO_STR10(b, c, d, e, f, g, h, i, j, k)
|
||||||
|
#define VARS_TO_STR12(a, b, c, d, e, f, g, h, i, j, k, l) VAR_TO_STR(a) + "," + VARS_TO_STR11(b, c, d, e, f, g, h, i, j, k, l)
|
||||||
|
|
||||||
#ifdef GGML_USE_SYCL
|
#ifdef GGML_USE_SYCL
|
||||||
static bool inline _isinf(float f) {
|
static bool inline _isinf(float f) {
|
||||||
@ -1162,10 +1171,45 @@ struct test_alibi : public test_case {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// GGML_OP_POOL2D
|
||||||
|
struct test_pool2d : public test_case {
|
||||||
|
enum ggml_op_pool pool_type;
|
||||||
|
const ggml_type type_input;
|
||||||
|
const std::array<int64_t, 4> ne_input;
|
||||||
|
// kernel size
|
||||||
|
const int k0;
|
||||||
|
const int k1;
|
||||||
|
// stride
|
||||||
|
const int s0;
|
||||||
|
const int s1;
|
||||||
|
// padding
|
||||||
|
const int p0;
|
||||||
|
const int p1;
|
||||||
|
|
||||||
|
std::string vars() override {
|
||||||
|
return VARS_TO_STR9(pool_type, type_input, ne_input, k0, k1, s0, s1, p0, p1);
|
||||||
|
}
|
||||||
|
|
||||||
|
test_pool2d(ggml_op_pool pool_type = GGML_OP_POOL_AVG,
|
||||||
|
ggml_type type_input = GGML_TYPE_F32,
|
||||||
|
std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
|
||||||
|
int k0 = 3, int k1 = 3,
|
||||||
|
int s0 = 1, int s1 = 1,
|
||||||
|
int p0 = 1, int p1 = 1)
|
||||||
|
: pool_type(pool_type), type_input(type_input), ne_input(ne_input), k0(k0), k1(k1), s0(s0), s1(s1), p0(p0), p1(p1) {}
|
||||||
|
|
||||||
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
|
ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
|
||||||
|
ggml_tensor * out = ggml_pool_2d(ctx, input, pool_type, k0, k1, s0, s1, p0, p1);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// GGML_OP_IM2COL
|
// GGML_OP_IM2COL
|
||||||
struct test_im2col : public test_case {
|
struct test_im2col : public test_case {
|
||||||
const ggml_type type_input;
|
const ggml_type type_input;
|
||||||
const ggml_type type_kernel;
|
const ggml_type type_kernel;
|
||||||
|
const ggml_type dst_type;
|
||||||
const std::array<int64_t, 4> ne_input;
|
const std::array<int64_t, 4> ne_input;
|
||||||
const std::array<int64_t, 4> ne_kernel;
|
const std::array<int64_t, 4> ne_kernel;
|
||||||
// stride
|
// stride
|
||||||
@ -1181,22 +1225,22 @@ struct test_im2col : public test_case {
|
|||||||
const bool is_2D;
|
const bool is_2D;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR11(type_input, type_kernel, ne_input, ne_kernel, s0, s1, p0, p1, d0, d1, is_2D);
|
return VARS_TO_STR12(type_input, type_kernel, dst_type, ne_input, ne_kernel, s0, s1, p0, p1, d0, d1, is_2D);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_im2col(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16,
|
test_im2col(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, ggml_type dst_type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
|
std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
|
||||||
std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
|
std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
|
||||||
int s0 = 1, int s1 = 1,
|
int s0 = 1, int s1 = 1,
|
||||||
int p0 = 1, int p1 = 1,
|
int p0 = 1, int p1 = 1,
|
||||||
int d0 = 1, int d1 = 1,
|
int d0 = 1, int d1 = 1,
|
||||||
bool is_2D = true)
|
bool is_2D = true)
|
||||||
: type_input(type_input), type_kernel(type_kernel), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), is_2D(is_2D) {}
|
: type_input(type_input), type_kernel(type_kernel), dst_type(dst_type), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), is_2D(is_2D) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
|
ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
|
||||||
ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
|
ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
|
||||||
ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D);
|
ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D, dst_type);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1912,6 +1956,27 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (ggml_type type_input : {GGML_TYPE_F32}) {
|
||||||
|
for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
|
||||||
|
for (int k0 : {1, 3}) {
|
||||||
|
for (int k1 : {1, 3}) {
|
||||||
|
for (int s0 : {1, 2}) {
|
||||||
|
for (int s1 : {1, 2}) {
|
||||||
|
for (int p0 : {0, 1}) {
|
||||||
|
for (int p1 : {0, 1}) {
|
||||||
|
test_cases.emplace_back(new test_pool2d(pool_type, type_input, {10, 10, 3, 1}, k0, k1, s0, s1, p0, p1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
|
||||||
|
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
|
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
|
||||||
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1}));
|
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1}));
|
||||||
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1}));
|
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1}));
|
||||||
@ -2049,7 +2114,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
}
|
}
|
||||||
|
|
||||||
test_cases.emplace_back(new test_alibi());
|
test_cases.emplace_back(new test_alibi());
|
||||||
test_cases.emplace_back(new test_im2col());
|
|
||||||
test_cases.emplace_back(new test_concat(GGML_TYPE_F32));
|
test_cases.emplace_back(new test_concat(GGML_TYPE_F32));
|
||||||
test_cases.emplace_back(new test_concat(GGML_TYPE_I32));
|
test_cases.emplace_back(new test_concat(GGML_TYPE_I32));
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user