mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 07:34:18 +01:00
metal : warp-based reduction for soft max kernel
This commit is contained in:
parent
68e02c0d58
commit
55717c98c4
10
ggml-metal.m
10
ggml-metal.m
@ -1028,12 +1028,14 @@ void ggml_metal_graph_compute(
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
if (ne00%4 == 0) {
|
||||
while (nth < ne00/4 && nth < 256) {
|
||||
nth *= 2;
|
||||
}
|
||||
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
||||
} else {
|
||||
do {
|
||||
while (nth < ne00 && nth < 1024) {
|
||||
nth *= 2;
|
||||
} while (nth <= ne00 && nth <= 1024);
|
||||
nth /= 2;
|
||||
}
|
||||
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
||||
}
|
||||
|
||||
@ -1046,7 +1048,7 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
||||
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
|
139
ggml-metal.metal
139
ggml-metal.metal
@ -39,6 +39,8 @@ typedef struct {
|
||||
int8_t qs[QK8_0]; // quants
|
||||
} block_q8_0;
|
||||
|
||||
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
||||
|
||||
// general-purpose kernel for addition of two tensors
|
||||
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
|
||||
// cons: not very efficient
|
||||
@ -207,54 +209,55 @@ kernel void kernel_soft_max(
|
||||
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
||||
}
|
||||
|
||||
float max = simd_max(lmax);
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = max;
|
||||
// find the max value in the block
|
||||
float max_val = simd_max(lmax);
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = -INFINITY;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = max_val;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
max_val = buf[tiisg];
|
||||
max_val = simd_max(max_val);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// broadcast, simd group number is ntg / 32
|
||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||
if (tpitg < i) {
|
||||
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
max = buf[0];
|
||||
|
||||
// parallel sum
|
||||
float lsum = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max);
|
||||
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
||||
lsum += exp_psrc0;
|
||||
// Remember the result of exp here. exp is expensive, so we really do not
|
||||
// wish to compute it twice.
|
||||
pdst[i00] = exp_psrc0;
|
||||
}
|
||||
|
||||
float sum = simd_sum(lsum);
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = sum;
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = sum;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sum = buf[tiisg];
|
||||
sum = simd_sum(sum);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// broadcast, simd group number is ntg / 32
|
||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||
if (tpitg < i) {
|
||||
buf[tpitg] += buf[tpitg + i];
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sum = buf[0];
|
||||
const float inv_sum = 1.0f/sum;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
pdst[i00] /= sum;
|
||||
pdst[i00] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
@ -288,53 +291,56 @@ kernel void kernel_soft_max_4(
|
||||
}
|
||||
|
||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||
float max = simd_max(lmax);
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = max;
|
||||
|
||||
float max_val = simd_max(lmax);
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = -INFINITY;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = max_val;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
max_val = buf[tiisg];
|
||||
max_val = simd_max(max_val);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// broadcast, simd group number is ntg / 32
|
||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||
if (tpitg < i) {
|
||||
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
max = buf[0];
|
||||
|
||||
// parallel sum
|
||||
float4 lsum4 = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max);
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
||||
lsum4 += exp_psrc4;
|
||||
pdst4[i00] = exp_psrc4;
|
||||
}
|
||||
|
||||
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
||||
float sum = simd_sum(lsum);
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = sum;
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = sum;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sum = buf[tiisg];
|
||||
sum = simd_sum(sum);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// broadcast, simd group number is ntg / 32
|
||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||
if (tpitg < i) {
|
||||
buf[tpitg] += buf[tpitg + i];
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sum = buf[0];
|
||||
const float inv_sum = 1.0f/sum;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
pdst4[i00] /= sum;
|
||||
pdst4[i00] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
@ -582,7 +588,6 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
||||
// putting them in the kernel cause a significant performance penalty
|
||||
#define N_DST 4 // each SIMD group works on 4 rows
|
||||
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
||||
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
||||
//Note: This is a template, but strictly speaking it only applies to
|
||||
// quantizations where the block size is 32. It also does not
|
||||
// giard against the number of rows not being divisible by
|
||||
|
Loading…
Reference in New Issue
Block a user