mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 06:10:29 +01:00
metal : add im2col F32 dst support (#5132)
This commit is contained in:
parent
15606309a0
commit
efb7bdbbd0
13
ggml-metal.m
13
ggml-metal.m
@ -135,6 +135,7 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
||||||
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
||||||
@ -506,6 +507,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
||||||
@ -630,6 +632,10 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|||||||
case GGML_OP_ALIBI:
|
case GGML_OP_ALIBI:
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
|
return true;
|
||||||
|
case GGML_OP_POOL_1D:
|
||||||
|
case GGML_OP_POOL_2D:
|
||||||
|
return false;
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
@ -2015,7 +2021,7 @@ static bool ggml_metal_graph_compute(
|
|||||||
{
|
{
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
||||||
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
||||||
@ -2023,6 +2029,7 @@ static bool ggml_metal_graph_compute(
|
|||||||
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
||||||
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
||||||
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
||||||
|
|
||||||
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
||||||
|
|
||||||
const int32_t N = src1->ne[is_2D ? 3 : 2];
|
const int32_t N = src1->ne[is_2D ? 3 : 2];
|
||||||
@ -2043,8 +2050,8 @@ static bool ggml_metal_graph_compute(
|
|||||||
|
|
||||||
id<MTLComputePipelineState> pipeline = nil;
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (dst->type) {
|
||||||
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
|
||||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
|
||||||
default: GGML_ASSERT(false);
|
default: GGML_ASSERT(false);
|
||||||
};
|
};
|
||||||
|
@ -1775,9 +1775,29 @@ kernel void kernel_rope(
|
|||||||
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
||||||
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
||||||
|
|
||||||
kernel void kernel_im2col_f16(
|
typedef void (im2col_t)(
|
||||||
device const float * x,
|
device const float * x,
|
||||||
device half * dst,
|
device char * dst,
|
||||||
|
constant int32_t & ofs0,
|
||||||
|
constant int32_t & ofs1,
|
||||||
|
constant int32_t & IW,
|
||||||
|
constant int32_t & IH,
|
||||||
|
constant int32_t & CHW,
|
||||||
|
constant int32_t & s0,
|
||||||
|
constant int32_t & s1,
|
||||||
|
constant int32_t & p0,
|
||||||
|
constant int32_t & p1,
|
||||||
|
constant int32_t & d0,
|
||||||
|
constant int32_t & d1,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tgpg[[threadgroups_per_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
kernel void kernel_im2col(
|
||||||
|
device const float * x,
|
||||||
|
device char * dst,
|
||||||
constant int32_t & ofs0,
|
constant int32_t & ofs0,
|
||||||
constant int32_t & ofs1,
|
constant int32_t & ofs1,
|
||||||
constant int32_t & IW,
|
constant int32_t & IW,
|
||||||
@ -1800,14 +1820,19 @@ kernel void kernel_im2col_f16(
|
|||||||
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
|
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
|
||||||
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
|
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
|
||||||
|
|
||||||
|
device T * pdst = (device T *) (dst);
|
||||||
|
|
||||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||||
dst[offset_dst] = 0.0f;
|
pdst[offset_dst] = 0.0f;
|
||||||
} else {
|
} else {
|
||||||
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
|
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
|
||||||
dst[offset_dst] = x[offset_src + iih * IW + iiw];
|
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
|
||||||
|
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
|
||||||
|
|
||||||
kernel void kernel_upscale_f32(
|
kernel void kernel_upscale_f32(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
|
Loading…
Reference in New Issue
Block a user