mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 13:28:50 +01:00
metal: somewhat faster f16 x f32 matrix multiply kernel (#2951)
* Somewhat faster f16 x f32 matrix multiply kernel * Better use 32 thread groups for f16 x f32 --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
parent
bce1fef328
commit
e8d9158925
@ -840,7 +840,7 @@ void ggml_metal_graph_compute(
|
|||||||
switch (src0t) {
|
switch (src0t) {
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
{
|
{
|
||||||
nth0 = 64;
|
nth0 = 32;
|
||||||
nth1 = 1;
|
nth1 = 1;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
||||||
} break;
|
} break;
|
||||||
|
@ -528,24 +528,42 @@ kernel void kernel_mul_mat_f16_f32(
|
|||||||
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
|
||||||
sum[tpitg.x] = 0.0f;
|
uint ith = tpitg.x;
|
||||||
|
uint nth = tptg.x;
|
||||||
|
|
||||||
for (int i = tpitg.x; i < ne00; i += tptg.x) {
|
sum[ith] = 0.0f;
|
||||||
sum[tpitg.x] += (float) x[i] * (float) y[i];
|
|
||||||
|
for (int i = ith; i < ne00; i += nth) {
|
||||||
|
sum[ith] += (float) x[i] * (float) y[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
// accumulate the sum from all threads in the threadgroup
|
// accumulate the sum from all threads in the threadgroup
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
for (uint i = tptg.x/2; i > 0; i /= 2) {
|
if (ith%4 == 0) {
|
||||||
if (tpitg.x < i) {
|
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
||||||
sum[tpitg.x] += sum[tpitg.x + i];
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
}
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
if (tpitg.x == 0) {
|
if (ith%16 == 0) {
|
||||||
|
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (ith == 0) {
|
||||||
|
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
||||||
dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
|
dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Original implementation. Left behind commented out for now
|
||||||
|
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
//for (uint i = tptg.x/2; i > 0; i /= 2) {
|
||||||
|
// if (tpitg.x < i) {
|
||||||
|
// sum[tpitg.x] += sum[tpitg.x + i];
|
||||||
|
// }
|
||||||
|
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//if (tpitg.x == 0) {
|
||||||
|
// dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
|
||||||
|
//}
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_alibi_f32(
|
kernel void kernel_alibi_f32(
|
||||||
|
Loading…
Reference in New Issue
Block a user