metal : add GGML_OP_REPEAT kernels (#7557)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-05-27 12:10:19 +03:00 committed by GitHub
parent 62bfef5194
commit 1d8fca72ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 96 additions and 4 deletions

View File

@ -35,6 +35,10 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_ROW, GGML_METAL_KERNEL_TYPE_MUL_ROW,
GGML_METAL_KERNEL_TYPE_DIV, GGML_METAL_KERNEL_TYPE_DIV,
GGML_METAL_KERNEL_TYPE_DIV_ROW, GGML_METAL_KERNEL_TYPE_DIV_ROW,
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
GGML_METAL_KERNEL_TYPE_REPEAT_I16,
GGML_METAL_KERNEL_TYPE_SCALE, GGML_METAL_KERNEL_TYPE_SCALE,
GGML_METAL_KERNEL_TYPE_SCALE_4, GGML_METAL_KERNEL_TYPE_SCALE_4,
GGML_METAL_KERNEL_TYPE_CLAMP, GGML_METAL_KERNEL_TYPE_CLAMP,
@ -485,6 +489,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
@ -746,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
case GGML_OP_ACC: case GGML_OP_ACC:
case GGML_OP_MUL: case GGML_OP_MUL:
case GGML_OP_DIV: case GGML_OP_DIV:
case GGML_OP_REPEAT:
case GGML_OP_SCALE: case GGML_OP_SCALE:
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
case GGML_OP_SQR: case GGML_OP_SQR:
@ -979,8 +988,6 @@ static enum ggml_status ggml_metal_graph_compute(
switch (dst->op) { switch (dst->op) {
case GGML_OP_CONCAT: case GGML_OP_CONCAT:
{ {
const int64_t nb = ne00;
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
@ -1011,7 +1018,6 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
const int nth = MIN(1024, ne0); const int nth = MIN(1024, ne0);
@ -1021,11 +1027,14 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_OP_MUL: case GGML_OP_MUL:
case GGML_OP_DIV: case GGML_OP_DIV:
{ {
GGML_ASSERT(src0t == GGML_TYPE_F32);
GGML_ASSERT(src1t == GGML_TYPE_F32);
const size_t offs = 0; const size_t offs = 0;
bool bcast_row = false; bool bcast_row = false;
int64_t nb = ne00; int64_t nb = ne00; // used by the "row" kernels
id<MTLComputePipelineState> pipeline = nil; id<MTLComputePipelineState> pipeline = nil;
@ -1094,6 +1103,42 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} }
} break; } break;
case GGML_OP_REPEAT:
{
id<MTLComputePipelineState> pipeline;
switch (src0t) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
default: GGML_ASSERT(false);
}
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_ACC: case GGML_OP_ACC:
{ {
GGML_ASSERT(src0t == GGML_TYPE_F32); GGML_ASSERT(src0t == GGML_TYPE_F32);

View File

@ -168,6 +168,53 @@ kernel void kernel_div(
} }
} }
template<typename T>
kernel void kernel_repeat(
device const char * src0,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3 % ne03;
const int64_t i02 = i2 % ne02;
const int64_t i01 = i1 % ne01;
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
const int i00 = i0 % ne00;
*((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
}
}
typedef decltype(kernel_repeat<float>) kernel_repeat_t;
template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
// assumption: src1 is a row // assumption: src1 is a row
// broadcast src1 into src0 // broadcast src1 into src0
kernel void kernel_add_row( kernel void kernel_add_row(