mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-11 21:10:24 +01:00
parent
e3932593d4
commit
469c9addef
18
ggml-metal.m
18
ggml-metal.m
@ -62,6 +62,7 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(mul);
|
GGML_METAL_DECL_KERNEL(mul);
|
||||||
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
||||||
GGML_METAL_DECL_KERNEL(scale);
|
GGML_METAL_DECL_KERNEL(scale);
|
||||||
|
GGML_METAL_DECL_KERNEL(scale_4);
|
||||||
GGML_METAL_DECL_KERNEL(silu);
|
GGML_METAL_DECL_KERNEL(silu);
|
||||||
GGML_METAL_DECL_KERNEL(relu);
|
GGML_METAL_DECL_KERNEL(relu);
|
||||||
GGML_METAL_DECL_KERNEL(gelu);
|
GGML_METAL_DECL_KERNEL(gelu);
|
||||||
@ -249,6 +250,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(mul);
|
GGML_METAL_ADD_KERNEL(mul);
|
||||||
GGML_METAL_ADD_KERNEL(mul_row);
|
GGML_METAL_ADD_KERNEL(mul_row);
|
||||||
GGML_METAL_ADD_KERNEL(scale);
|
GGML_METAL_ADD_KERNEL(scale);
|
||||||
|
GGML_METAL_ADD_KERNEL(scale_4);
|
||||||
GGML_METAL_ADD_KERNEL(silu);
|
GGML_METAL_ADD_KERNEL(silu);
|
||||||
GGML_METAL_ADD_KERNEL(relu);
|
GGML_METAL_ADD_KERNEL(relu);
|
||||||
GGML_METAL_ADD_KERNEL(gelu);
|
GGML_METAL_ADD_KERNEL(gelu);
|
||||||
@ -347,6 +349,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(mul);
|
GGML_METAL_DEL_KERNEL(mul);
|
||||||
GGML_METAL_DEL_KERNEL(mul_row);
|
GGML_METAL_DEL_KERNEL(mul_row);
|
||||||
GGML_METAL_DEL_KERNEL(scale);
|
GGML_METAL_DEL_KERNEL(scale);
|
||||||
|
GGML_METAL_DEL_KERNEL(scale_4);
|
||||||
GGML_METAL_DEL_KERNEL(silu);
|
GGML_METAL_DEL_KERNEL(silu);
|
||||||
GGML_METAL_DEL_KERNEL(relu);
|
GGML_METAL_DEL_KERNEL(relu);
|
||||||
GGML_METAL_DEL_KERNEL(gelu);
|
GGML_METAL_DEL_KERNEL(gelu);
|
||||||
@ -923,15 +926,20 @@ void ggml_metal_graph_compute(
|
|||||||
|
|
||||||
const float scale = *(const float *) src1->data;
|
const float scale = *(const float *) src1->data;
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_scale];
|
int64_t n = ggml_nelements(dst);
|
||||||
|
|
||||||
|
if (n % 4 == 0) {
|
||||||
|
n /= 4;
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_scale_4];
|
||||||
|
} else {
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_scale];
|
||||||
|
}
|
||||||
|
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(dst);
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
GGML_ASSERT(n % 4 == 0);
|
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(gf->nodes[i])) {
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
||||||
|
@ -125,9 +125,17 @@ kernel void kernel_mul_row(
|
|||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_scale(
|
kernel void kernel_scale(
|
||||||
|
device const float * src0,
|
||||||
|
device float * dst,
|
||||||
|
constant float & scale,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
dst[tpig] = src0[tpig] * scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_scale_4(
|
||||||
device const float4 * src0,
|
device const float4 * src0,
|
||||||
device float4 * dst,
|
device float4 * dst,
|
||||||
constant float & scale,
|
constant float & scale,
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
dst[tpig] = src0[tpig] * scale;
|
dst[tpig] = src0[tpig] * scale;
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user