mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-15 06:40:45 +01:00
metal : opts
This commit is contained in:
parent
92a0c17474
commit
e68e32548f
@ -1343,7 +1343,7 @@ static bool ggml_metal_graph_compute(
|
||||
const int nsg = 8;
|
||||
|
||||
const int nsg0 = 1;
|
||||
const int nsh0 = 8;
|
||||
const int nsh0 = 16;
|
||||
const int nsg1 = 1;
|
||||
const int nsh1 = 64;
|
||||
|
||||
|
110
ggml-metal.metal
110
ggml-metal.metal
@ -4785,7 +4785,7 @@ void kernel_mul_mm_impl(
|
||||
}
|
||||
|
||||
#define NSG0 1
|
||||
#define NSH0 8
|
||||
#define NSH0 16
|
||||
#define NSG1 1
|
||||
#define NSH1 64
|
||||
|
||||
@ -4815,33 +4815,34 @@ void kernel_mul_mm2_impl(
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
const uint nsg = ntg.y; // number of simdgroups
|
||||
|
||||
const int64_t im = tgpig[2];
|
||||
const int64_t i11 = tgpig[1]*(8*NSG1);
|
||||
const int64_t i01 = tgpig[0]*(8*NSG0*nsg) + sgitg*(8*NSG0);
|
||||
const int im = tgpig[2];
|
||||
const int i11 = tgpig[1]*(8*NSG1);
|
||||
const int i01 = tgpig[0]*(8*NSG0*nsg) + sgitg*(8*NSG0);
|
||||
|
||||
const int64_t i12 = im%ne12;
|
||||
const int64_t i13 = im/ne12;
|
||||
const int i12 = im%ne12;
|
||||
const int i13 = im/ne12;
|
||||
|
||||
const int64_t ne01 = ne0;
|
||||
const int64_t ne11 = ne1;
|
||||
const int ne01 = ne0;
|
||||
const int ne11 = ne1;
|
||||
|
||||
const int64_t NW = N_SIMDWIDTH;
|
||||
const int NW = N_SIMDWIDTH;
|
||||
|
||||
const int64_t SH0 = (8*NSG0)*(8*NSH0); // shread memory per threadgroup for src0 data in (half)
|
||||
const int64_t SH04 = SH0/4; // shread memory per threadgroup for src0 data in (half4)
|
||||
const int SH0 = (8*NSG0)*(8*NSH0); // shread memory per threadgroup for src0 data in (half)
|
||||
const int SH04 = SH0/4; // shread memory per threadgroup for src0 data in (half4)
|
||||
|
||||
const int64_t SH1 = (8*NSG1)*(8*NSH1); // shread memory for src1 data in (float)
|
||||
const int64_t SH14 = SH1/4; // shread memory for src1 data in (float4)
|
||||
const int SH1 = (8*NSG1)*(8*NSH1); // shread memory for src1 data in (float)
|
||||
const int SH14 = SH1/4; // shread memory for src1 data in (float4)
|
||||
|
||||
const int64_t T1 = 8*NSH1; // row of src1 in shared memory in (float)
|
||||
const int64_t T14 = T1/4; // row of src1 in shared memory in (float4)
|
||||
const int T1 = 8*NSH1; // row of src1 in shared memory in (float)
|
||||
const int T14 = T1/4; // row of src1 in shared memory in (float4)
|
||||
|
||||
threadgroup half * shared = (threadgroup half *) shared_u8;
|
||||
|
||||
threadgroup half * s0 = (threadgroup half *)(shared + sgitg*SH0);
|
||||
threadgroup half4 * s04 = (threadgroup half4 *)(shared + sgitg*SH0);
|
||||
threadgroup float * s1 = (threadgroup float *)(shared + nsg*SH0);
|
||||
threadgroup float4 * s14 = (threadgroup float4 *)(shared + nsg*SH0);
|
||||
threadgroup half * s0 = (threadgroup half *)(shared + sgitg*SH0);
|
||||
threadgroup half4 * s04 = (threadgroup half4 *)(shared + sgitg*SH0);
|
||||
threadgroup half4x4 * s016 = (threadgroup half4x4 *)(shared + sgitg*SH0);
|
||||
threadgroup float * s1 = (threadgroup float *)(shared + nsg*SH0);
|
||||
threadgroup float4 * s14 = (threadgroup float4 *)(shared + nsg*SH0);
|
||||
|
||||
threadgroup float * r0 = (threadgroup float *)(shared + 2*sgitg*(8*NSG0)*(8*NSG1));
|
||||
|
||||
@ -4850,12 +4851,12 @@ void kernel_mul_mm2_impl(
|
||||
simdgroup_float8x8 mr[NSG0][NSG1];
|
||||
|
||||
// zero out shared memory SH0 for src0
|
||||
for (int64_t i = tiisg; i < SH04; i += NW) {
|
||||
for (int i = tiisg; i < SH04; i += NW) {
|
||||
s04[i] = 0.0h;
|
||||
}
|
||||
|
||||
// zero out shared memory SH1 for src1
|
||||
for (int64_t i = tiitg; i < SH14; i += nsg*NW) {
|
||||
for (int i = tiitg; i < SH14; i += nsg*NW) {
|
||||
s14[i] = 0.0f;
|
||||
}
|
||||
|
||||
@ -4868,24 +4869,27 @@ void kernel_mul_mm2_impl(
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i00 = 0; i00 < ne00; i00 += 8*NSH1) {
|
||||
for (int i00 = 0; i00 < ne00; i00 += 8*NSH1) {
|
||||
// load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory
|
||||
{
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const int64_t nload = min(8ll*NSG1, ne11 - i11) * (8*NSH1);
|
||||
const int nload = MIN(8*NSG1, ne11 - i11) * (8*NSH1);
|
||||
|
||||
for (int64_t i = tiitg; i < nload; i += nsg*NW) {
|
||||
const int64_t ic = i%(8*NSH1);
|
||||
const int64_t ir = i/(8*NSH1);
|
||||
const size_t offs0 = im*nb12;
|
||||
|
||||
// TODO: use float4
|
||||
device const float * p1 = (device const float *)(src1 + im*nb12 + (i11 + ir)*nb11 + (i00 + ic)*nb10);
|
||||
for (int i = 4*tiitg; i < nload; i += 4*nsg*NW) {
|
||||
const int ic = i%(8*NSH1);
|
||||
const int ir = i/(8*NSH1);
|
||||
|
||||
if (i00 + ic < ne00) {
|
||||
s1[8*NSH1*ir + ic] = *p1;
|
||||
device const float4 * p1 = (device const float4 *)(src1 + offs0 + (i11 + ir)*nb11 + (i00 + ic)*nb10);
|
||||
|
||||
if (i00 + ic + 4 <= ne00) {
|
||||
s14[(8*NSH1*ir + ic)/4] = *p1;
|
||||
} else {
|
||||
s1[8*NSH1*ir + ic] = 0.0f;
|
||||
for (int k = 0; i00 + ic + k < ne00; k++){
|
||||
s1[8*NSH1*ir + ic + k] = (*p1)[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -4895,28 +4899,36 @@ void kernel_mul_mm2_impl(
|
||||
for (int b0 = 0; b0 < NSH1/NSH0; ++b0) {
|
||||
// load NSG0*NSH0 8x8 blocks of src0 to threadgroup memory
|
||||
{
|
||||
const int64_t nload = min(8ll*NSG0, ne01 - i01) * (8*NSH0);
|
||||
const int nload = MIN(8*NSG0, ne01 - i01) * (8*NSH0);
|
||||
|
||||
half4x4 tmp0;
|
||||
|
||||
for (int64_t i = 16*tiisg; i < nload; i += 16*NW) {
|
||||
const int64_t ic = i%(8*NSH0);
|
||||
const int64_t ir = i/(8*NSH0);
|
||||
const size_t offs0 = (i13/r3)*(nb02*ne02) + (i12/r2)*nb02;
|
||||
|
||||
const int64_t icc = i00 + 8*b0*NSH0 + ic;
|
||||
for (int i = 16*tiisg; i < nload; i += 16*NW) {
|
||||
const int ic = i%(8*NSH0);
|
||||
const int ir = i/(8*NSH0);
|
||||
|
||||
const int64_t ib = (icc/(16*nl));
|
||||
const int64_t il = (icc%(16*nl))/16;
|
||||
const int icc = i00 + 8*b0*NSH0 + ic;
|
||||
|
||||
device const block_q * p0 = (device const block_q *)(src0 + (i13/r3)*(nb02*ne02) + (i12/r2)*nb02 + (i01 + ir)*nb01) + ib;
|
||||
const int ib = (icc/(16*nl));
|
||||
const int il = (icc%(16*nl))/16;
|
||||
|
||||
device const block_q * p0 = (device const block_q *)(src0 + offs0 + (i01 + ir)*nb01) + ib;
|
||||
|
||||
dequantize_func(p0, il, tmp0);
|
||||
|
||||
for (int k = 0; k < 4; k++){
|
||||
if (icc + 4*k < ne00) {
|
||||
s04[(8*NSH0*ir + ic)/4 + k] = tmp0[k];
|
||||
} else {
|
||||
s04[(8*NSH0*ir + ic)/4 + k] = 0.0h;
|
||||
if (icc + 16 <= ne00) {
|
||||
s016[(8*NSH0*ir + ic)/16] = tmp0;
|
||||
} else {
|
||||
for (int k = 0; k < 4; k++){
|
||||
if (icc + 4*k <= ne00) {
|
||||
s04[(8*NSH0*ir + ic)/4 + k] = tmp0[k];
|
||||
} else {
|
||||
for (int p = 0; icc + 4*k + p < ne00; p++) {
|
||||
s0[8*NSH0*ir + ic + 4*k + p] = tmp0[k][p];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -4958,12 +4970,12 @@ void kernel_mul_mm2_impl(
|
||||
device float * pdst = dst + im*ne1*ne0;
|
||||
|
||||
for (int is = 0; is < NSG1; is++) {
|
||||
const int64_t i1 = i11 + is*8;
|
||||
const int64_t nstore = min(8ll*NSG1, ne1 - i1) * (8*NSG0);
|
||||
const int i1 = i11 + is*8;
|
||||
const int nstore = MIN(8*NSG1, ne1 - i1) * (8*NSG0);
|
||||
|
||||
for (int64_t i = tiisg; i < nstore; i += NW) {
|
||||
const int64_t ic = i%(8*NSG0);
|
||||
const int64_t ir = i/(8*NSG0);
|
||||
for (int i = tiisg; i < nstore; i += NW) {
|
||||
const int ic = i%(8*NSG0);
|
||||
const int ir = i/(8*NSG0);
|
||||
|
||||
if (i1 + ir < ne1 && i01 + ic < ne0) {
|
||||
pdst[(i1 + ir)*ne0 + (i01 + ic)] = r0[(8*is)*(8*NSG0) + 8*NSG0*ir + ic];
|
||||
|
Loading…
Reference in New Issue
Block a user