mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 05:42:22 +01:00
metal : add GELU implementation (#1770)
Co-authored-by: Adam Treat <adam@nomic.ai>
This commit is contained in:
parent
245fc3c37d
commit
92f44ff7f7
16
ggml-metal.m
16
ggml-metal.m
@ -45,6 +45,7 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(scale);
|
GGML_METAL_DECL_KERNEL(scale);
|
||||||
GGML_METAL_DECL_KERNEL(silu);
|
GGML_METAL_DECL_KERNEL(silu);
|
||||||
GGML_METAL_DECL_KERNEL(relu);
|
GGML_METAL_DECL_KERNEL(relu);
|
||||||
|
GGML_METAL_DECL_KERNEL(gelu);
|
||||||
GGML_METAL_DECL_KERNEL(soft_max);
|
GGML_METAL_DECL_KERNEL(soft_max);
|
||||||
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
||||||
@ -135,6 +136,7 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|||||||
GGML_METAL_ADD_KERNEL(scale);
|
GGML_METAL_ADD_KERNEL(scale);
|
||||||
GGML_METAL_ADD_KERNEL(silu);
|
GGML_METAL_ADD_KERNEL(silu);
|
||||||
GGML_METAL_ADD_KERNEL(relu);
|
GGML_METAL_ADD_KERNEL(relu);
|
||||||
|
GGML_METAL_ADD_KERNEL(gelu);
|
||||||
GGML_METAL_ADD_KERNEL(soft_max);
|
GGML_METAL_ADD_KERNEL(soft_max);
|
||||||
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
||||||
@ -420,6 +422,20 @@ void ggml_metal_graph_compute(
|
|||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_GELU:
|
||||||
|
{
|
||||||
|
if (encoder == nil) {
|
||||||
|
encoder = [command_buffer computeCommandEncoder];
|
||||||
|
}
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
|
const int64_t n = ggml_nelements(dst);
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
} break;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
|
@ -81,6 +81,17 @@ kernel void kernel_relu(
|
|||||||
dst[tpig] = max(0.0f, src0[tpig]);
|
dst[tpig] = max(0.0f, src0[tpig]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
constant float GELU_COEF_A = 0.044715f;
|
||||||
|
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||||
|
|
||||||
|
kernel void kernel_gelu(
|
||||||
|
device const float * src0,
|
||||||
|
device float * dst,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
float x = src0[tpig];
|
||||||
|
dst[tpig] = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_soft_max(
|
kernel void kernel_soft_max(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user