metal : parallel command buffer encoding (#1860)

* metal : parallel command buffer encoding

* metal : determine number of command buffers based on gf->n_threads
This commit is contained in:
Georgi Gerganov 2023-06-15 20:29:48 +03:00 committed by GitHub
parent 6b8312e797
commit 4bfcc855ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 471 additions and 447 deletions

View File

@ -55,6 +55,7 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t); void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
// same as ggml_graph_compute but uses Metal // same as ggml_graph_compute but uses Metal
// creates gf->n_threads command buffers in parallel
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf); void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
#ifdef __cplusplus #ifdef __cplusplus

View File

@ -284,528 +284,551 @@ void ggml_metal_get_tensor(
void ggml_metal_graph_compute( void ggml_metal_graph_compute(
struct ggml_metal_context * ctx, struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) { struct ggml_cgraph * gf) {
metal_printf("%s: evaluating graph\n", __func__); metal_printf("%s: evaluating graph\n", __func__);
size_t offs_src0 = 0; // create multiple command buffers and enqueue them
size_t offs_src1 = 0; // then, we encode the graph into the command buffers in parallel
size_t offs_dst = 0;
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBuffer]; const int n_cb = gf->n_threads;
id<MTLComputeCommandEncoder> encoder = nil;
for (int i = 0; i < gf->n_nodes; ++i) { NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
//metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
struct ggml_tensor * src0 = gf->nodes[i]->src0; for (int i = 0; i < n_cb; ++i) {
struct ggml_tensor * src1 = gf->nodes[i]->src1; command_buffers[i] = [ctx->queue commandBuffer];
struct ggml_tensor * dst = gf->nodes[i];
const int64_t ne00 = src0 ? src0->ne[0] : 0; // enqueue the command buffers in order to specify their execution order
const int64_t ne01 = src0 ? src0->ne[1] : 0; [command_buffers[i] enqueue];
const int64_t ne02 = src0 ? src0->ne[2] : 0; }
const int64_t ne03 = src0 ? src0->ne[3] : 0;
const uint64_t nb00 = src0 ? src0->nb[0] : 0; // TODO: is this the best way to start threads?
const uint64_t nb01 = src0 ? src0->nb[1] : 0; dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
const uint64_t nb02 = src0 ? src0->nb[2] : 0;
const uint64_t nb03 = src0 ? src0->nb[3] : 0;
const int64_t ne10 = src1 ? src1->ne[0] : 0; for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
const int64_t ne11 = src1 ? src1->ne[1] : 0; const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
const int64_t ne12 = src1 ? src1->ne[2] : 0;
const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
const uint64_t nb10 = src1 ? src1->nb[0] : 0; dispatch_async(queue, ^{
const uint64_t nb11 = src1 ? src1->nb[1] : 0; size_t offs_src0 = 0;
const uint64_t nb12 = src1 ? src1->nb[2] : 0; size_t offs_src1 = 0;
const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13); size_t offs_dst = 0;
const int64_t ne0 = dst ? dst->ne[0] : 0; id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
const int64_t ne1 = dst ? dst->ne[1] : 0;
const int64_t ne2 = dst ? dst->ne[2] : 0;
const int64_t ne3 = dst ? dst->ne[3] : 0;
const uint64_t nb0 = dst ? dst->nb[0] : 0; id<MTLComputeCommandEncoder> encoder = nil;
const uint64_t nb1 = dst ? dst->nb[1] : 0;
const uint64_t nb2 = dst ? dst->nb[2] : 0;
const uint64_t nb3 = dst ? dst->nb[3] : 0;
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil; for (int i = node_start; i < node_end; ++i) {
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil; metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
//metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op)); struct ggml_tensor * src0 = gf->nodes[i]->src0;
//if (src0) { struct ggml_tensor * src1 = gf->nodes[i]->src1;
// metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, struct ggml_tensor * dst = gf->nodes[i];
// ggml_is_contiguous(src0), src0->name);
//}
//if (src1) {
// metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
// ggml_is_contiguous(src1), src1->name);
//}
//if (dst) {
// metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
// dst->name);
//}
switch (dst->op) { const int64_t ne00 = src0 ? src0->ne[0] : 0;
case GGML_OP_RESHAPE: const int64_t ne01 = src0 ? src0->ne[1] : 0;
case GGML_OP_VIEW: const int64_t ne02 = src0 ? src0->ne[2] : 0;
case GGML_OP_TRANSPOSE: const int64_t ne03 = src0 ? src0->ne[3] : 0;
case GGML_OP_PERMUTE:
{
// noop
} break;
case GGML_OP_ADD:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
[encoder setComputePipelineState:ctx->pipeline_add]; const uint64_t nb00 = src0 ? src0->nb[0] : 0;
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; const uint64_t nb01 = src0 ? src0->nb[1] : 0;
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; const uint64_t nb02 = src0 ? src0->nb[2] : 0;
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; const uint64_t nb03 = src0 ? src0->nb[3] : 0;
const int64_t n = ggml_nelements(dst); const int64_t ne10 = src1 ? src1->ne[0] : 0;
const int64_t ne11 = src1 ? src1->ne[1] : 0;
const int64_t ne12 = src1 ? src1->ne[2] : 0;
const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; const uint64_t nb10 = src1 ? src1->nb[0] : 0;
} break; const uint64_t nb11 = src1 ? src1->nb[1] : 0;
case GGML_OP_MUL: const uint64_t nb12 = src1 ? src1->nb[2] : 0;
{ const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
if (ggml_nelements(src1) == ne10) { const int64_t ne0 = dst ? dst->ne[0] : 0;
// src1 is a row const int64_t ne1 = dst ? dst->ne[1] : 0;
[encoder setComputePipelineState:ctx->pipeline_mul_row]; const int64_t ne2 = dst ? dst->ne[2] : 0;
} else { const int64_t ne3 = dst ? dst->ne[3] : 0;
[encoder setComputePipelineState:ctx->pipeline_mul];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
const int64_t n = ggml_nelements(dst); const uint64_t nb0 = dst ? dst->nb[0] : 0;
const uint64_t nb1 = dst ? dst->nb[1] : 0;
const uint64_t nb2 = dst ? dst->nb[2] : 0;
const uint64_t nb3 = dst ? dst->nb[3] : 0;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
} break; const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
case GGML_OP_SCALE: const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
const float scale = *(const float *) src1->data; id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
[encoder setComputePipelineState:ctx->pipeline_scale]; //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; //if (src0) {
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; // metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
[encoder setBytes:&scale length:sizeof(scale) atIndex:2]; // ggml_is_contiguous(src0), src0->name);
//}
//if (src1) {
// metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
// ggml_is_contiguous(src1), src1->name);
//}
//if (dst) {
// metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
// dst->name);
//}
const int64_t n = ggml_nelements(dst); switch (dst->op) {
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:
case GGML_OP_PERMUTE:
{
// noop
} break;
case GGML_OP_ADD:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder setComputePipelineState:ctx->pipeline_add];
} break; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
case GGML_OP_SILU: [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
{ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
[encoder setComputePipelineState:ctx->pipeline_silu]; const int64_t n = ggml_nelements(dst);
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst); [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_MUL:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; if (ggml_nelements(src1) == ne10) {
} break; // src1 is a row
case GGML_OP_RELU: [encoder setComputePipelineState:ctx->pipeline_mul_row];
{ } else {
if (encoder == nil) { [encoder setComputePipelineState:ctx->pipeline_mul];
encoder = [command_buffer computeCommandEncoder]; }
} [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setComputePipelineState:ctx->pipeline_relu]; const int64_t n = ggml_nelements(dst);
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst); [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SCALE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; const float scale = *(const float *) src1->data;
} break;
case GGML_OP_GELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
[encoder setComputePipelineState:ctx->pipeline_gelu]; [encoder setComputePipelineState:ctx->pipeline_scale];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
const int64_t n = ggml_nelements(dst); const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
case GGML_OP_SOFT_MAX: case GGML_OP_SILU:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoder];
} }
const int nth = 32; [encoder setComputePipelineState:ctx->pipeline_silu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setComputePipelineState:ctx->pipeline_soft_max]; const int64_t n = ggml_nelements(dst);
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
case GGML_OP_DIAG_MASK_INF: case GGML_OP_RELU:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoder];
} }
const int n_past = ((int32_t *)(src1->data))[0]; [encoder setComputePipelineState:ctx->pipeline_relu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; const int64_t n = ggml_nelements(dst);
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
case GGML_OP_MUL_MAT: case GGML_OP_GELU:
{ {
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224 if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
GGML_ASSERT(ne00 == ne10); [encoder setComputePipelineState:ctx->pipeline_gelu];
GGML_ASSERT(ne02 == ne12); [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (ggml_is_contiguous(src0) && const int64_t n = ggml_nelements(dst);
ggml_is_contiguous(src1) &&
(src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
if (encoder != nil) { [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder endEncoding]; } break;
encoder = nil; case GGML_OP_SOFT_MAX:
} {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16; const int nth = 32;
MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
// for F32 x F32 we use MPS [encoder setComputePipelineState:ctx->pipeline_soft_max];
MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt]; } break;
case GGML_OP_DIAG_MASK_INF:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
MPSMatrixDescriptor * desc = [MPSMatrixDescriptor const int n_past = ((int32_t *)(src1->data))[0];
matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc] [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
initWithDevice:ctx->device transposeLeft:false transposeRight:true [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
// we need to do ne02 multiplications [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
// TODO: is there a way to do this in parallel - currently very slow .. } break;
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS case GGML_OP_MUL_MAT:
for (int64_t i02 = 0; i02 < ne02; ++i02) { {
size_t offs_src0_cur = offs_src0 + i02*nb02; // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
size_t offs_src1_cur = offs_src1 + i02*nb12;
size_t offs_dst_cur = offs_dst + i02*nb2;
MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0]; GGML_ASSERT(ne00 == ne10);
MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1]; GGML_ASSERT(ne02 == ne12);
MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst]; if (ggml_is_contiguous(src0) &&
} ggml_is_contiguous(src1) &&
} else { (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
int nth0 = 32; if (encoder != nil) {
int nth1 = 1; [encoder endEncoding];
encoder = nil;
// use custom matrix x vector kernel
switch (src0t) {
case GGML_TYPE_F16:
{
GGML_ASSERT(ne02 == ne12);
nth0 = 64;
nth1 = 1;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
} break;
case GGML_TYPE_Q4_0:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
} break;
case GGML_TYPE_Q4_1:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
} break;
case GGML_TYPE_Q2_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
} break;
case GGML_TYPE_Q3_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32];
} break;
case GGML_TYPE_Q4_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
} break;
case GGML_TYPE_Q5_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32];
} break;
case GGML_TYPE_Q6_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
} break;
default:
{
fprintf(stderr, "Asserting on type %d\n",(int)src0t);
GGML_ASSERT(false && "not implemented");
} }
};
MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; // for F32 x F32 we use MPS
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) { MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_Q3_K ||
src0t == GGML_TYPE_Q4_K ||
src0t == GGML_TYPE_Q5_K ||
src0t == GGML_TYPE_Q6_K) {
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
} break;
case GGML_OP_GET_ROWS:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
switch (src0->type) { MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break;
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
default: GGML_ASSERT(false && "not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; initWithDevice:ctx->device transposeLeft:false transposeRight:true
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
[encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
const int64_t n = ggml_nelements(src1); // we need to do ne02 multiplications
// TODO: is there a way to do this in parallel - currently very slow ..
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
for (int64_t i02 = 0; i02 < ne02; ++i02) {
size_t offs_src0_cur = offs_src0 + i02*nb02;
size_t offs_src1_cur = offs_src1 + i02*nb12;
size_t offs_dst_cur = offs_dst + i02*nb2;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
} break; MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
case GGML_OP_RMS_NORM: MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
const float eps = 1e-6f; [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
}
} else {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
const int nth = 256; int nth0 = 32;
int nth1 = 1;
[encoder setComputePipelineState:ctx->pipeline_rms_norm]; // use custom matrix x vector kernel
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; switch (src0t) {
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; case GGML_TYPE_F16:
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; {
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; GGML_ASSERT(ne02 == ne12);
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(src0); nth0 = 64;
nth1 = 1;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
} break;
case GGML_TYPE_Q4_0:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; nth0 = 8;
} break; nth1 = 8;
case GGML_OP_ROPE: [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
{ } break;
if (encoder == nil) { case GGML_TYPE_Q4_1:
encoder = [command_buffer computeCommandEncoder]; {
} GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
const int n_dims = ((int32_t *) src1->data)[1]; nth0 = 8;
const int mode = ((int32_t *) src1->data)[2]; nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
} break;
case GGML_TYPE_Q2_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
const int n_past = ((int32_t *)(src1->data))[0]; nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
} break;
case GGML_TYPE_Q3_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
[encoder setComputePipelineState:ctx->pipeline_rope]; nth0 = 4;
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; nth1 = 16;
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; } break;
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; case GGML_TYPE_Q4_K:
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; {
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; GGML_ASSERT(ne02 == 1);
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; GGML_ASSERT(ne12 == 1);
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; nth0 = 4;
} break; nth1 = 16;
case GGML_OP_CPY: [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
{ } break;
if (encoder == nil) { case GGML_TYPE_Q5_K:
encoder = [command_buffer computeCommandEncoder]; {
} GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
const int nth = 32; nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32];
} break;
case GGML_TYPE_Q6_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
switch (src0t) { nth0 = 4;
case GGML_TYPE_F32: nth1 = 16;
{ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
switch (dstt) { } break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break; default:
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break; {
default: GGML_ASSERT(false && "not implemented"); fprintf(stderr, "Asserting on type %d\n",(int)src0t);
GGML_ASSERT(false && "not implemented");
}
}; };
} break;
default: GGML_ASSERT(false && "not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
} break; [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
default: [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); }
GGML_ASSERT(false); else if (src0t == GGML_TYPE_Q2_K ||
} src0t == GGML_TYPE_Q3_K ||
src0t == GGML_TYPE_Q4_K ||
src0t == GGML_TYPE_Q5_K ||
src0t == GGML_TYPE_Q6_K) {
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
} break;
case GGML_OP_GET_ROWS:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
switch (src0->type) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break;
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
default: GGML_ASSERT(false && "not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
const int64_t n = ggml_nelements(src1);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_RMS_NORM:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
const float eps = 1e-6f;
const int nth = 256;
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(src0);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_ROPE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
const int n_past = ((int32_t *)(src1->data))[0];
[encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_CPY:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
const int nth = 32;
switch (src0t) {
case GGML_TYPE_F32:
{
switch (dstt) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
default: GGML_ASSERT(false && "not implemented");
};
} break;
default: GGML_ASSERT(false && "not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
default:
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
GGML_ASSERT(false);
}
}
if (encoder != nil) {
[encoder endEncoding];
encoder = nil;
}
[command_buffer commit];
});
} }
if (encoder != nil) { // wait for all threads to finish
[encoder endEncoding]; dispatch_barrier_sync(queue, ^{});
encoder = nil;
}
[command_buffer commit]; [command_buffers[n_cb - 1] waitUntilCompleted];
[command_buffer waitUntilCompleted];
{
const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
UNUSED(time_elapsed);
metal_printf("%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
}
} }