mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 05:17:21 +01:00
parent
ff8f9a88da
commit
e16b9fa4ba
@ -1001,11 +1001,15 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
{
|
{
|
||||||
const int nth = MIN(32, ne00);
|
int nth = 32; // SIMD width
|
||||||
|
|
||||||
if (ne00%4 == 0) {
|
if (ne00%4 == 0) {
|
||||||
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
||||||
} else {
|
} else {
|
||||||
|
do {
|
||||||
|
nth *= 2;
|
||||||
|
} while (nth <= ne00 && nth <= 1024);
|
||||||
|
nth /= 2;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
||||||
}
|
}
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
@ -1013,8 +1017,9 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||||
|
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
{
|
{
|
||||||
|
127
ggml-metal.metal
127
ggml-metal.metal
@ -184,36 +184,73 @@ kernel void kernel_soft_max(
|
|||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
threadgroup float * buf [[threadgroup(0)]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
const int64_t i03 = tgpig[2];
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
const int64_t i02 = tgpig[1];
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
const int64_t i01 = tgpig[0];
|
uint ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int64_t i03 = (tgpig) / (ne02*ne01);
|
||||||
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||||
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||||
|
|
||||||
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
|
float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
|
||||||
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
|
|
||||||
|
for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
|
||||||
lmax = MAX(lmax, psrc0[i00]);
|
lmax = MAX(lmax, psrc0[i00]);
|
||||||
}
|
}
|
||||||
const float max = simd_max(lmax);
|
|
||||||
|
float max = simd_max(lmax);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
buf[sgitg] = max;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// broadcast, simd group number is ntg / 32
|
||||||
|
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||||
|
if (tpitg < i) {
|
||||||
|
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
max = buf[0];
|
||||||
|
|
||||||
// parallel sum
|
// parallel sum
|
||||||
float lsum = 0.0f;
|
float lsum = 0.0f;
|
||||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
const float exp_psrc0 = exp(psrc0[i00] - max);
|
const float exp_psrc0 = exp(psrc0[i00] - max);
|
||||||
lsum += exp_psrc0;
|
lsum += exp_psrc0;
|
||||||
// Remember the result of exp here. exp is expensive, so we really do not
|
// Remember the result of exp here. exp is expensive, so we really do not
|
||||||
// whish to compute it twice.
|
// wish to compute it twice.
|
||||||
pdst[i00] = exp_psrc0;
|
pdst[i00] = exp_psrc0;
|
||||||
}
|
}
|
||||||
|
|
||||||
const float sum = simd_sum(lsum);
|
float sum = simd_sum(lsum);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
buf[sgitg] = sum;
|
||||||
|
}
|
||||||
|
|
||||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// broadcast, simd group number is ntg / 32
|
||||||
|
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||||
|
if (tpitg < i) {
|
||||||
|
buf[tpitg] += buf[tpitg + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
sum = buf[0];
|
||||||
|
|
||||||
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
pdst[i00] /= sum;
|
pdst[i00] /= sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -224,37 +261,73 @@ kernel void kernel_soft_max_4(
|
|||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
threadgroup float * buf [[threadgroup(0)]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
const int64_t i03 = tgpig[2];
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
const int64_t i02 = tgpig[1];
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
const int64_t i01 = tgpig[0];
|
uint ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int64_t i03 = (tgpig) / (ne02*ne01);
|
||||||
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||||
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||||
|
|
||||||
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
|
float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
|
||||||
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
|
||||||
|
for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
|
||||||
lmax4 = fmax(lmax4, psrc4[i00]);
|
lmax4 = fmax(lmax4, psrc4[i00]);
|
||||||
}
|
}
|
||||||
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
||||||
|
|
||||||
const float max = simd_max(lmax);
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||||
|
float max = simd_max(lmax);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
buf[sgitg] = max;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// broadcast, simd group number is ntg / 32
|
||||||
|
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||||
|
if (tpitg < i) {
|
||||||
|
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
max = buf[0];
|
||||||
|
|
||||||
// parallel sum
|
// parallel sum
|
||||||
float4 lsum4 = 0.0f;
|
float4 lsum4 = 0.0f;
|
||||||
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
const float4 exp_psrc4 = exp(psrc4[i00] - max);
|
const float4 exp_psrc4 = exp(psrc4[i00] - max);
|
||||||
lsum4 += exp_psrc4;
|
lsum4 += exp_psrc4;
|
||||||
pdst4[i00] = exp_psrc4;
|
pdst4[i00] = exp_psrc4;
|
||||||
}
|
}
|
||||||
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
|
||||||
|
|
||||||
const float sum = simd_sum(lsum);
|
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
||||||
|
float sum = simd_sum(lsum);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
buf[sgitg] = sum;
|
||||||
|
}
|
||||||
|
|
||||||
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// broadcast, simd group number is ntg / 32
|
||||||
|
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||||
|
if (tpitg < i) {
|
||||||
|
buf[tpitg] += buf[tpitg + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
sum = buf[0];
|
||||||
|
|
||||||
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
pdst4[i00] /= sum;
|
pdst4[i00] /= sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user