mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-14 22:38:58 +01:00
metal : initial working version
This commit is contained in:
parent
099afc6274
commit
92a0c17474
91
ggml-metal.m
91
ggml-metal.m
@ -116,6 +116,21 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM2_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM2_F16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM2_Q4_0_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM2_Q4_1_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM2_Q5_0_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM2_Q5_1_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM2_Q8_0_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM2_Q2_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM2_Q3_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM2_Q4_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM2_Q5_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM2_Q6_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM2_IQ2_XXS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM2_IQ2_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM2_IQ3_XXS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
||||
@ -488,6 +503,21 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_F32_F32, mul_mm2_f32_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_F16_F32, mul_mm2_f16_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_Q4_0_F32, mul_mm2_q4_0_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_Q4_1_F32, mul_mm2_q4_1_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_Q5_0_F32, mul_mm2_q5_0_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_Q5_1_F32, mul_mm2_q5_1_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_Q8_0_F32, mul_mm2_q8_0_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_Q2_K_F32, mul_mm2_q2_K_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_Q3_K_F32, mul_mm2_q3_K_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_Q4_K_F32, mul_mm2_q4_K_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_Q5_K_F32, mul_mm2_q5_K_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_Q6_K_F32, mul_mm2_q6_K_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_IQ2_XXS_F32, mul_mm2_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_IQ2_XS_F32, mul_mm2_iq2_xs_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM2_IQ3_XXS_F32, mul_mm2_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
||||
@ -1271,7 +1301,66 @@ static bool ggml_metal_graph_compute(
|
||||
|
||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
||||
if (src1t == GGML_TYPE_F32 && ne11 <= 8) {
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_F32_F32 ].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_F16_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_Q4_0_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_Q4_1_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_Q5_0_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_Q5_1_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_Q8_0_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_Q2_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_Q3_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_Q4_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_Q5_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_Q6_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_IQ2_XXS_F32].pipeline; break;
|
||||
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_IQ2_XS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM2_IQ3_XXS_F32].pipeline; break;
|
||||
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
||||
|
||||
const int nsg = 8;
|
||||
|
||||
const int nsg0 = 1;
|
||||
const int nsh0 = 8;
|
||||
const int nsg1 = 1;
|
||||
const int nsh1 = 64;
|
||||
|
||||
GGML_ASSERT(ne00 % 4 == 0); // for zeroing shared memory with half4 / float4
|
||||
GGML_ASSERT(ne00 % 16 == 0); // dequantize in chunks of 16
|
||||
GGML_ASSERT(nsh0 % 2 == 0); // dequantize in chunks of 2x8 = 16
|
||||
GGML_ASSERT(nsh1 % nsh0 == 0);
|
||||
GGML_ASSERT(nsh0 >= 2*nsg1); // need enough memory to store the results in f32
|
||||
|
||||
const size_t shmem = nsg*(8*nsg0)*(8*nsh0)*(sizeof(float)/2) + (8*nsg1)*(8*nsh1)*sizeof(float);
|
||||
|
||||
GGML_ASSERT(shmem <= 32*1024);
|
||||
GGML_ASSERT(shmem >= nsg*(8*nsg0)*(8*nsg1)*sizeof(float));
|
||||
|
||||
[encoder setThreadgroupMemoryLength:shmem atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 8*nsg0*nsg - 1)/(8*nsg0*nsg), (ne11 + 8*nsg1 - 1)/(8*nsg1), ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
} else if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
||||
!ggml_is_transposed(src0) &&
|
||||
!ggml_is_transposed(src1) &&
|
||||
src1t == GGML_TYPE_F32 &&
|
||||
|
369
ggml-metal.metal
369
ggml-metal.metal
@ -4650,25 +4650,28 @@ kernel void kernel_get_rows_i32(
|
||||
|
||||
// each block_q contains 16*nl weights
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||
void kernel_mul_mm_impl(device const uchar * src0,
|
||||
device const uchar * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
void kernel_mul_mm_impl(
|
||||
device const uchar * src0,
|
||||
device const uchar * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 ntg[[threads_per_threadgroup]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
||||
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
||||
@ -4781,6 +4784,194 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
||||
}
|
||||
}
|
||||
|
||||
#define NSG0 1
|
||||
#define NSH0 8
|
||||
#define NSG1 1
|
||||
#define NSH1 64
|
||||
|
||||
// each block_q contains 16*nl weights
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||
void kernel_mul_mm2_impl(
|
||||
device const uchar * src0,
|
||||
device const uchar * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
threadgroup uchar * shared_u8 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 ntg[[threads_per_threadgroup]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
const uint nsg = ntg.y; // number of simdgroups
|
||||
|
||||
const int64_t im = tgpig[2];
|
||||
const int64_t i11 = tgpig[1]*(8*NSG1);
|
||||
const int64_t i01 = tgpig[0]*(8*NSG0*nsg) + sgitg*(8*NSG0);
|
||||
|
||||
const int64_t i12 = im%ne12;
|
||||
const int64_t i13 = im/ne12;
|
||||
|
||||
const int64_t ne01 = ne0;
|
||||
const int64_t ne11 = ne1;
|
||||
|
||||
const int64_t NW = N_SIMDWIDTH;
|
||||
|
||||
const int64_t SH0 = (8*NSG0)*(8*NSH0); // shread memory per threadgroup for src0 data in (half)
|
||||
const int64_t SH04 = SH0/4; // shread memory per threadgroup for src0 data in (half4)
|
||||
|
||||
const int64_t SH1 = (8*NSG1)*(8*NSH1); // shread memory for src1 data in (float)
|
||||
const int64_t SH14 = SH1/4; // shread memory for src1 data in (float4)
|
||||
|
||||
const int64_t T1 = 8*NSH1; // row of src1 in shared memory in (float)
|
||||
const int64_t T14 = T1/4; // row of src1 in shared memory in (float4)
|
||||
|
||||
threadgroup half * shared = (threadgroup half *) shared_u8;
|
||||
|
||||
threadgroup half * s0 = (threadgroup half *)(shared + sgitg*SH0);
|
||||
threadgroup half4 * s04 = (threadgroup half4 *)(shared + sgitg*SH0);
|
||||
threadgroup float * s1 = (threadgroup float *)(shared + nsg*SH0);
|
||||
threadgroup float4 * s14 = (threadgroup float4 *)(shared + nsg*SH0);
|
||||
|
||||
threadgroup float * r0 = (threadgroup float *)(shared + 2*sgitg*(8*NSG0)*(8*NSG1));
|
||||
|
||||
simdgroup_half8x8 m0[NSG0];
|
||||
simdgroup_float8x8 m1[NSG1];
|
||||
simdgroup_float8x8 mr[NSG0][NSG1];
|
||||
|
||||
// zero out shared memory SH0 for src0
|
||||
for (int64_t i = tiisg; i < SH04; i += NW) {
|
||||
s04[i] = 0.0h;
|
||||
}
|
||||
|
||||
// zero out shared memory SH1 for src1
|
||||
for (int64_t i = tiitg; i < SH14; i += nsg*NW) {
|
||||
s14[i] = 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// initialize mr
|
||||
for (int j = 0; j < NSG0; j++) {
|
||||
for (int i = 0; i < NSG1; i++) {
|
||||
mr[j][i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i00 = 0; i00 < ne00; i00 += 8*NSH1) {
|
||||
// load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory
|
||||
{
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const int64_t nload = min(8ll*NSG1, ne11 - i11) * (8*NSH1);
|
||||
|
||||
for (int64_t i = tiitg; i < nload; i += nsg*NW) {
|
||||
const int64_t ic = i%(8*NSH1);
|
||||
const int64_t ir = i/(8*NSH1);
|
||||
|
||||
// TODO: use float4
|
||||
device const float * p1 = (device const float *)(src1 + im*nb12 + (i11 + ir)*nb11 + (i00 + ic)*nb10);
|
||||
|
||||
if (i00 + ic < ne00) {
|
||||
s1[8*NSH1*ir + ic] = *p1;
|
||||
} else {
|
||||
s1[8*NSH1*ir + ic] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
for (int b0 = 0; b0 < NSH1/NSH0; ++b0) {
|
||||
// load NSG0*NSH0 8x8 blocks of src0 to threadgroup memory
|
||||
{
|
||||
const int64_t nload = min(8ll*NSG0, ne01 - i01) * (8*NSH0);
|
||||
|
||||
half4x4 tmp0;
|
||||
|
||||
for (int64_t i = 16*tiisg; i < nload; i += 16*NW) {
|
||||
const int64_t ic = i%(8*NSH0);
|
||||
const int64_t ir = i/(8*NSH0);
|
||||
|
||||
const int64_t icc = i00 + 8*b0*NSH0 + ic;
|
||||
|
||||
const int64_t ib = (icc/(16*nl));
|
||||
const int64_t il = (icc%(16*nl))/16;
|
||||
|
||||
device const block_q * p0 = (device const block_q *)(src0 + (i13/r3)*(nb02*ne02) + (i12/r2)*nb02 + (i01 + ir)*nb01) + ib;
|
||||
|
||||
dequantize_func(p0, il, tmp0);
|
||||
|
||||
for (int k = 0; k < 4; k++){
|
||||
if (icc + 4*k < ne00) {
|
||||
s04[(8*NSH0*ir + ic)/4 + k] = tmp0[k];
|
||||
} else {
|
||||
s04[(8*NSH0*ir + ic)/4 + k] = 0.0h;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
#pragma unroll(NSH0)
|
||||
for (int k = 0; k < NSH0; ++k) {
|
||||
for (int j = 0; j < NSG0; ++j) {
|
||||
simdgroup_load(m0[j], s0 + (8*j)*(8*NSH0) + 8*k, 8*NSH0);
|
||||
}
|
||||
|
||||
for (int i = 0; i < NSG1; ++i) {
|
||||
simdgroup_load(m1[i], s1 + (8*i)*(8*NSH1) + 8*NSH0*b0 + 8*k, 8*NSH1, 0, true);
|
||||
}
|
||||
|
||||
for (int j = 0; j < NSG0; ++j) {
|
||||
for (int i = 0; i < NSG1; ++i) {
|
||||
simdgroup_multiply_accumulate(mr[j][i], m0[j], m1[i], mr[j][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// write the mr to shared memory
|
||||
|
||||
for (int i = 0; i < NSG1; i++) {
|
||||
for (int j = 0; j < NSG0; j++) {
|
||||
simdgroup_store(mr[j][i], r0 + (8*i)*(8*NSG0) + 8*j, 8*NSG0, 0, true);
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
device float * pdst = dst + im*ne1*ne0;
|
||||
|
||||
for (int is = 0; is < NSG1; is++) {
|
||||
const int64_t i1 = i11 + is*8;
|
||||
const int64_t nstore = min(8ll*NSG1, ne1 - i1) * (8*NSG0);
|
||||
|
||||
for (int64_t i = tiisg; i < nstore; i += NW) {
|
||||
const int64_t ic = i%(8*NSG0);
|
||||
const int64_t ir = i/(8*NSG0);
|
||||
|
||||
if (i1 + ir < ne1 && i01 + ic < ne0) {
|
||||
pdst[(i1 + ir)*ne0 + (i01 + ic)] = r0[(8*is)*(8*NSG0) + 8*NSG0*ir + ic];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||
void kernel_mul_mm_id_impl(
|
||||
@ -4802,7 +4993,9 @@ void kernel_mul_mm_id_impl(
|
||||
constant uint & r3,
|
||||
threadgroup uchar * shared_memory,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 ntg[[threads_per_threadgroup]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
||||
@ -4907,25 +5100,28 @@ void kernel_mul_mm_id_impl(
|
||||
}
|
||||
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||
kernel void kernel_mul_mm(device const uchar * src0,
|
||||
device const uchar * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
kernel void kernel_mul_mm(
|
||||
device const uchar * src0,
|
||||
device const uchar * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 ntg[[threads_per_threadgroup]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
||||
src0,
|
||||
src1,
|
||||
@ -4944,7 +5140,56 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||
r3,
|
||||
shared_memory,
|
||||
tgpig,
|
||||
ntg,
|
||||
tiitg,
|
||||
tiisg,
|
||||
sgitg);
|
||||
}
|
||||
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||
kernel void kernel_mul_mm2(
|
||||
device const uchar * src0,
|
||||
device const uchar * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 ntg[[threads_per_threadgroup]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
kernel_mul_mm2_impl<block_q, nl, dequantize_func>(
|
||||
src0,
|
||||
src1,
|
||||
dst,
|
||||
ne00,
|
||||
ne02,
|
||||
nb01,
|
||||
nb02,
|
||||
ne12,
|
||||
nb10,
|
||||
nb11,
|
||||
nb12,
|
||||
ne0,
|
||||
ne1,
|
||||
r2,
|
||||
r3,
|
||||
shared_memory,
|
||||
tgpig,
|
||||
ntg,
|
||||
tiitg,
|
||||
tiisg,
|
||||
sgitg);
|
||||
}
|
||||
|
||||
@ -4979,7 +5224,9 @@ kernel void kernel_mul_mm_id(
|
||||
device const uchar * src07,
|
||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 ntg[[threads_per_threadgroup]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
@ -5017,7 +5264,9 @@ kernel void kernel_mul_mm_id(
|
||||
r3,
|
||||
shared_memory,
|
||||
tgpig,
|
||||
ntg,
|
||||
tiitg,
|
||||
tiisg,
|
||||
sgitg);
|
||||
}
|
||||
|
||||
@ -5082,24 +5331,40 @@ typedef void (mat_mm_t)(
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
threadgroup uchar *,
|
||||
uint3, uint, uint);
|
||||
uint3, uint3, uint, uint, uint);
|
||||
|
||||
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
|
||||
template [[host_name("kernel_mul_mm2_f32_f32")]] kernel mat_mm_t kernel_mul_mm2<float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm2_f16_f32")]] kernel mat_mm_t kernel_mul_mm2<half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_mul_mm2_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm2<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm2_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm2<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm2_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm2<block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_mul_mm2_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm2<block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_mul_mm2_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm2<block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm2_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm2<block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mm2_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm2<block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mm2_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm2<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_mul_mm2_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm2<block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_mul_mm2_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm2<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mm2_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm2<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||
template [[host_name("kernel_mul_mm2_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm2<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_mul_mm2_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm2<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
|
||||
//
|
||||
// indirect matrix-matrix multiplication
|
||||
//
|
||||
@ -5133,7 +5398,7 @@ typedef void (mat_mm_id_t)(
|
||||
device const uchar * src06,
|
||||
device const uchar * src07,
|
||||
threadgroup uchar *,
|
||||
uint3, uint, uint);
|
||||
uint3, uint3, uint, uint, uint);
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
||||
|
@ -480,12 +480,13 @@ struct test_case {
|
||||
|
||||
double err = nmse(f1.data(), f2.data(), f1.size());
|
||||
if (err > ud->max_err) {
|
||||
printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
|
||||
//for (int i = 0; i < (int) f1.size(); i++) {
|
||||
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
|
||||
//}
|
||||
//printf("\n");
|
||||
//exit(1);
|
||||
printf("[%s] NMSE = %.9f > %.9f", ggml_op_desc(t1), err, ud->max_err);
|
||||
printf("\n");
|
||||
for (int i = 0; i < (int) f1.size(); i++) {
|
||||
printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
|
||||
}
|
||||
printf("\n");
|
||||
exit(1);
|
||||
ud->ok = false;
|
||||
}
|
||||
return true;
|
||||
@ -572,9 +573,19 @@ struct test_case {
|
||||
// duplicate the op
|
||||
size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU
|
||||
int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1;
|
||||
#if 0
|
||||
for (int i = 1; i < n_runs; i++) {
|
||||
gf->nodes[gf->n_nodes++] = out;
|
||||
}
|
||||
#else
|
||||
n_runs = 256;
|
||||
int n_nodes = gf->n_nodes;
|
||||
for (int i = 0; i < n_runs; i++) {
|
||||
for (int j = 0; j < n_nodes; j++) {
|
||||
gf->nodes[gf->n_nodes++] = gf->nodes[j];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// calculate memory
|
||||
size_t mem = n_runs * op_size(out);
|
||||
@ -2044,6 +2055,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
|
||||
}
|
||||
|
||||
#if 0
|
||||
for (ggml_type type_a : all_types) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
|
||||
@ -2063,6 +2075,20 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (ggml_type type_a : {GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_F16}) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 1, 4096, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 2, 4096, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 3, 4096, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 4, 4096, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 5, 4096, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 6, 4096, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 7, 4096, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 8, 4096, { 1, 1}, {1, 1}));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
for (ggml_type type_a : all_types) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
|
||||
|
Loading…
Reference in New Issue
Block a user