mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-26 12:21:40 +01:00
Added support for GGML_OP_CLAMP in Metal (#6662)
* Added support for GGML_OP_CLAMP in Metal * Corrected size --------- Co-authored-by: dave-fl <dave@Davids-MacBook-Pro.local>
This commit is contained in:
parent
8800226d65
commit
422c2aff1c
22
ggml-metal.m
22
ggml-metal.m
@ -37,6 +37,7 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
||||||
GGML_METAL_KERNEL_TYPE_SCALE,
|
GGML_METAL_KERNEL_TYPE_SCALE,
|
||||||
GGML_METAL_KERNEL_TYPE_SCALE_4,
|
GGML_METAL_KERNEL_TYPE_SCALE_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_CLAMP,
|
||||||
GGML_METAL_KERNEL_TYPE_TANH,
|
GGML_METAL_KERNEL_TYPE_TANH,
|
||||||
GGML_METAL_KERNEL_TYPE_RELU,
|
GGML_METAL_KERNEL_TYPE_RELU,
|
||||||
GGML_METAL_KERNEL_TYPE_GELU,
|
GGML_METAL_KERNEL_TYPE_GELU,
|
||||||
@ -468,6 +469,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
||||||
@ -713,6 +715,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
|
case GGML_OP_CLAMP:
|
||||||
case GGML_OP_SQR:
|
case GGML_OP_SQR:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
return true;
|
return true;
|
||||||
@ -1154,6 +1157,25 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_CLAMP:
|
||||||
|
{
|
||||||
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
|
||||||
|
|
||||||
|
float min;
|
||||||
|
float max;
|
||||||
|
memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
|
||||||
|
memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
[encoder setBytes:&min length:sizeof(min) atIndex:2];
|
||||||
|
[encoder setBytes:&max length:sizeof(max) atIndex:3];
|
||||||
|
|
||||||
|
const int64_t n = ggml_nelements(dst);
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
} break;
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(gf->nodes[i])) {
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
||||||
case GGML_UNARY_OP_TANH:
|
case GGML_UNARY_OP_TANH:
|
||||||
|
@ -213,6 +213,15 @@ kernel void kernel_scale_4(
|
|||||||
dst[tpig] = src0[tpig] * scale;
|
dst[tpig] = src0[tpig] * scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_clamp(
|
||||||
|
device const float * src0,
|
||||||
|
device float * dst,
|
||||||
|
constant float & min,
|
||||||
|
constant float & max,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_relu(
|
kernel void kernel_relu(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
Loading…
Reference in New Issue
Block a user