mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-03 17:51:09 +01:00
metal : warp-based reduction for soft max kernel
This commit is contained in:
parent
68e02c0d58
commit
55717c98c4
10
ggml-metal.m
10
ggml-metal.m
@ -1028,12 +1028,14 @@ void ggml_metal_graph_compute(
|
|||||||
int nth = 32; // SIMD width
|
int nth = 32; // SIMD width
|
||||||
|
|
||||||
if (ne00%4 == 0) {
|
if (ne00%4 == 0) {
|
||||||
|
while (nth < ne00/4 && nth < 256) {
|
||||||
|
nth *= 2;
|
||||||
|
}
|
||||||
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
||||||
} else {
|
} else {
|
||||||
do {
|
while (nth < ne00 && nth < 1024) {
|
||||||
nth *= 2;
|
nth *= 2;
|
||||||
} while (nth <= ne00 && nth <= 1024);
|
}
|
||||||
nth /= 2;
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1046,7 +1048,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
||||||
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
139
ggml-metal.metal
139
ggml-metal.metal
@ -39,6 +39,8 @@ typedef struct {
|
|||||||
int8_t qs[QK8_0]; // quants
|
int8_t qs[QK8_0]; // quants
|
||||||
} block_q8_0;
|
} block_q8_0;
|
||||||
|
|
||||||
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
||||||
|
|
||||||
// general-purpose kernel for addition of two tensors
|
// general-purpose kernel for addition of two tensors
|
||||||
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
|
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
|
||||||
// cons: not very efficient
|
// cons: not very efficient
|
||||||
@ -207,54 +209,55 @@ kernel void kernel_soft_max(
|
|||||||
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
float max = simd_max(lmax);
|
// find the max value in the block
|
||||||
if (tiisg == 0) {
|
float max_val = simd_max(lmax);
|
||||||
buf[sgitg] = max;
|
if (ntg > N_SIMDWIDTH) {
|
||||||
|
if (sgitg == 0) {
|
||||||
|
buf[tiisg] = -INFINITY;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if (tiisg == 0) {
|
||||||
|
buf[sgitg] = max_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
max_val = buf[tiisg];
|
||||||
|
max_val = simd_max(max_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// broadcast, simd group number is ntg / 32
|
|
||||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
||||||
if (tpitg < i) {
|
|
||||||
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
max = buf[0];
|
|
||||||
|
|
||||||
// parallel sum
|
// parallel sum
|
||||||
float lsum = 0.0f;
|
float lsum = 0.0f;
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max);
|
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
||||||
lsum += exp_psrc0;
|
lsum += exp_psrc0;
|
||||||
// Remember the result of exp here. exp is expensive, so we really do not
|
|
||||||
// wish to compute it twice.
|
|
||||||
pdst[i00] = exp_psrc0;
|
pdst[i00] = exp_psrc0;
|
||||||
}
|
}
|
||||||
|
|
||||||
float sum = simd_sum(lsum);
|
float sum = simd_sum(lsum);
|
||||||
if (tiisg == 0) {
|
if (ntg > N_SIMDWIDTH) {
|
||||||
buf[sgitg] = sum;
|
if (sgitg == 0) {
|
||||||
|
buf[tiisg] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if (tiisg == 0) {
|
||||||
|
buf[sgitg] = sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
sum = buf[tiisg];
|
||||||
|
sum = simd_sum(sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
const float inv_sum = 1.0f/sum;
|
||||||
|
|
||||||
// broadcast, simd group number is ntg / 32
|
|
||||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
||||||
if (tpitg < i) {
|
|
||||||
buf[tpitg] += buf[tpitg + i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
sum = buf[0];
|
|
||||||
|
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
pdst[i00] /= sum;
|
pdst[i00] *= inv_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -288,53 +291,56 @@ kernel void kernel_soft_max_4(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||||
float max = simd_max(lmax);
|
|
||||||
if (tiisg == 0) {
|
float max_val = simd_max(lmax);
|
||||||
buf[sgitg] = max;
|
if (ntg > N_SIMDWIDTH) {
|
||||||
|
if (sgitg == 0) {
|
||||||
|
buf[tiisg] = -INFINITY;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if (tiisg == 0) {
|
||||||
|
buf[sgitg] = max_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
max_val = buf[tiisg];
|
||||||
|
max_val = simd_max(max_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// broadcast, simd group number is ntg / 32
|
|
||||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
||||||
if (tpitg < i) {
|
|
||||||
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
max = buf[0];
|
|
||||||
|
|
||||||
// parallel sum
|
// parallel sum
|
||||||
float4 lsum4 = 0.0f;
|
float4 lsum4 = 0.0f;
|
||||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max);
|
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
||||||
lsum4 += exp_psrc4;
|
lsum4 += exp_psrc4;
|
||||||
pdst4[i00] = exp_psrc4;
|
pdst4[i00] = exp_psrc4;
|
||||||
}
|
}
|
||||||
|
|
||||||
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
||||||
float sum = simd_sum(lsum);
|
float sum = simd_sum(lsum);
|
||||||
if (tiisg == 0) {
|
if (ntg > N_SIMDWIDTH) {
|
||||||
buf[sgitg] = sum;
|
if (sgitg == 0) {
|
||||||
|
buf[tiisg] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if (tiisg == 0) {
|
||||||
|
buf[sgitg] = sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
sum = buf[tiisg];
|
||||||
|
sum = simd_sum(sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
const float inv_sum = 1.0f/sum;
|
||||||
|
|
||||||
// broadcast, simd group number is ntg / 32
|
|
||||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
||||||
if (tpitg < i) {
|
|
||||||
buf[tpitg] += buf[tpitg + i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
sum = buf[0];
|
|
||||||
|
|
||||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
pdst4[i00] /= sum;
|
pdst4[i00] *= inv_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -582,7 +588,6 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|||||||
// putting them in the kernel cause a significant performance penalty
|
// putting them in the kernel cause a significant performance penalty
|
||||||
#define N_DST 4 // each SIMD group works on 4 rows
|
#define N_DST 4 // each SIMD group works on 4 rows
|
||||||
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
||||||
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
|
||||||
//Note: This is a template, but strictly speaking it only applies to
|
//Note: This is a template, but strictly speaking it only applies to
|
||||||
// quantizations where the block size is 32. It also does not
|
// quantizations where the block size is 32. It also does not
|
||||||
// giard against the number of rows not being divisible by
|
// giard against the number of rows not being divisible by
|
||||||
|
Loading…
Reference in New Issue
Block a user