metal : rename kernels mul_mat_ to mul_mv_

This commit is contained in:
Georgi Gerganov 2023-10-07 15:06:22 +03:00
parent 99ed03a24a
commit c60022488a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 75 additions and 68 deletions

View File

@ -81,18 +81,18 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@ -262,18 +262,18 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
@ -339,18 +339,18 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
GGML_METAL_DEL_KERNEL(rms_norm);
GGML_METAL_DEL_KERNEL(norm);
GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
@ -1059,7 +1059,7 @@ void ggml_metal_graph_compute(
switch (src0t) {
case GGML_TYPE_F32:
{
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
nrows = 4;
} break;
case GGML_TYPE_F16:
@ -1067,12 +1067,12 @@ void ggml_metal_graph_compute(
nth0 = 32;
nth1 = 1;
if (ne11 * ne12 < 4) {
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
nrows = ne11;
} else {
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
nrows = 4;
}
} break;
@ -1083,7 +1083,7 @@ void ggml_metal_graph_compute(
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
} break;
case GGML_TYPE_Q4_1:
{
@ -1092,7 +1092,7 @@ void ggml_metal_graph_compute(
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
} break;
case GGML_TYPE_Q8_0:
{
@ -1101,7 +1101,7 @@ void ggml_metal_graph_compute(
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
} break;
case GGML_TYPE_Q2_K:
{
@ -1110,7 +1110,7 @@ void ggml_metal_graph_compute(
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
} break;
case GGML_TYPE_Q3_K:
{
@ -1119,7 +1119,7 @@ void ggml_metal_graph_compute(
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
} break;
case GGML_TYPE_Q4_K:
{
@ -1128,7 +1128,7 @@ void ggml_metal_graph_compute(
nth0 = 4; //1;
nth1 = 8; //32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
} break;
case GGML_TYPE_Q5_K:
{
@ -1137,7 +1137,7 @@ void ggml_metal_graph_compute(
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
} break;
case GGML_TYPE_Q6_K:
{
@ -1146,7 +1146,7 @@ void ggml_metal_graph_compute(
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
} break;
default:
{

View File

@ -477,7 +477,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
}
}
kernel void kernel_mul_mat_q4_0_f32(
kernel void kernel_mul_mv_q4_0_f32(
device const void * src0,
device const float * src1,
device float * dst,
@ -495,7 +495,7 @@ kernel void kernel_mul_mat_q4_0_f32(
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mat_q4_1_f32(
kernel void kernel_mul_mv_q4_1_f32(
device const void * src0,
device const float * src1,
device float * dst,
@ -515,7 +515,7 @@ kernel void kernel_mul_mat_q4_1_f32(
#define NB_Q8_0 8
kernel void kernel_mul_mat_q8_0_f32(
kernel void kernel_mul_mv_q8_0_f32(
device const void * src0,
device const float * src1,
device float * dst,
@ -579,7 +579,7 @@ kernel void kernel_mul_mat_q8_0_f32(
#define N_F32_F32 4
kernel void kernel_mul_mat_f32_f32(
kernel void kernel_mul_mv_f32_f32(
device const char * src0,
device const char * src1,
device float * dst,
@ -650,7 +650,7 @@ kernel void kernel_mul_mat_f32_f32(
}
}
kernel void kernel_mul_mat_f16_f32_1row(
kernel void kernel_mul_mv_f16_f32_1row(
device const char * src0,
device const char * src1,
device float * dst,
@ -704,7 +704,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
#define N_F16_F32 4
kernel void kernel_mul_mat_f16_f32(
kernel void kernel_mul_mv_f16_f32(
device const char * src0,
device const char * src1,
device float * dst,
@ -776,7 +776,7 @@ kernel void kernel_mul_mat_f16_f32(
}
// Assumes row size (ne00) is a multiple of 4
kernel void kernel_mul_mat_f16_f32_l4(
kernel void kernel_mul_mv_f16_f32_l4(
device const char * src0,
device const char * src1,
device float * dst,
@ -1253,7 +1253,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
//====================================== dot products =========================
kernel void kernel_mul_mat_q2_K_f32(
kernel void kernel_mul_mv_q2_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
@ -1397,7 +1397,7 @@ kernel void kernel_mul_mat_q2_K_f32(
}
#if QK_K == 256
kernel void kernel_mul_mat_q3_K_f32(
kernel void kernel_mul_mv_q3_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
@ -1549,7 +1549,7 @@ kernel void kernel_mul_mat_q3_K_f32(
}
}
#else
kernel void kernel_mul_mat_q3_K_f32(
kernel void kernel_mul_mv_q3_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
@ -1620,7 +1620,7 @@ kernel void kernel_mul_mat_q3_K_f32(
#endif
#if QK_K == 256
kernel void kernel_mul_mat_q4_K_f32(
kernel void kernel_mul_mv_q4_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
@ -1726,7 +1726,7 @@ kernel void kernel_mul_mat_q4_K_f32(
}
}
#else
kernel void kernel_mul_mat_q4_K_f32(
kernel void kernel_mul_mv_q4_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
@ -1815,7 +1815,7 @@ kernel void kernel_mul_mat_q4_K_f32(
}
#endif
kernel void kernel_mul_mat_q5_K_f32(
kernel void kernel_mul_mv_q5_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
@ -1988,7 +1988,7 @@ kernel void kernel_mul_mat_q5_K_f32(
}
kernel void kernel_mul_mat_q6_K_f32(
kernel void kernel_mul_mv_q6_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
@ -2363,9 +2363,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
const uint r0 = tgpig.y;
const uint r1 = tgpig.x;
const uint im = tgpig.z;
// if this block is of 64x32 shape or smaller
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
// a thread shouldn't load data outside of the matrix
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
@ -2393,22 +2395,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
half4x4 temp_a;
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
#pragma unroll(16)
for (int i = 0; i < 16; i++) {
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+ 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
}
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
= *((device float2x4 *)y);
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2+nl-1)/nl : x;
y += BLOCK_SIZE_K;
threadgroup_barrier(mem_flags::mem_threadgroup);
//load matrices from threadgroup memory and conduct outer products
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
#pragma unroll(4)
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
#pragma unroll(4)
@ -2423,6 +2429,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
#pragma unroll(8)
for (int i = 0; i < 8; i++){
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
@ -2431,7 +2438,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
}
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
device float *C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
for (int i = 0; i < 8; i++) {
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);