mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-15 06:40:45 +01:00
metal : works with ne00 % 4 == 0
This commit is contained in:
parent
e68e32548f
commit
845876d012
@ -1348,7 +1348,7 @@ static bool ggml_metal_graph_compute(
|
|||||||
const int nsh1 = 64;
|
const int nsh1 = 64;
|
||||||
|
|
||||||
GGML_ASSERT(ne00 % 4 == 0); // for zeroing shared memory with half4 / float4
|
GGML_ASSERT(ne00 % 4 == 0); // for zeroing shared memory with half4 / float4
|
||||||
GGML_ASSERT(ne00 % 16 == 0); // dequantize in chunks of 16
|
//GGML_ASSERT(ne00 % 16 == 0); // dequantize in chunks of 16
|
||||||
GGML_ASSERT(nsh0 % 2 == 0); // dequantize in chunks of 2x8 = 16
|
GGML_ASSERT(nsh0 % 2 == 0); // dequantize in chunks of 2x8 = 16
|
||||||
GGML_ASSERT(nsh1 % nsh0 == 0);
|
GGML_ASSERT(nsh1 % nsh0 == 0);
|
||||||
GGML_ASSERT(nsh0 >= 2*nsg1); // need enough memory to store the results in f32
|
GGML_ASSERT(nsh0 >= 2*nsg1); // need enough memory to store the results in f32
|
||||||
|
@ -4872,8 +4872,6 @@ void kernel_mul_mm2_impl(
|
|||||||
for (int i00 = 0; i00 < ne00; i00 += 8*NSH1) {
|
for (int i00 = 0; i00 < ne00; i00 += 8*NSH1) {
|
||||||
// load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory
|
// load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory
|
||||||
{
|
{
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
const int nload = MIN(8*NSG1, ne11 - i11) * (8*NSH1);
|
const int nload = MIN(8*NSG1, ne11 - i11) * (8*NSH1);
|
||||||
|
|
||||||
const size_t offs0 = im*nb12;
|
const size_t offs0 = im*nb12;
|
||||||
@ -4884,14 +4882,20 @@ void kernel_mul_mm2_impl(
|
|||||||
|
|
||||||
device const float4 * p1 = (device const float4 *)(src1 + offs0 + (i11 + ir)*nb11 + (i00 + ic)*nb10);
|
device const float4 * p1 = (device const float4 *)(src1 + offs0 + (i11 + ir)*nb11 + (i00 + ic)*nb10);
|
||||||
|
|
||||||
|
//float4 tmp0 = *p1;
|
||||||
|
//tmp0[0] = 1; tmp0[1] = 1; tmp0[2] = 1; tmp0[3] = 1;
|
||||||
|
|
||||||
if (i00 + ic + 4 <= ne00) {
|
if (i00 + ic + 4 <= ne00) {
|
||||||
s14[(8*NSH1*ir + ic)/4] = *p1;
|
s14[(8*NSH1*ir + ic)/4] = *p1;
|
||||||
} else {
|
} else {
|
||||||
for (int k = 0; i00 + ic + k < ne00; k++){
|
s14[(8*NSH1*ir + ic)/4] = 0.0f;
|
||||||
|
for (int k = 0; k < 4; k++){
|
||||||
|
if (i00 + ic + k < ne00) {
|
||||||
s1[8*NSH1*ir + ic + k] = (*p1)[k];
|
s1[8*NSH1*ir + ic + k] = (*p1)[k];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
@ -4918,11 +4922,16 @@ void kernel_mul_mm2_impl(
|
|||||||
|
|
||||||
dequantize_func(p0, il, tmp0);
|
dequantize_func(p0, il, tmp0);
|
||||||
|
|
||||||
|
//for (int z = 0; z < 16; z++) {
|
||||||
|
// tmp0[z/4][z%4] = 1;
|
||||||
|
//}
|
||||||
|
|
||||||
if (icc + 16 <= ne00) {
|
if (icc + 16 <= ne00) {
|
||||||
s016[(8*NSH0*ir + ic)/16] = tmp0;
|
s016[(8*NSH0*ir + ic)/16] = tmp0;
|
||||||
} else {
|
} else {
|
||||||
|
s016[(8*NSH0*ir + ic)/16] = half4x4(0.0h);
|
||||||
for (int k = 0; k < 4; k++){
|
for (int k = 0; k < 4; k++){
|
||||||
if (icc + 4*k <= ne00) {
|
if (icc + 4*k < ne00) {
|
||||||
s04[(8*NSH0*ir + ic)/4 + k] = tmp0[k];
|
s04[(8*NSH0*ir + ic)/4 + k] = tmp0[k];
|
||||||
} else {
|
} else {
|
||||||
for (int p = 0; icc + 4*k + p < ne00; p++) {
|
for (int p = 0; icc + 4*k + p < ne00; p++) {
|
||||||
@ -4953,9 +4962,9 @@ void kernel_mul_mm2_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
|
||||||
// write the mr to shared memory
|
// write the mr to shared memory
|
||||||
|
|
||||||
|
@ -2076,16 +2076,20 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
for (int r0 = 0; r0 < 32; ++r0) {
|
||||||
|
for (int c0 = 0; c0 < 4096; c0 += 512) {
|
||||||
for (ggml_type type_a : {GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_F16}) {
|
for (ggml_type type_a : {GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_F16}) {
|
||||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 1, 4096, { 1, 1}, {1, 1}));
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 1, 64 + c0, { 1, 1}, {1, 1}));
|
||||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 2, 4096, { 1, 1}, {1, 1}));
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 2, 64 + c0, { 1, 1}, {1, 1}));
|
||||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 3, 4096, { 1, 1}, {1, 1}));
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 3, 64 + c0, { 1, 1}, {1, 1}));
|
||||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 4, 4096, { 1, 1}, {1, 1}));
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 4, 64 + c0, { 1, 1}, {1, 1}));
|
||||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 5, 4096, { 1, 1}, {1, 1}));
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 5, 64 + c0, { 1, 1}, {1, 1}));
|
||||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 6, 4096, { 1, 1}, {1, 1}));
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 6, 64 + c0, { 1, 1}, {1, 1}));
|
||||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 7, 4096, { 1, 1}, {1, 1}));
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 7, 64 + c0, { 1, 1}, {1, 1}));
|
||||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 8, 4096, { 1, 1}, {1, 1}));
|
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 8, 64 + c0, { 1, 1}, {1, 1}));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
Loading…
Reference in New Issue
Block a user