metal : warp-based reduction for soft max kernel

This commit is contained in:
Georgi Gerganov 2023-11-30 21:52:32 +02:00
parent 68e02c0d58
commit 55717c98c4
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 78 additions and 71 deletions

View File

@ -1028,12 +1028,14 @@ void ggml_metal_graph_compute(
int nth = 32; // SIMD width int nth = 32; // SIMD width
if (ne00%4 == 0) { if (ne00%4 == 0) {
while (nth < ne00/4 && nth < 256) {
nth *= 2;
}
[encoder setComputePipelineState:ctx->pipeline_soft_max_4]; [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
} else { } else {
do { while (nth < ne00 && nth < 1024) {
nth *= 2; nth *= 2;
} while (nth <= ne00 && nth <= 1024); }
nth /= 2;
[encoder setComputePipelineState:ctx->pipeline_soft_max]; [encoder setComputePipelineState:ctx->pipeline_soft_max];
} }
@ -1046,7 +1048,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
[encoder setBytes:&scale length:sizeof(scale) atIndex:6]; [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;

View File

@ -39,6 +39,8 @@ typedef struct {
int8_t qs[QK8_0]; // quants int8_t qs[QK8_0]; // quants
} block_q8_0; } block_q8_0;
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
// general-purpose kernel for addition of two tensors // general-purpose kernel for addition of two tensors
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
// cons: not very efficient // cons: not very efficient
@ -207,54 +209,55 @@ kernel void kernel_soft_max(
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)); lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
} }
float max = simd_max(lmax); // find the max value in the block
if (tiisg == 0) { float max_val = simd_max(lmax);
buf[sgitg] = max; if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = -INFINITY;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = max_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = buf[tiisg];
max_val = simd_max(max_val);
} }
threadgroup_barrier(mem_flags::mem_threadgroup);
// broadcast, simd group number is ntg / 32
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
if (tpitg < i) {
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max = buf[0];
// parallel sum // parallel sum
float lsum = 0.0f; float lsum = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) { for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max); const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
lsum += exp_psrc0; lsum += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not
// wish to compute it twice.
pdst[i00] = exp_psrc0; pdst[i00] = exp_psrc0;
} }
float sum = simd_sum(lsum); float sum = simd_sum(lsum);
if (tiisg == 0) { if (ntg > N_SIMDWIDTH) {
buf[sgitg] = sum; if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[tiisg];
sum = simd_sum(sum);
} }
threadgroup_barrier(mem_flags::mem_threadgroup); const float inv_sum = 1.0f/sum;
// broadcast, simd group number is ntg / 32
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
if (tpitg < i) {
buf[tpitg] += buf[tpitg + i];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[0];
for (int i00 = tpitg; i00 < ne00; i00 += ntg) { for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
pdst[i00] /= sum; pdst[i00] *= inv_sum;
} }
} }
@ -288,53 +291,56 @@ kernel void kernel_soft_max_4(
} }
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
float max = simd_max(lmax);
if (tiisg == 0) { float max_val = simd_max(lmax);
buf[sgitg] = max; if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = -INFINITY;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = max_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = buf[tiisg];
max_val = simd_max(max_val);
} }
threadgroup_barrier(mem_flags::mem_threadgroup);
// broadcast, simd group number is ntg / 32
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
if (tpitg < i) {
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max = buf[0];
// parallel sum // parallel sum
float4 lsum4 = 0.0f; float4 lsum4 = 0.0f;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max); const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
lsum4 += exp_psrc4; lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4; pdst4[i00] = exp_psrc4;
} }
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
float sum = simd_sum(lsum); float sum = simd_sum(lsum);
if (tiisg == 0) { if (ntg > N_SIMDWIDTH) {
buf[sgitg] = sum; if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[tiisg];
sum = simd_sum(sum);
} }
threadgroup_barrier(mem_flags::mem_threadgroup); const float inv_sum = 1.0f/sum;
// broadcast, simd group number is ntg / 32
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
if (tpitg < i) {
buf[tpitg] += buf[tpitg + i];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[0];
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
pdst4[i00] /= sum; pdst4[i00] *= inv_sum;
} }
} }
@ -582,7 +588,6 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
// putting them in the kernel cause a significant performance penalty // putting them in the kernel cause a significant performance penalty
#define N_DST 4 // each SIMD group works on 4 rows #define N_DST 4 // each SIMD group works on 4 rows
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
//Note: This is a template, but strictly speaking it only applies to //Note: This is a template, but strictly speaking it only applies to
// quantizations where the block size is 32. It also does not // quantizations where the block size is 32. It also does not
// giard against the number of rows not being divisible by // giard against the number of rows not being divisible by