This commit is contained in:
Georgi Gerganov 2023-10-07 19:20:40 +03:00
parent 8f6ad68427
commit 545b03491c
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 18 additions and 16 deletions

View File

@ -994,7 +994,7 @@ void ggml_metal_graph_compute(
GGML_ASSERT(ne03 == ne13); GGML_ASSERT(ne03 == ne13);
// find the break-even point where the matrix-matrix kernel becomes more efficient compared // find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel. the numbers below are measure on M2 Ultra // to the matrix-vector kernel. the numbers below are measured on M2 Ultra
// not sure if this translates across all chips // not sure if this translates across all chips
int ne11_mm_min = 1; int ne11_mm_min = 1;
@ -1015,12 +1015,13 @@ void ggml_metal_graph_compute(
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
if (!ggml_is_transposed(src0) && if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
!ggml_is_transposed(src0) &&
!ggml_is_transposed(src1) && !ggml_is_transposed(src1) &&
src1t == GGML_TYPE_F32 && src1t == GGML_TYPE_F32 &&
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
ne00 % 32 == 0 && ne00 % 32 == 0 &&
ne11 > ne11_mm_min) { ne11 > ne11_mm_min) {
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break; case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
@ -1054,6 +1055,7 @@ void ggml_metal_graph_compute(
int nth0 = 32; int nth0 = 32;
int nth1 = 1; int nth1 = 1;
int nrows = 1; int nrows = 1;
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
// use custom matrix x vector kernel // use custom matrix x vector kernel
switch (src0t) { switch (src0t) {