mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-01 00:39:00 +01:00
metal : rename kernels mul_mat_ to mul_mv_
This commit is contained in:
parent
99ed03a24a
commit
c60022488a
96
ggml-metal.m
96
ggml-metal.m
@ -81,18 +81,18 @@ struct ggml_metal_context {
|
||||
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
||||
GGML_METAL_DECL_KERNEL(rms_norm);
|
||||
GGML_METAL_DECL_KERNEL(norm);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
||||
@ -262,18 +262,18 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
||||
GGML_METAL_ADD_KERNEL(rms_norm);
|
||||
GGML_METAL_ADD_KERNEL(norm);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
||||
@ -339,18 +339,18 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
||||
GGML_METAL_DEL_KERNEL(rms_norm);
|
||||
GGML_METAL_DEL_KERNEL(norm);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
||||
@ -1059,7 +1059,7 @@ void ggml_metal_graph_compute(
|
||||
switch (src0t) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
|
||||
nrows = 4;
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
@ -1067,12 +1067,12 @@ void ggml_metal_graph_compute(
|
||||
nth0 = 32;
|
||||
nth1 = 1;
|
||||
if (ne11 * ne12 < 4) {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
|
||||
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
|
||||
nrows = ne11;
|
||||
} else {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
|
||||
nrows = 4;
|
||||
}
|
||||
} break;
|
||||
@ -1083,7 +1083,7 @@ void ggml_metal_graph_compute(
|
||||
|
||||
nth0 = 8;
|
||||
nth1 = 8;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
{
|
||||
@ -1092,7 +1092,7 @@ void ggml_metal_graph_compute(
|
||||
|
||||
nth0 = 8;
|
||||
nth1 = 8;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
{
|
||||
@ -1101,7 +1101,7 @@ void ggml_metal_graph_compute(
|
||||
|
||||
nth0 = 8;
|
||||
nth1 = 8;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
{
|
||||
@ -1110,7 +1110,7 @@ void ggml_metal_graph_compute(
|
||||
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
{
|
||||
@ -1119,7 +1119,7 @@ void ggml_metal_graph_compute(
|
||||
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
{
|
||||
@ -1128,7 +1128,7 @@ void ggml_metal_graph_compute(
|
||||
|
||||
nth0 = 4; //1;
|
||||
nth1 = 8; //32;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
{
|
||||
@ -1137,7 +1137,7 @@ void ggml_metal_graph_compute(
|
||||
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
{
|
||||
@ -1146,7 +1146,7 @@ void ggml_metal_graph_compute(
|
||||
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
|
@ -477,7 +477,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mat_q4_0_f32(
|
||||
kernel void kernel_mul_mv_q4_0_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
@ -495,7 +495,7 @@ kernel void kernel_mul_mat_q4_0_f32(
|
||||
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mat_q4_1_f32(
|
||||
kernel void kernel_mul_mv_q4_1_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
@ -515,7 +515,7 @@ kernel void kernel_mul_mat_q4_1_f32(
|
||||
|
||||
#define NB_Q8_0 8
|
||||
|
||||
kernel void kernel_mul_mat_q8_0_f32(
|
||||
kernel void kernel_mul_mv_q8_0_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
@ -579,7 +579,7 @@ kernel void kernel_mul_mat_q8_0_f32(
|
||||
|
||||
#define N_F32_F32 4
|
||||
|
||||
kernel void kernel_mul_mat_f32_f32(
|
||||
kernel void kernel_mul_mv_f32_f32(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
@ -650,7 +650,7 @@ kernel void kernel_mul_mat_f32_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mat_f16_f32_1row(
|
||||
kernel void kernel_mul_mv_f16_f32_1row(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
@ -704,7 +704,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
|
||||
|
||||
#define N_F16_F32 4
|
||||
|
||||
kernel void kernel_mul_mat_f16_f32(
|
||||
kernel void kernel_mul_mv_f16_f32(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
@ -776,7 +776,7 @@ kernel void kernel_mul_mat_f16_f32(
|
||||
}
|
||||
|
||||
// Assumes row size (ne00) is a multiple of 4
|
||||
kernel void kernel_mul_mat_f16_f32_l4(
|
||||
kernel void kernel_mul_mv_f16_f32_l4(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
@ -1253,7 +1253,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
||||
|
||||
//====================================== dot products =========================
|
||||
|
||||
kernel void kernel_mul_mat_q2_K_f32(
|
||||
kernel void kernel_mul_mv_q2_K_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
@ -1397,7 +1397,7 @@ kernel void kernel_mul_mat_q2_K_f32(
|
||||
}
|
||||
|
||||
#if QK_K == 256
|
||||
kernel void kernel_mul_mat_q3_K_f32(
|
||||
kernel void kernel_mul_mv_q3_K_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
@ -1549,7 +1549,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||
}
|
||||
}
|
||||
#else
|
||||
kernel void kernel_mul_mat_q3_K_f32(
|
||||
kernel void kernel_mul_mv_q3_K_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
@ -1620,7 +1620,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||
#endif
|
||||
|
||||
#if QK_K == 256
|
||||
kernel void kernel_mul_mat_q4_K_f32(
|
||||
kernel void kernel_mul_mv_q4_K_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
@ -1726,7 +1726,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
||||
}
|
||||
}
|
||||
#else
|
||||
kernel void kernel_mul_mat_q4_K_f32(
|
||||
kernel void kernel_mul_mv_q4_K_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
@ -1815,7 +1815,7 @@ kernel void kernel_mul_mat_q4_K_f32(
|
||||
}
|
||||
#endif
|
||||
|
||||
kernel void kernel_mul_mat_q5_K_f32(
|
||||
kernel void kernel_mul_mv_q5_K_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
@ -1988,7 +1988,7 @@ kernel void kernel_mul_mat_q5_K_f32(
|
||||
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mat_q6_K_f32(
|
||||
kernel void kernel_mul_mv_q6_K_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
@ -2363,9 +2363,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||
const uint r0 = tgpig.y;
|
||||
const uint r1 = tgpig.x;
|
||||
const uint im = tgpig.z;
|
||||
|
||||
// if this block is of 64x32 shape or smaller
|
||||
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
||||
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
||||
|
||||
// a thread shouldn't load data outside of the matrix
|
||||
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
||||
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
||||
@ -2393,22 +2395,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||
half4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
#pragma unroll(16)
|
||||
for (int i = 0; i < 16; i++) {
|
||||
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
||||
+ 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
|
||||
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
||||
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
||||
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
||||
}
|
||||
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
|
||||
= *((device float2x4 *)y);
|
||||
|
||||
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
||||
|
||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
||||
y += BLOCK_SIZE_K;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
//load matrices from threadgroup memory and conduct outer products
|
||||
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
||||
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
||||
#pragma unroll(4)
|
||||
@ -2423,6 +2429,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||
|
||||
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
||||
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
||||
|
||||
#pragma unroll(8)
|
||||
for (int i = 0; i < 8; i++){
|
||||
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
||||
@ -2431,8 +2438,8 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||
}
|
||||
|
||||
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
|
||||
device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
|
||||
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
|
||||
device float *C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
|
||||
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user