mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 22:30:32 +01:00
95 lines
3.2 KiB
Plaintext
95 lines
3.2 KiB
Plaintext
|
#include "pool2d.cuh"
|
||
|
|
||
|
template <typename Ti, typename To>
|
||
|
static __global__ void pool2d_nchw_kernel(
|
||
|
const int ih, const int iw, const int oh, const int ow,
|
||
|
const int kh, const int kw, const int sh, const int sw,
|
||
|
const int ph, const int pw, const int parallel_elements,
|
||
|
const Ti* src, To* dst, const enum ggml_op_pool op) {
|
||
|
int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||
|
if (idx >= parallel_elements) {
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
const int I_HW = ih * iw;
|
||
|
const int O_HW = oh * ow;
|
||
|
const int nc = idx / O_HW;
|
||
|
const int cur_oh = idx % O_HW / ow;
|
||
|
const int cur_ow = idx % O_HW % ow;
|
||
|
const Ti* i_ptr = src + nc * I_HW;
|
||
|
To* o_ptr = dst + nc * O_HW;
|
||
|
const int start_h = cur_oh * sh - ph;
|
||
|
const int bh = max(0, start_h);
|
||
|
const int eh = min(ih, start_h + kh);
|
||
|
const int start_w = cur_ow * sw - pw;
|
||
|
const int bw = max(0, start_w);
|
||
|
const int ew = min(iw, start_w + kw);
|
||
|
const To scale = 1. / (kh * kw);
|
||
|
To res = 0;
|
||
|
|
||
|
switch (op) {
|
||
|
case GGML_OP_POOL_AVG: res = 0; break;
|
||
|
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
|
||
|
default: assert(false);
|
||
|
}
|
||
|
|
||
|
for (int i = bh; i < eh; i += 1) {
|
||
|
for (int j = bw; j < ew; j += 1) {
|
||
|
#if __CUDA_ARCH__ >= 350
|
||
|
Ti cur = __ldg(i_ptr + i * iw + j);
|
||
|
#else
|
||
|
Ti cur = i_ptr[i * iw + j];
|
||
|
#endif
|
||
|
switch (op) {
|
||
|
case GGML_OP_POOL_AVG: res += cur * scale; break;
|
||
|
case GGML_OP_POOL_MAX: res = max(res, (To)cur); break;
|
||
|
default: assert(false);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
o_ptr[cur_oh * ow + cur_ow] = res;
|
||
|
}
|
||
|
|
||
|
static void pool2d_nchw_kernel_f32_f32_cuda(
|
||
|
const int ih, const int iw, const int oh, const int ow,
|
||
|
const int kh, const int kw, const int sh, const int sw,
|
||
|
const int ph, const int pw, const int parallel_elements,
|
||
|
const float * src, float * dst, const enum ggml_op_pool op,
|
||
|
cudaStream_t stream) {
|
||
|
|
||
|
const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE;
|
||
|
dim3 block_nums(num_blocks);
|
||
|
pool2d_nchw_kernel<<<block_nums, CUDA_POOL2D_BLOCK_SIZE, 0, stream>>>(ih, iw, oh, ow, kh, kw, sh, sw, ph, pw, parallel_elements, src, dst, op);
|
||
|
}
|
||
|
|
||
|
void ggml_cuda_op_pool2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||
|
const ggml_tensor * src0 = dst->src[0];
|
||
|
const float * src0_d = (const float *)src0->data;
|
||
|
float * dst_d = (float *)dst->data;
|
||
|
cudaStream_t stream = ctx.stream();
|
||
|
|
||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||
|
|
||
|
const int32_t * opts = (const int32_t *)dst->op_params;
|
||
|
enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
|
||
|
const int k0 = opts[1];
|
||
|
const int k1 = opts[2];
|
||
|
const int s0 = opts[3];
|
||
|
const int s1 = opts[4];
|
||
|
const int p0 = opts[5];
|
||
|
const int p1 = opts[6];
|
||
|
|
||
|
const int64_t IH = src0->ne[1];
|
||
|
const int64_t IW = src0->ne[0];
|
||
|
|
||
|
const int64_t N = dst->ne[3];
|
||
|
const int64_t OC = dst->ne[2];
|
||
|
const int64_t OH = dst->ne[1];
|
||
|
const int64_t OW = dst->ne[0];
|
||
|
|
||
|
const int parallel_elements = N * OC * OH * OW;
|
||
|
|
||
|
pool2d_nchw_kernel_f32_f32_cuda(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_d, dst_d, op, stream);
|
||
|
}
|