mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 05:48:47 +01:00
metal : PP speedup (#3084)
* Minor speed gains for all quantization types * metal: faster kernel_scale via float4 * Various other speedups for "small" kernels * metal: faster soft_max vial float4 * metal: faster diagonal infinity Although, to me it looks like one should simply fuse scale + diagnonal infinity + soft_max on the KQtensor. * Another faster f16 x f32 matrix multiply kernel * Reverting the diag infinity change It does work for PP, but somehow it fails for TG. Need to look more into it. * metal: add back faster diagonal infinity This time more carefully * metal : minor (readibility) --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
6eeb4d9083
commit
f31b6f4e2d
37
ggml-metal.m
37
ggml-metal.m
@ -63,7 +63,9 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(relu);
|
GGML_METAL_DECL_KERNEL(relu);
|
||||||
GGML_METAL_DECL_KERNEL(gelu);
|
GGML_METAL_DECL_KERNEL(gelu);
|
||||||
GGML_METAL_DECL_KERNEL(soft_max);
|
GGML_METAL_DECL_KERNEL(soft_max);
|
||||||
|
GGML_METAL_DECL_KERNEL(soft_max_4);
|
||||||
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
||||||
|
GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
||||||
@ -77,6 +79,7 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(norm);
|
GGML_METAL_DECL_KERNEL(norm);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
||||||
@ -218,7 +221,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(relu);
|
GGML_METAL_ADD_KERNEL(relu);
|
||||||
GGML_METAL_ADD_KERNEL(gelu);
|
GGML_METAL_ADD_KERNEL(gelu);
|
||||||
GGML_METAL_ADD_KERNEL(soft_max);
|
GGML_METAL_ADD_KERNEL(soft_max);
|
||||||
|
GGML_METAL_ADD_KERNEL(soft_max_4);
|
||||||
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
||||||
|
GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
||||||
@ -232,6 +237,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(norm);
|
GGML_METAL_ADD_KERNEL(norm);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
||||||
@ -286,7 +292,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(relu);
|
GGML_METAL_DEL_KERNEL(relu);
|
||||||
GGML_METAL_DEL_KERNEL(gelu);
|
GGML_METAL_DEL_KERNEL(gelu);
|
||||||
GGML_METAL_DEL_KERNEL(soft_max);
|
GGML_METAL_DEL_KERNEL(soft_max);
|
||||||
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
GGML_METAL_DEL_KERNEL(soft_max_4);
|
||||||
|
GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
|
||||||
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
||||||
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
||||||
@ -300,6 +307,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(norm);
|
GGML_METAL_DEL_KERNEL(norm);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
|
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
||||||
@ -767,7 +775,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(dst);
|
const int64_t n = ggml_nelements(dst)/4;
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
@ -779,7 +787,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(dst);
|
const int64_t n = ggml_nelements(dst)/4;
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
@ -799,7 +807,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(dst);
|
const int64_t n = ggml_nelements(dst)/4;
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
@ -813,13 +821,16 @@ void ggml_metal_graph_compute(
|
|||||||
{
|
{
|
||||||
const int nth = 32;
|
const int nth = 32;
|
||||||
|
|
||||||
|
if (ne00%4 == 0) {
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
||||||
|
} else {
|
||||||
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
||||||
|
}
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||||
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
@ -827,14 +838,23 @@ void ggml_metal_graph_compute(
|
|||||||
{
|
{
|
||||||
const int n_past = ((int32_t *)(dst->op_params))[0];
|
const int n_past = ((int32_t *)(dst->op_params))[0];
|
||||||
|
|
||||||
|
if (ne00%8 == 0) {
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
|
||||||
|
} else {
|
||||||
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
||||||
|
}
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||||
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
||||||
|
|
||||||
|
if (ne00%8 == 0) {
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
}
|
||||||
|
else {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
{
|
{
|
||||||
@ -881,6 +901,7 @@ void ggml_metal_graph_compute(
|
|||||||
} else {
|
} else {
|
||||||
int nth0 = 32;
|
int nth0 = 32;
|
||||||
int nth1 = 1;
|
int nth1 = 1;
|
||||||
|
int nrows = 1;
|
||||||
|
|
||||||
// use custom matrix x vector kernel
|
// use custom matrix x vector kernel
|
||||||
switch (src0t) {
|
switch (src0t) {
|
||||||
@ -890,8 +911,12 @@ void ggml_metal_graph_compute(
|
|||||||
nth1 = 1;
|
nth1 = 1;
|
||||||
if (ne11 * ne12 < 4) {
|
if (ne11 * ne12 < 4) {
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
|
||||||
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
|
||||||
|
nrows = ne11;
|
||||||
} else {
|
} else {
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
||||||
|
nrows = 4;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
@ -1012,7 +1037,7 @@ void ggml_metal_graph_compute(
|
|||||||
else if (src0t == GGML_TYPE_Q6_K) {
|
else if (src0t == GGML_TYPE_Q6_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
} else {
|
} else {
|
||||||
int64_t ny = (ne11 + 3)/4;
|
int64_t ny = (ne11 + nrows - 1)/nrows;
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
279
ggml-metal.metal
279
ggml-metal.metal
@ -63,18 +63,18 @@ kernel void kernel_mul_row(
|
|||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_scale(
|
kernel void kernel_scale(
|
||||||
device const float * src0,
|
device const float4 * src0,
|
||||||
device float * dst,
|
device float4 * dst,
|
||||||
constant float & scale,
|
constant float & scale,
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
dst[tpig] = src0[tpig] * scale;
|
dst[tpig] = src0[tpig] * scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_silu(
|
kernel void kernel_silu(
|
||||||
device const float * src0,
|
device const float4 * src0,
|
||||||
device float * dst,
|
device float4 * dst,
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
float x = src0[tpig];
|
device const float4 & x = src0[tpig];
|
||||||
dst[tpig] = x / (1.0f + exp(-x));
|
dst[tpig] = x / (1.0f + exp(-x));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,10 +89,10 @@ constant float GELU_COEF_A = 0.044715f;
|
|||||||
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||||
|
|
||||||
kernel void kernel_gelu(
|
kernel void kernel_gelu(
|
||||||
device const float * src0,
|
device const float4 * src0,
|
||||||
device float * dst,
|
device float4 * dst,
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
float x = src0[tpig];
|
device const float4 & x = src0[tpig];
|
||||||
|
|
||||||
// BEWARE !!!
|
// BEWARE !!!
|
||||||
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
||||||
@ -107,7 +107,6 @@ kernel void kernel_soft_max(
|
|||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
threadgroup float * buf [[threadgroup(0)]],
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
@ -119,64 +118,70 @@ kernel void kernel_soft_max(
|
|||||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
buf[tpitg[0]] = -INFINITY;
|
float lmax = psrc0[tpitg[0]];
|
||||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||||
buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
|
lmax = MAX(lmax, psrc0[i00]);
|
||||||
}
|
}
|
||||||
|
const float max = simd_max(lmax);
|
||||||
// reduce
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
for (uint i = ntg[0]/2; i > 0; i /= 2) {
|
|
||||||
if (tpitg[0] < i) {
|
|
||||||
buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
}
|
|
||||||
|
|
||||||
//// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
|
|
||||||
// the loop, and when that is done, buf[0] has the correct (synchronized) value
|
|
||||||
//if (tpitg[0] == 0) {
|
|
||||||
// buf[0] = buf[0];
|
|
||||||
//}
|
|
||||||
|
|
||||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
const float max = buf[0];
|
|
||||||
|
|
||||||
// parallel sum
|
// parallel sum
|
||||||
buf[tpitg[0]] = 0.0f;
|
float lsum = 0.0f;
|
||||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||||
const float exp_psrc0 = exp(psrc0[i00] - max);
|
const float exp_psrc0 = exp(psrc0[i00] - max);
|
||||||
buf[tpitg[0]] += exp_psrc0;
|
lsum += exp_psrc0;
|
||||||
// Remember the result of exp here. exp is expensive, so we really do not
|
// Remember the result of exp here. exp is expensive, so we really do not
|
||||||
// whish to compute it twice.
|
// whish to compute it twice.
|
||||||
pdst[i00] = exp_psrc0;
|
pdst[i00] = exp_psrc0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// reduce
|
const float sum = simd_sum(lsum);
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
for (uint i = ntg[0]/2; i > 0; i /= 2) {
|
|
||||||
if (tpitg[0] < i) {
|
|
||||||
buf[tpitg[0]] += buf[tpitg[0] + i];
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
}
|
|
||||||
|
|
||||||
// broadcast - not needed, see above
|
|
||||||
//// broadcast
|
|
||||||
//if (tpitg[0] == 0) {
|
|
||||||
// buf[0] = buf[0];
|
|
||||||
//}
|
|
||||||
|
|
||||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
const float sum = buf[0];
|
|
||||||
|
|
||||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||||
pdst[i00] /= sum;
|
pdst[i00] /= sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_soft_max_4(
|
||||||
|
device const float * src0,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int64_t i03 = tgpig[2];
|
||||||
|
const int64_t i02 = tgpig[1];
|
||||||
|
const int64_t i01 = tgpig[0];
|
||||||
|
|
||||||
|
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
|
|
||||||
|
// parallel max
|
||||||
|
float4 lmax4 = psrc4[tpitg[0]];
|
||||||
|
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
||||||
|
lmax4 = fmax(lmax4, psrc4[i00]);
|
||||||
|
}
|
||||||
|
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||||
|
|
||||||
|
const float max = simd_max(lmax);
|
||||||
|
|
||||||
|
// parallel sum
|
||||||
|
float4 lsum4 = 0.0f;
|
||||||
|
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
||||||
|
const float4 exp_psrc4 = exp(psrc4[i00] - max);
|
||||||
|
lsum4 += exp_psrc4;
|
||||||
|
pdst4[i00] = exp_psrc4;
|
||||||
|
}
|
||||||
|
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
||||||
|
|
||||||
|
const float sum = simd_sum(lsum);
|
||||||
|
|
||||||
|
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
||||||
|
pdst4[i00] /= sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_diag_mask_inf(
|
kernel void kernel_diag_mask_inf(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -195,6 +200,33 @@ kernel void kernel_diag_mask_inf(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_diag_mask_inf_8(
|
||||||
|
device const float4 * src0,
|
||||||
|
device float4 * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int & n_past,
|
||||||
|
uint3 tpig[[thread_position_in_grid]]) {
|
||||||
|
|
||||||
|
const int64_t i = 2*tpig[0];
|
||||||
|
|
||||||
|
dst[i+0] = src0[i+0];
|
||||||
|
dst[i+1] = src0[i+1];
|
||||||
|
int64_t i4 = 4*i;
|
||||||
|
const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
|
||||||
|
const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
|
||||||
|
const int64_t i00 = i4;
|
||||||
|
for (int k = 3; k >= 0; --k) {
|
||||||
|
if (i00 + 4 + k <= n_past + i01) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
dst[i+1][k] = -INFINITY;
|
||||||
|
if (i00 + k > n_past + i01) {
|
||||||
|
dst[i][k] = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_norm(
|
kernel void kernel_norm(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -616,6 +648,49 @@ kernel void kernel_mul_mat_f16_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Assumes row size (ne00) is a multiple of 4
|
||||||
|
kernel void kernel_mul_mat_f16_f32_l4(
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
|
const int nrows = ne11;
|
||||||
|
const int64_t r0 = tgpig.x;
|
||||||
|
const int64_t im = tgpig.z;
|
||||||
|
|
||||||
|
device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||||
|
|
||||||
|
for (int r1 = 0; r1 < nrows; ++r1) {
|
||||||
|
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||||
|
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
||||||
|
}
|
||||||
|
|
||||||
|
float all_sum = simd_sum(sumf);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_alibi_f32(
|
kernel void kernel_alibi_f32(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
@ -1800,29 +1875,34 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
|||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
||||||
|
|
||||||
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
||||||
const half d = il ? (xb->d / 16.h) : xb->d;
|
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
||||||
const half m = il ? ( -8.h * 16.h) : -8.h;
|
const float d2 = d1 / 256.f;
|
||||||
|
const float md = -8.h * xb->d;
|
||||||
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
||||||
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
const ushort mask1 = mask0 << 8;
|
||||||
|
|
||||||
for (int i=0;i<8;i++) {
|
for (int i=0;i<8;i++) {
|
||||||
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
|
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
|
||||||
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
|
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
||||||
|
|
||||||
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
||||||
const half d = il ? (xb->d / 16.h) : xb->d;
|
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
||||||
const half m = xb->m;
|
const float d2 = d1 / 256.f;
|
||||||
|
const float m = xb->m;
|
||||||
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
||||||
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
const ushort mask1 = mask0 << 8;
|
||||||
|
|
||||||
for (int i=0;i<8;i++) {
|
for (int i=0;i<8;i++) {
|
||||||
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
|
reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
|
||||||
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
|
reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1858,7 +1938,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
|
|||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
|
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
|
||||||
const float d_all = (float)(xb->d);
|
const half d_all = xb->d;
|
||||||
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
||||||
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
||||||
device const int8_t * scales = (device const int8_t *)xb->scales;
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
||||||
@ -1871,17 +1951,20 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|||||||
((il/4)>0 ? 12 : 3);
|
((il/4)>0 ? 12 : 3);
|
||||||
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
|
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
|
||||||
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
||||||
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
|
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
||||||
(scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
||||||
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
|
half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
|
||||||
|
const half ml = 4.h * dl;
|
||||||
|
|
||||||
il = (il/2)%4;
|
il = (il/2) & 3;
|
||||||
float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
||||||
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
||||||
|
dl *= coef;
|
||||||
|
|
||||||
for (int i = 0; i < 16; ++i) {
|
for (int i = 0; i < 16; ++i) {
|
||||||
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
|
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
|
||||||
}
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
||||||
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
|
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
|
||||||
@ -1895,31 +1978,37 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
|
||||||
|
return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
|
||||||
|
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
|
||||||
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
||||||
device const uint8_t * q = xb->qs;
|
device const uchar * q = xb->qs;
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const float d = (float)(xb->d);
|
|
||||||
const float min = (float)(xb->dmin);
|
|
||||||
short is = (il/4) * 2;
|
short is = (il/4) * 2;
|
||||||
q = q + (il/4) * 32 + 16 * (il&1);
|
q = q + (il/4) * 32 + 16 * (il&1);
|
||||||
il = il%4;
|
il = il & 3;
|
||||||
const uchar4 sc = get_scale_min_k4(is, xb->scales);
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
||||||
const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
|
const half d = il < 2 ? xb->d : xb->d / 16.h;
|
||||||
const float ml = il<2 ? min * sc[1] : min * sc[3];
|
const half min = xb->dmin;
|
||||||
|
const half dl = d * sc[0];
|
||||||
|
const half ml = min * sc[1];
|
||||||
#else
|
#else
|
||||||
q = q + 16 * (il&1);
|
q = q + 16 * (il&1);
|
||||||
device const uint8_t * s = xb->scales;
|
device const uint8_t * s = xb->scales;
|
||||||
device const half2 * dh = (device const half2 *)xb->d;
|
device const half2 * dh = (device const half2 *)xb->d;
|
||||||
const float2 d = (float2)dh[0];
|
const float2 d = (float2)dh[0];
|
||||||
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
|
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
|
||||||
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
|
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
|
||||||
#endif
|
#endif
|
||||||
const ushort mask = il<2 ? 0x0F : 0xF0;
|
const ushort mask = il<2 ? 0x0F : 0xF0;
|
||||||
for (int i = 0; i < 16; ++i) {
|
for (int i = 0; i < 16; ++i) {
|
||||||
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
@ -1928,19 +2017,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|||||||
device const uint8_t * qh = xb->qh;
|
device const uint8_t * qh = xb->qh;
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const float d = (float)(xb->d);
|
|
||||||
const float min = (float)(xb->dmin);
|
|
||||||
short is = (il/4) * 2;
|
short is = (il/4) * 2;
|
||||||
q = q + 32 * (il/4) + 16 * (il&1);
|
q = q + 32 * (il/4) + 16 * (il&1);
|
||||||
qh = qh + 16 * (il&1);
|
qh = qh + 16 * (il&1);
|
||||||
uint8_t ul = 1 << (il/2);
|
uint8_t ul = 1 << (il/2);
|
||||||
il = il%4;
|
il = il & 3;
|
||||||
const uchar4 sc = get_scale_min_k4(is, xb->scales);
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
||||||
const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
|
const half d = il < 2 ? xb->d : xb->d / 16.h;
|
||||||
const float ml = il<2 ? min * sc[1] : min * sc[3];
|
const half min = xb->dmin;
|
||||||
|
const half dl = d * sc[0];
|
||||||
|
const half ml = min * sc[1];
|
||||||
|
|
||||||
const ushort mask = il<2 ? 0x0F : 0xF0;
|
const ushort mask = il<2 ? 0x0F : 0xF0;
|
||||||
const float qh_val = il<2 ? 16.f : 256.f;
|
const half qh_val = il<2 ? 16.h : 256.h;
|
||||||
for (int i = 0; i < 16; ++i) {
|
for (int i = 0; i < 16; ++i) {
|
||||||
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
||||||
}
|
}
|
||||||
@ -1959,7 +2048,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
|
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
|
||||||
const float d_all = (float)(xb->d);
|
const half d_all = xb->d;
|
||||||
device const uint8_t * ql = (device const uint8_t *)xb->ql;
|
device const uint8_t * ql = (device const uint8_t *)xb->ql;
|
||||||
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
||||||
device const int8_t * scales = (device const int8_t *)xb->scales;
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
||||||
@ -1967,19 +2056,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
||||||
qh = qh + 32*(il/8) + 16*(il&1);
|
qh = qh + 32*(il/8) + 16*(il&1);
|
||||||
float sc = scales[(il%2) + 2 * ((il/2))];
|
half sc = scales[(il%2) + 2 * ((il/2))];
|
||||||
il = (il/2)%4;
|
il = (il/2) & 3;
|
||||||
#else
|
#else
|
||||||
ql = ql + 16 * (il&1);
|
ql = ql + 16 * (il&1);
|
||||||
float sc = scales[il];
|
half sc = scales[il];
|
||||||
#endif
|
#endif
|
||||||
|
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
||||||
|
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
||||||
|
const half coef = il>1 ? 1.f/16.h : 1.h;
|
||||||
|
const half ml = d_all * sc * 32.h;
|
||||||
|
const half dl = d_all * sc * coef;
|
||||||
for (int i = 0; i < 16; ++i) {
|
for (int i = 0; i < 16; ++i) {
|
||||||
uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
|
||||||
uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
|
||||||
const float coef = il>1 ? 1.f/16.f : 1.f;
|
reg[i/4][i%4] = dl * q - ml;
|
||||||
float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
|
|
||||||
((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
|
|
||||||
reg[i/4][i%4] = d_all * sc * q * coef;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user