mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-28 12:57:03 +01:00
metal : trying bs = 512 performance (wip)
This commit is contained in:
parent
e8b00e2941
commit
5a668ea000
12
ggml-metal.m
12
ggml-metal.m
@ -1301,7 +1301,7 @@ static bool ggml_metal_graph_compute(
|
|||||||
|
|
||||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||||
if (src1t == GGML_TYPE_F32 && ne11 <= 8) {
|
if (src1t == GGML_TYPE_F32) {
|
||||||
id<MTLComputePipelineState> pipeline = nil;
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
@ -1340,12 +1340,12 @@ static bool ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
||||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
||||||
|
|
||||||
const int nsg = 8;
|
const int nsg = 4;
|
||||||
|
|
||||||
const int nsg0 = 1;
|
const int nsg0 = 4;
|
||||||
const int nsh0 = 16;
|
const int nsh0 = 4;
|
||||||
const int nsg1 = 1;
|
const int nsg1 = 2;
|
||||||
const int nsh1 = 64;
|
const int nsh1 = 4;
|
||||||
|
|
||||||
GGML_ASSERT(ne00 % 4 == 0); // for zeroing shared memory with half4 / float4
|
GGML_ASSERT(ne00 % 4 == 0); // for zeroing shared memory with half4 / float4
|
||||||
//GGML_ASSERT(ne00 % 16 == 0); // dequantize in chunks of 16
|
//GGML_ASSERT(ne00 % 16 == 0); // dequantize in chunks of 16
|
||||||
|
@ -4784,10 +4784,10 @@ void kernel_mul_mm_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define NSG0 1
|
#define NSG0 4
|
||||||
#define NSH0 16
|
#define NSH0 4
|
||||||
#define NSG1 1
|
#define NSG1 2
|
||||||
#define NSH1 64
|
#define NSH1 4
|
||||||
|
|
||||||
// each block_q contains 16*nl weights
|
// each block_q contains 16*nl weights
|
||||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||||
@ -4870,6 +4870,8 @@ void kernel_mul_mm2_impl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int i00 = 0; i00 < ne00; i00 += 8*NSH1) {
|
for (int i00 = 0; i00 < ne00; i00 += 8*NSH1) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory
|
// load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory
|
||||||
{
|
{
|
||||||
const int nload = MIN(8*NSG1, ne11 - i11) * (8*NSH1);
|
const int nload = MIN(8*NSG1, ne11 - i11) * (8*NSH1);
|
||||||
@ -4896,10 +4898,10 @@ void kernel_mul_mm2_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
for (int b0 = 0; b0 < NSH1/NSH0; ++b0) {
|
for (int b0 = 0; b0 < NSH1/NSH0; ++b0) {
|
||||||
// load NSG0*NSH0 8x8 blocks of src0 to threadgroup memory
|
// load NSG0*NSH0 8x8 blocks of src0 to threadgroup memory
|
||||||
{
|
{
|
||||||
@ -4945,6 +4947,7 @@ void kernel_mul_mm2_impl(
|
|||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
#if 0
|
||||||
#pragma unroll(NSH0)
|
#pragma unroll(NSH0)
|
||||||
for (int k = 0; k < NSH0; ++k) {
|
for (int k = 0; k < NSH0; ++k) {
|
||||||
for (int j = 0; j < NSG0; ++j) {
|
for (int j = 0; j < NSG0; ++j) {
|
||||||
@ -4961,9 +4964,22 @@ void kernel_mul_mm2_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
#else
|
||||||
|
#pragma unroll(NSH0)
|
||||||
|
for (int k = 0; k < NSH0; ++k) {
|
||||||
|
for (int i = 0; i < NSG1; ++i) {
|
||||||
|
simdgroup_load(m1[i], s1 + (8*i)*(8*NSH1) + 8*NSH0*b0 + 8*k, 8*NSH1, 0, true);
|
||||||
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
for (int j = 0; j < NSG0; ++j) {
|
||||||
|
simdgroup_load(m0[j], s0 + (8*j)*(8*NSH0) + 8*k, 8*NSH0);
|
||||||
|
for (int i = 0; i < NSG1; ++i) {
|
||||||
|
simdgroup_multiply_accumulate(mr[j][i], m0[j], m1[i], mr[j][i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// write the mr to shared memory
|
// write the mr to shared memory
|
||||||
|
@ -2075,7 +2075,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#elif 0
|
||||||
for (int r0 = 0; r0 < 32; ++r0) {
|
for (int r0 = 0; r0 < 32; ++r0) {
|
||||||
for (int c0 = 0; c0 < 4096; c0 += 512) {
|
for (int c0 = 0; c0 < 4096; c0 += 512) {
|
||||||
for (ggml_type type_a : {GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_F16}) {
|
for (ggml_type type_a : {GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_F16}) {
|
||||||
@ -2092,6 +2092,19 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#elif 1
|
||||||
|
for (ggml_type type_a : {GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_F16}) {
|
||||||
|
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||||
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1}));
|
||||||
|
}
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
for (ggml_type type_a : all_types) {
|
for (ggml_type type_a : all_types) {
|
||||||
|
Loading…
Reference in New Issue
Block a user