metal : use mm kernels for batch size > 2

This commit is contained in:
Georgi Gerganov 2023-09-28 16:02:20 +03:00
parent e9463792d3
commit 4c72ab13b2
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -958,7 +958,7 @@ void ggml_metal_graph_compute(
src1t == GGML_TYPE_F32 && src1t == GGML_TYPE_F32 &&
[ctx->device supportsFamily:MTLGPUFamilyApple7] && [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
ne00%32 == 0 && ne00%32 == 0 &&
ne11 > 1) { ne11 > 2) {
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break; case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;