metal : relax conditions on fast matrix multiplication kernel (#3168)

* metal : relax conditions on fast matrix multiplication kernel

* metal : revert the concurrnecy change because it was wrong

* llama : remove experimental stuff
This commit is contained in:
Georgi Gerganov 2023-09-15 11:09:24 +03:00 committed by GitHub
parent 76164fe2e6
commit a51b687657
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 100 additions and 51 deletions

View File

@ -66,6 +66,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(soft_max_4); GGML_METAL_DECL_KERNEL(soft_max_4);
GGML_METAL_DECL_KERNEL(diag_mask_inf); GGML_METAL_DECL_KERNEL(diag_mask_inf);
GGML_METAL_DECL_KERNEL(diag_mask_inf_8); GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
GGML_METAL_DECL_KERNEL(get_rows_f32);
GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(get_rows_q4_1); GGML_METAL_DECL_KERNEL(get_rows_q4_1);
@ -145,7 +146,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
ctx->n_buffers = 0; ctx->n_buffers = 0;
ctx->concur_list_len = 0; ctx->concur_list_len = 0;
ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT); ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
#ifdef GGML_SWIFT #ifdef GGML_SWIFT
// load the default.metallib file // load the default.metallib file
@ -175,7 +176,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"]; //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]); metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error]; NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
@ -224,6 +225,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(soft_max_4); GGML_METAL_ADD_KERNEL(soft_max_4);
GGML_METAL_ADD_KERNEL(diag_mask_inf); GGML_METAL_ADD_KERNEL(diag_mask_inf);
GGML_METAL_ADD_KERNEL(diag_mask_inf_8); GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
GGML_METAL_ADD_KERNEL(get_rows_f32);
GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(get_rows_q4_1); GGML_METAL_ADD_KERNEL(get_rows_q4_1);
@ -293,7 +295,9 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(gelu); GGML_METAL_DEL_KERNEL(gelu);
GGML_METAL_DEL_KERNEL(soft_max); GGML_METAL_DEL_KERNEL(soft_max);
GGML_METAL_DEL_KERNEL(soft_max_4); GGML_METAL_DEL_KERNEL(soft_max_4);
GGML_METAL_DEL_KERNEL(diag_mask_inf);
GGML_METAL_DEL_KERNEL(diag_mask_inf_8); GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
GGML_METAL_DEL_KERNEL(get_rows_f32);
GGML_METAL_DEL_KERNEL(get_rows_f16); GGML_METAL_DEL_KERNEL(get_rows_f16);
GGML_METAL_DEL_KERNEL(get_rows_q4_0); GGML_METAL_DEL_KERNEL(get_rows_q4_0);
GGML_METAL_DEL_KERNEL(get_rows_q4_1); GGML_METAL_DEL_KERNEL(get_rows_q4_1);
@ -386,6 +390,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
for (int i = 0; i < ctx->n_buffers; ++i) { for (int i = 0; i < ctx->n_buffers; ++i) {
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data; const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
//metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) { if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
*offs = (size_t) ioffs; *offs = (size_t) ioffs;
@ -723,6 +728,7 @@ void ggml_metal_graph_compute(
case GGML_OP_ADD: case GGML_OP_ADD:
{ {
GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
// utilize float4 // utilize float4
GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(ne00 % 4 == 0);
@ -730,6 +736,7 @@ void ggml_metal_graph_compute(
if (ggml_nelements(src1) == ne10) { if (ggml_nelements(src1) == ne10) {
// src1 is a row // src1 is a row
GGML_ASSERT(ne11 == 1);
[encoder setComputePipelineState:ctx->pipeline_add_row]; [encoder setComputePipelineState:ctx->pipeline_add_row];
} else { } else {
[encoder setComputePipelineState:ctx->pipeline_add]; [encoder setComputePipelineState:ctx->pipeline_add];
@ -746,6 +753,7 @@ void ggml_metal_graph_compute(
case GGML_OP_MUL: case GGML_OP_MUL:
{ {
GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
// utilize float4 // utilize float4
GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(ne00 % 4 == 0);
@ -753,6 +761,7 @@ void ggml_metal_graph_compute(
if (ggml_nelements(src1) == ne10) { if (ggml_nelements(src1) == ne10) {
// src1 is a row // src1 is a row
GGML_ASSERT(ne11 == 1);
[encoder setComputePipelineState:ctx->pipeline_mul_row]; [encoder setComputePipelineState:ctx->pipeline_mul_row];
} else { } else {
[encoder setComputePipelineState:ctx->pipeline_mul]; [encoder setComputePipelineState:ctx->pipeline_mul];
@ -768,6 +777,8 @@ void ggml_metal_graph_compute(
} break; } break;
case GGML_OP_SCALE: case GGML_OP_SCALE:
{ {
GGML_ASSERT(ggml_is_contiguous(src0));
const float scale = *(const float *) src1->data; const float scale = *(const float *) src1->data;
[encoder setComputePipelineState:ctx->pipeline_scale]; [encoder setComputePipelineState:ctx->pipeline_scale];
@ -867,8 +878,8 @@ void ggml_metal_graph_compute(
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
if (ggml_is_contiguous(src0) && if (!ggml_is_transposed(src0) &&
ggml_is_contiguous(src1) && !ggml_is_transposed(src1) &&
src1t == GGML_TYPE_F32 && src1t == GGML_TYPE_F32 &&
[ctx->device supportsFamily:MTLGPUFamilyApple7] && [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
ne00%32 == 0 && ne00%32 == 0 &&
@ -893,9 +904,12 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
[encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else { } else {
@ -1045,6 +1059,7 @@ void ggml_metal_graph_compute(
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
{ {
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
@ -1060,9 +1075,9 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5]; [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
const int64_t n = ggml_nelements(src1); const int64_t n = ggml_nelements(src1);

View File

@ -38,7 +38,7 @@ kernel void kernel_add_row(
device const float4 * src0, device const float4 * src0,
device const float4 * src1, device const float4 * src1,
device float4 * dst, device float4 * dst,
constant int64_t & nb, constant int64_t & nb,
uint tpig[[thread_position_in_grid]]) { uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] + src1[tpig % nb]; dst[tpig] = src0[tpig] + src1[tpig % nb];
} }
@ -1321,7 +1321,6 @@ kernel void kernel_mul_mat_q3_K_f32(
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row]; dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
} }
} }
} }
#else #else
kernel void kernel_mul_mat_q3_K_f32( kernel void kernel_mul_mat_q3_K_f32(
@ -1865,6 +1864,15 @@ kernel void kernel_mul_mat_q6_K_f32(
//============================= templates and their specializations ============================= //============================= templates and their specializations =============================
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
float4x4 temp = *(((device float4x4 *)src));
for (int i = 0; i < 16; i++){
reg[i/4][i%4] = temp[i/4][i%4];
}
}
template <typename type4x4> template <typename type4x4>
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
half4x4 temp = *(((device half4x4 *)src)); half4x4 temp = *(((device half4x4 *)src));
@ -1875,7 +1883,6 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
template <typename type4x4> template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1); device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const float d1 = il ? (xb->d / 16.h) : xb->d; const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f; const float d2 = d1 / 256.f;
@ -1887,12 +1894,10 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
} }
} }
template <typename type4x4> template <typename type4x4>
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2); device const uint16_t * qs = ((device const uint16_t *)xb + 2);
const float d1 = il ? (xb->d / 16.h) : xb->d; const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f; const float d2 = d1 / 256.f;
@ -1964,7 +1969,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
for (int i = 0; i < 16; ++i) { for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
} }
#else #else
float kcoef = il&1 ? 1.f/16.f : 1.f; float kcoef = il&1 ? 1.f/16.f : 1.f;
uint16_t kmask = il&1 ? 0xF0 : 0x0F; uint16_t kmask = il&1 ? 0xF0 : 0x0F;
@ -2110,22 +2114,25 @@ kernel void kernel_get_rows(
// each block_q contains 16*nl weights // each block_q contains 16*nl weights
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)> template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
kernel void kernel_mul_mm(device const uchar * src0, kernel void kernel_mul_mm(device const uchar * src0,
device const float * src1, device const uchar * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne02, constant int64_t & ne02,
constant int64_t & nb01, constant int64_t & nb01,
constant int64_t & nb02, constant int64_t & nb02,
constant int64_t & ne12, constant int64_t & ne12,
constant int64_t & ne0, constant int64_t & nb10,
constant int64_t & ne1, constant int64_t & nb11,
constant uint & gqa, constant int64_t & nb12,
threadgroup uchar * shared_memory [[threadgroup(0)]], constant int64_t & ne0,
uint3 tgpig[[threadgroup_position_in_grid]], constant int64_t & ne1,
uint tiitg[[thread_index_in_threadgroup]], constant uint & gqa,
uint sgitg[[simdgroup_index_in_threadgroup]]) { threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup half * sa = ((threadgroup half *)shared_memory); threadgroup half * sa = (threadgroup half *)(shared_memory);
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
const uint r0 = tgpig.y; const uint r0 = tgpig.y;
@ -2138,7 +2145,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
simdgroup_half8x8 ma[4]; simdgroup_half8x8 ma[4];
simdgroup_float8x8 mb[2]; simdgroup_float8x8 mb[2];
simdgroup_float8x8 c_res[8]; simdgroup_float8x8 c_res[8];
for (int i = 0; i < 8; i++){ for (int i = 0; i < 8; i++){
@ -2146,10 +2153,15 @@ kernel void kernel_mul_mm(device const uchar * src0,
} }
short il = (tiitg % THREAD_PER_ROW); short il = (tiitg % THREAD_PER_ROW);
uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; uint offset0 = im/gqa*nb02;
device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \ ushort offset1 = il/nl;
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
device const float * y = (device const float *)(src1
+ nb12 * im
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
//load data and store to threadgroup memory //load data and store to threadgroup memory
@ -2229,6 +2241,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \ typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
constant uint64_t &, constant uint64_t &, uint, uint, uint); constant uint64_t &, constant uint64_t &, uint, uint, uint);
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>; template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
@ -2239,14 +2252,27 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>; template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>; template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\ typedef void (mat_mm_t)(
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \ device const uchar * src0,
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint); device const uchar * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
constant int64_t & nb01,
constant int64_t & nb02,
constant int64_t & ne12,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & gqa,
threadgroup uchar *, uint3, uint, uint);
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>; template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;

20
ggml.c
View File

@ -4303,10 +4303,21 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
} }
size_t ggml_nbytes(const struct ggml_tensor * tensor) { size_t ggml_nbytes(const struct ggml_tensor * tensor) {
size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type); size_t nbytes;
for (int i = 1; i < GGML_MAX_DIMS; ++i) { size_t blck_size = ggml_blck_size(tensor->type);
nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; if (blck_size == 1) {
nbytes = ggml_type_size(tensor->type);
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
}
} }
else {
nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
}
}
return nbytes; return nbytes;
} }
@ -18340,7 +18351,8 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n", GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
i, i,
node->ne[0], node->ne[1], node->ne[0], node->ne[1],
ggml_op_name(node->op)); ggml_op_name(node->op),
ggml_get_name(node));
} }
for (int i = 0; i < GGML_OP_COUNT; i++) { for (int i = 0; i < GGML_OP_COUNT; i++) {

View File

@ -3429,10 +3429,6 @@ static bool llama_eval_internal(
if (lctx.ctx_metal) { if (lctx.ctx_metal) {
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
ggml_metal_graph_compute(lctx.ctx_metal, gf); ggml_metal_graph_compute(lctx.ctx_metal, gf);
ggml_metal_get_tensor (lctx.ctx_metal, res);
if (!lctx.embedding.empty()) {
ggml_metal_get_tensor(lctx.ctx_metal, embeddings);
}
} else { } else {
ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);
} }