mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-03 17:51:09 +01:00
metal : assert various kernel requirements
This commit is contained in:
parent
42833bc7a8
commit
0f8df395ce
20
ggml-metal.m
20
ggml-metal.m
@ -774,8 +774,8 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
{
|
{
|
||||||
|
const int64_t nb = ne00;
|
||||||
|
|
||||||
int64_t nb = ne00;
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_concat];
|
[encoder setComputePipelineState:ctx->pipeline_concat];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
@ -807,6 +807,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
||||||
|
|
||||||
const int nth = MIN(1024, ne0);
|
const int nth = MIN(1024, ne0);
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
@ -904,9 +905,10 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(dst)/4;
|
const int64_t n = ggml_nelements(dst);
|
||||||
|
GGML_ASSERT(n % 4 == 0);
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(gf->nodes[i])) {
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
||||||
@ -916,9 +918,10 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(dst)/4;
|
const int64_t n = ggml_nelements(dst);
|
||||||
|
GGML_ASSERT(n % 4 == 0);
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
{
|
{
|
||||||
@ -936,9 +939,10 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(dst)/4;
|
const int64_t n = ggml_nelements(dst);
|
||||||
|
GGML_ASSERT(n % 4 == 0);
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
@ -1220,6 +1224,8 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
{
|
{
|
||||||
|
GGML_ASSERT(ne00 % 4 == 0);
|
||||||
|
|
||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, dst->op_params, sizeof(float));
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
|
@ -345,10 +345,11 @@ kernel void kernel_rms_norm(
|
|||||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint ntg[[threads_per_threadgroup]]) {
|
uint ntg[[threads_per_threadgroup]]) {
|
||||||
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
||||||
device const float * x_scalar = (device const float *) x;
|
device const float * x_scalar = (device const float *) x;
|
||||||
float4 sumf=0;
|
|
||||||
float all_sum=0;
|
float4 sumf = 0;
|
||||||
|
float all_sum = 0;
|
||||||
|
|
||||||
// parallel sum
|
// parallel sum
|
||||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
@ -361,6 +362,7 @@ kernel void kernel_rms_norm(
|
|||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// broadcast, simd group number is ntg / 32
|
// broadcast, simd group number is ntg / 32
|
||||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||||
if (tpitg < i) {
|
if (tpitg < i) {
|
||||||
@ -368,7 +370,9 @@ kernel void kernel_rms_norm(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (tpitg == 0) {
|
if (tpitg == 0) {
|
||||||
for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
|
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
||||||
|
sum[0] += x_scalar[i];
|
||||||
|
}
|
||||||
sum[0] /= ne00;
|
sum[0] /= ne00;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -383,7 +387,9 @@ kernel void kernel_rms_norm(
|
|||||||
y[i00] = x[i00] * scale;
|
y[i00] = x[i00] * scale;
|
||||||
}
|
}
|
||||||
if (tpitg == 0) {
|
if (tpitg == 0) {
|
||||||
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
|
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
||||||
|
y_scalar[i00] = x_scalar[i00] * scale;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user