ggml : optimize llamafile cpu matrix multiplication for ppc64le (#10156)

This change upstreams llamafile's cpu matrix
multiplication kernels for ppc64le using MMA
builtins for FP32 datatype.

This change results in a consistent 90%
improvement in input processing time, and 20%
to 80% improvement in output processing time,
across various batch sizes.

The patch is tested with Meta-Lllama-3-8B,
Mistral-7B, Llama-2-7B-chat-hf models on a
IBM POWER10 machine.

Signed-off-by: Amrita H S <amritahs@linux.vnet.ibm.com>
This commit is contained in:
amritahs-ibm 2024-11-09 12:47:50 +05:30 committed by GitHub
parent 8fc393f246
commit e89213492d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 615 additions and 2 deletions

View File

@ -1265,8 +1265,13 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
endif() endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
message(STATUS "PowerPC detected") message(STATUS "PowerPC detected")
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le") execute_process(COMMAND bash -c "grep POWER10 /proc/cpuinfo | head -n 1"
list(APPEND ARCH_FLAGS -mcpu=powerpc64le) OUTPUT_VARIABLE POWER10_M)
string(FIND ${POWER10_M} "POWER10" substring_index)
if(${substring_index} GREATER_EQUAL 0)
list(APPEND ARCH_FLAGS -mcpu=power10)
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
list(APPEND ARCH_FLAGS -mcpu=powerpc64le)
else() else()
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native) list(APPEND ARCH_FLAGS -mcpu=native -mtune=native)
#TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be) #TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)

View File

@ -106,6 +106,10 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); } inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if defined(__MMA__)
typedef vector unsigned char vec_t;
typedef __vector_quad acc_t;
#endif
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED FUSED MULTIPLY ADD // VECTORIZED FUSED MULTIPLY ADD
@ -1026,6 +1030,600 @@ class tinyBLAS_Q0_AVX {
}; };
#endif // __AVX__ #endif // __AVX__
//PPC Implementation
#if defined(__MMA__)
#define SAVE_ACC(ACC, ii, jj) \
__builtin_mma_disassemble_acc(vec_C, ACC); \
for (int I = 0; I < 4; I++) { \
for (int J = 0; J < 4; J++) { \
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \
} \
} \
template <typename TA, typename TB, typename TC>
class tinyBLAS_PPC {
public:
tinyBLAS_PPC(int64_t k,
const TA *A, int64_t lda,
const TB *B, int64_t ldb,
TC *C, int64_t ldc,
int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
void matmul(int64_t m, int64_t n) {
mnpack(0, m, 0, n);
}
private:
void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) {
int64_t i, j;
float *aoffset = NULL, *boffset = NULL;
float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
aoffset = const_cast<float*>(a);
boffset = vec;
j = (rows >> 3);
if (j > 0) {
do {
aoffset1 = aoffset;
aoffset2 = aoffset1 + lda;
aoffset3 = aoffset2 + lda;
aoffset4 = aoffset3 + lda;
aoffset5 = aoffset4 + lda;
aoffset6 = aoffset5 + lda;
aoffset7 = aoffset6 + lda;
aoffset8 = aoffset7 + lda;
aoffset += 8 * lda;
i = (cols >> 3);
if (i > 0) {
__vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2];
vector float t1, t2, t3, t4, t5, t6, t7, t8;
do {
C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
__builtin_vsx_disassemble_pair(c1, &C1);
__builtin_vsx_disassemble_pair(c2, &C2);
__builtin_vsx_disassemble_pair(c3, &C3);
__builtin_vsx_disassemble_pair(c4, &C4);
__builtin_vsx_disassemble_pair(c5, &C5);
__builtin_vsx_disassemble_pair(c6, &C6);
__builtin_vsx_disassemble_pair(c7, &C7);
__builtin_vsx_disassemble_pair(c8, &C8);
t1 = vec_mergeh(c1[0], c2[0]);
t2 = vec_mergeh(c3[0], c4[0]);
t3 = vec_mergeh(c5[0], c6[0]);
t4 = vec_mergeh(c7[0], c8[0]);
t5 = vec_xxpermdi(t1, t2, 0);
t6 = vec_xxpermdi(t3, t4, 0);
t7 = vec_xxpermdi(t1, t2, 3);
t8 = vec_xxpermdi(t3, t4, 3);
vec_xst(t5, 0, boffset);
vec_xst(t6, 0, boffset+4);
vec_xst(t7, 0, boffset+8);
vec_xst(t8, 0, boffset+12);
t1 = vec_mergel(c1[0], c2[0]);
t2 = vec_mergel(c3[0], c4[0]);
t3 = vec_mergel(c5[0], c6[0]);
t4 = vec_mergel(c7[0], c8[0]);
t5 = vec_xxpermdi(t1, t2, 0);
t6 = vec_xxpermdi(t3, t4, 0);
t7 = vec_xxpermdi(t1, t2, 3);
t8 = vec_xxpermdi(t3, t4, 3);
vec_xst(t5, 0, boffset+16);
vec_xst(t6, 0, boffset+20);
vec_xst(t7, 0, boffset+24);
vec_xst(t8, 0, boffset+28);
t1 = vec_mergeh(c1[1], c2[1]);
t2 = vec_mergeh(c3[1], c4[1]);
t3 = vec_mergeh(c5[1], c6[1]);
t4 = vec_mergeh(c7[1], c8[1]);
t5 = vec_xxpermdi(t1, t2, 0);
t6 = vec_xxpermdi(t3, t4, 0);
t7 = vec_xxpermdi(t1, t2, 3);
t8 = vec_xxpermdi(t3, t4, 3);
vec_xst(t5, 0, boffset+32);
vec_xst(t6, 0, boffset+36);
vec_xst(t7, 0, boffset+40);
vec_xst(t8, 0, boffset+44);
t1 = vec_mergel(c1[1], c2[1]);
t2 = vec_mergel(c3[1], c4[1]);
t3 = vec_mergel(c5[1], c6[1]);
t4 = vec_mergel(c7[1], c8[1]);
t5 = vec_xxpermdi(t1, t2, 0);
t6 = vec_xxpermdi(t3, t4, 0);
t7 = vec_xxpermdi(t1, t2, 3);
t8 = vec_xxpermdi(t3, t4, 3);
vec_xst(t5, 0, boffset+48);
vec_xst(t6, 0, boffset+52);
vec_xst(t7, 0, boffset+56);
vec_xst(t8, 0, boffset+60);
aoffset1 += 8*lda;
aoffset2 += 8*lda;
aoffset3 += 8*lda;
aoffset4 += 8*lda;
boffset += 64;
i--;
} while(i > 0);
}
if (cols & 4) {
vector float c1, c2, c3, c4, c5, c6, c7, c8;
vector float t1, t2, t3, t4, t5, t6, t7, t8;
c1 = vec_xl(0, aoffset1);
c2 = vec_xl(0, aoffset2);
c3 = vec_xl(0, aoffset3);
c4 = vec_xl(0, aoffset4);
c5 = vec_xl(0, aoffset5);
c6 = vec_xl(0, aoffset6);
c7 = vec_xl(0, aoffset7);
c8 = vec_xl(0, aoffset8);
t1 = vec_mergeh(c1, c2);
t2 = vec_mergeh(c3, c4);
t3 = vec_mergeh(c5, c6);
t4 = vec_mergeh(c7, c8);
t5 = vec_xxpermdi(t1, t2, 0);
t6 = vec_xxpermdi(t3, t4, 0);
t7 = vec_xxpermdi(t1, t2, 3);
t8 = vec_xxpermdi(t3, t4, 3);
vec_xst(t5, 0, boffset);
vec_xst(t6, 0, boffset+4);
vec_xst(t7, 0, boffset+8);
vec_xst(t8, 0, boffset+12);
t1 = vec_mergel(c1, c2);
t2 = vec_mergel(c3, c4);
t3 = vec_mergel(c5, c6);
t4 = vec_mergel(c7, c8);
t5 = vec_xxpermdi(t1, t2, 0);
t6 = vec_xxpermdi(t3, t4, 0);
t7 = vec_xxpermdi(t1, t2, 3);
t8 = vec_xxpermdi(t3, t4, 3);
vec_xst(t5, 0, boffset+16);
vec_xst(t6, 0, boffset+20);
vec_xst(t7, 0, boffset+24);
vec_xst(t8, 0, boffset+28);
}
j--;
} while(j > 0);
}
if (rows & 4) {
aoffset1 = aoffset;
aoffset2 = aoffset1 + lda;
aoffset3 = aoffset2 + lda;
aoffset4 = aoffset3 + lda;
aoffset += 4 * lda;
i = (cols >> 3);
if (i > 0) {
__vector_pair C1, C2, C3, C4;
vector float c1[2], c2[2], c3[2], c4[2];
vector float t1, t2, t3, t4, t5, t6, t7, t8;
do {
C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
__builtin_vsx_disassemble_pair(c1, &C1);
__builtin_vsx_disassemble_pair(c2, &C2);
__builtin_vsx_disassemble_pair(c3, &C3);
__builtin_vsx_disassemble_pair(c4, &C4);
t1 = vec_mergeh(c1[0], c2[0]);
t2 = vec_mergeh(c3[0], c4[0]);
t3 = vec_mergel(c1[0], c2[0]);
t4 = vec_mergel(c3[0], c4[0]);
t5 = vec_xxpermdi(t1, t2, 0);
t6 = vec_xxpermdi(t1, t2, 3);
t7 = vec_xxpermdi(t3, t4, 0);
t8 = vec_xxpermdi(t3, t4, 3);
vec_xst(t5, 0, boffset);
vec_xst(t6, 0, boffset+4);
vec_xst(t7, 0, boffset+8);
vec_xst(t8, 0, boffset+12);
t1 = vec_mergeh(c1[1], c2[1]);
t2 = vec_mergeh(c3[1], c4[1]);
t3 = vec_mergel(c1[1], c2[1]);
t4 = vec_mergel(c3[1], c4[1]);
t5 = vec_xxpermdi(t1, t2, 0);
t6 = vec_xxpermdi(t1, t2, 3);
t7 = vec_xxpermdi(t3, t4, 0);
t8 = vec_xxpermdi(t3, t4, 3);
vec_xst(t5, 0, boffset+16);
vec_xst(t6, 0, boffset+20);
vec_xst(t7, 0, boffset+24);
vec_xst(t8, 0, boffset+28);
aoffset1 += 8*lda;
aoffset2 += 8*lda;
aoffset3 += 8*lda;
aoffset4 += 8*lda;
boffset += 32;
i--;
} while(i > 0);
}
if (cols & 4) {
vector float c1, c2, c3, c4;
vector float t1, t2, t3, t4;
c1 = vec_xl(0, aoffset1);
c2 = vec_xl(0, aoffset2);
c3 = vec_xl(0, aoffset3);
c4 = vec_xl(0, aoffset4);
t1 = vec_mergeh(c1, c2);
t2 = vec_mergeh(c3, c4);
t3 = vec_xxpermdi(t1, t2, 0);
t4 = vec_xxpermdi(t1, t2, 3);
vec_xst(t3, 0, boffset);
vec_xst(t4, 0, boffset+4);
t1 = vec_mergel(c1, c2);
t2 = vec_mergel(c3, c4);
t3 = vec_xxpermdi(t1, t2, 0);
t4 = vec_xxpermdi(t1, t2, 3);
vec_xst(t3, 0, boffset+8);
vec_xst(t4, 0, boffset+12);
}
}
if (rows & 3) {
aoffset1 = aoffset;
aoffset2 = aoffset1 + lda;
aoffset3 = aoffset2 + lda;
if (cols & 4) {
vector float c1, c2, c3, c4 = {0};
vector float t1, t2, t3, t4;
c1 = vec_xl(0, aoffset1);
c2 = vec_xl(0, aoffset2);
c3 = vec_xl(0, aoffset3);
t1 = vec_mergeh(c1, c2);
t2 = vec_mergeh(c3, c4);
t3 = vec_xxpermdi(t1, t2, 0);
t4 = vec_xxpermdi(t1, t2, 3);
vec_xst(t3, 0, boffset);
vec_xst(t4, 0, boffset+4);
t1 = vec_mergel(c1, c2);
t2 = vec_mergel(c3, c4);
t3 = vec_xxpermdi(t1, t2, 0);
t4 = vec_xxpermdi(t1, t2, 3);
vec_xst(t3, 0, boffset+8);
vec_xst(t4, 0, boffset+12);
}
}
}
void KERNEL_4x4(int64_t ii, int64_t jj) {
vec_t vec_A[4], vec_B[4], vec_C[4];
acc_t acc_0;
__builtin_mma_xxsetaccz(&acc_0);
for (int l = 0; l < k; l+=4) {
READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
}
SAVE_ACC(&acc_0, ii, jj);
}
void KERNEL_4x8(int64_t ii, int64_t jj) {
vec_t vec_A[4], vec_B[8], vec_C[4];
acc_t acc_0, acc_1;
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
for (int64_t l = 0; l < k; l+=4) {
READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
__builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]);
__builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
__builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
}
SAVE_ACC(&acc_0, ii, jj);
SAVE_ACC(&acc_1, ii, jj+4);
}
void KERNEL_8x4(int64_t ii, int64_t jj) {
vec_t vec_A[8], vec_B[4], vec_C[4];
acc_t acc_0, acc_1;
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
for (int64_t l = 0; l < k; l+=4) {
READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]);
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]);
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
}
SAVE_ACC(&acc_0, ii, jj);
SAVE_ACC(&acc_1, ii+4, jj);
}
void KERNEL_8x8(int64_t ii, int64_t jj) {
vec_t vec_A[16], vec_B[16], vec_C[4];
acc_t acc_0, acc_1, acc_2, acc_3;
__builtin_mma_xxsetaccz(&acc_0);
__builtin_mma_xxsetaccz(&acc_1);
__builtin_mma_xxsetaccz(&acc_2);
__builtin_mma_xxsetaccz(&acc_3);
for (int l = 0; l < k; l+=8) {
READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
for(int x = 0; x < 16; x+=2) {
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
__builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
__builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
}
}
SAVE_ACC(&acc_0, ii, jj);
SAVE_ACC(&acc_1, ii, jj+4);
SAVE_ACC(&acc_2, ii+4, jj);
SAVE_ACC(&acc_3, ii+4, jj+4);
}
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t mc, nc, mp, np;
int m_rem = MIN(m - m0, 16);
int n_rem = MIN(n - n0, 16);
if (m_rem >= 16 && n_rem >= 8) {
mc = 8;
nc = 8;
gemm<8,8>(m0, m, n0, n);
} else if(m_rem >= 8 && n_rem >= 16) {
mc = 8;
nc = 8;
gemm<8,8>(m0, m, n0, n);
} else if (m_rem >= 8 && n_rem >= 8) {
mc = 8;
nc = 8;
gemm<8,8>(m0, m, n0, n);
} else if (m_rem >= 4 && n_rem >= 8) {
mc = 4;
nc = 8;
gemm<4,8>(m0, m, n0, n);
} else if (m_rem >= 8 && n_rem >= 4) {
mc = 8;
nc = 4;
gemm<8,4>(m0, m, n0, n);
} else if (m_rem >= 4 && n_rem >= 4) {
mc = 4;
nc = 4;
gemm<4,4>(m0, m, n0, n);
} else if ((m_rem < 4) && (n_rem > 4)) {
nc = 4;
switch(m_rem) {
case 1:
mc = 1;
gemm_small(m0, m, n0, n, mc, nc);
break;
case 2:
mc = 2;
gemm_small(m0, m, n0, n, mc, nc);
break;
case 3:
mc = 3;
gemm_small(m0, m, n0, n, mc, nc);
break;
default:
return;
}
} else if ((m_rem > 4) && (n_rem < 4)) {
mc = 4;
switch(n_rem) {
case 1:
nc = 1;
gemm_small(m0, m, n0, n, mc, nc);
break;
case 2:
nc = 2;
gemm_small(m0, m, n0, n, mc, nc);
break;
case 3:
nc = 3;
gemm_small(m0, m, n0, n, mc, nc);
break;
default:
return;
}
} else {
switch((m_rem << 4) | n_rem) {
case 0x43:
mc = 4;
nc = 3;
gemm_small(m0, m, n0, n, mc, nc);
break;
case 0x42:
mc = 4;
nc = 2;
gemm_small(m0, m, n0, n, mc, nc);
break;
case 0x41:
mc = 4;
nc = 1;
gemm_small(m0, m, n0, n, mc, nc);
break;
case 0x34:
mc = 3;
nc = 4;
gemm_small(m0, m, n0, n, mc, nc);
break;
case 0x33:
mc = 3;
nc = 3;
gemm_small(m0, m, n0, n, mc, nc);
break;
case 0x32:
mc = 3;
nc = 2;
gemm_small(m0, m, n0, n, mc, nc);
break;
case 0x31:
mc = 3;
nc = 1;
gemm_small(m0, m, n0, n, mc, nc);
break;
case 0x24:
mc = 2;
nc = 4;
gemm_small(m0, m, n0, n, mc, nc);
break;
case 0x23:
mc = 2;
nc = 3;
gemm_small(m0, m, n0, n, mc, nc);
break;
case 0x22:
mc = 2;
nc = 2;
gemm_small(m0, m, n0, n, mc, nc);
break;
case 0x21:
mc = 2;
nc = 1;
gemm_small(m0, m, n0, n, mc, nc);
break;
case 0x14:
mc = 1;
nc = 4;
gemm_small(m0, m, n0, n, mc, nc);
break;
case 0x13:
mc = 1;
nc = 3;
gemm_small(m0, m, n0, n, mc, nc);
break;
case 0x12:
mc = 1;
nc = 2;
gemm_small(m0, m, n0, n, mc, nc);
break;
case 0x11:
mc = 1;
nc = 1;
gemm_small(m0, m, n0, n, mc, nc);
break;
default:
return;
}
}
mp = m0 + (m - m0) / mc * mc;
np = n0 + (n - n0) / nc * nc;
mnpack(mp, m, n0, np);
mnpack(m0, m, np, n);
}
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
int64_t ytiles = (m - m0) / RM;
int64_t xtiles = (n - n0) / RN;
int64_t tiles = xtiles * ytiles;
int64_t duty = (tiles + nth - 1) / nth;
int64_t start = duty * ith;
int64_t end = start + duty;
if (end > tiles)
end = tiles;
for (int64_t job = start; job < end; ++job) {
int64_t ii = m0 + job / xtiles * RM;
int64_t jj = n0 + job % xtiles * RN;
vec_t vec_C[4];
acc_t acc_0;
__builtin_mma_xxsetaccz(&acc_0);
vec_t vec_A[4], vec_B[4];
for (int l=0; l<k; l+=4) {
if (RN >= 4 && RM == 1) {
float* a = const_cast<float*>(A+(ii)*lda+l);
READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
vec_A[0] = (vec_t)vec_xl(0,a);
vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
} else {
READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
}
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
}
__builtin_mma_disassemble_acc(vec_C, &acc_0);
for (int I = 0; I < RM; I++) {
for (int J = 0; J < RN; J++) {
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
}
}
}
}
template <int RM, int RN>
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t ytiles = (m - m0) / RM;
int64_t xtiles = (n - n0) / RN;
int64_t tiles = xtiles * ytiles;
int64_t duty = (tiles + nth - 1) / nth;
int64_t start = duty * ith;
int64_t end = start + duty;
if (RM == 4 && RN == 4) {
kernel = &tinyBLAS_PPC::KERNEL_4x4;
} else if (RM == 4 && RN == 8) {
kernel = &tinyBLAS_PPC::KERNEL_4x8;
} else if (RM == 8 && RN == 4) {
kernel = &tinyBLAS_PPC::KERNEL_8x4;
} else if (RM == 8 && RN == 8) {
kernel = &tinyBLAS_PPC::KERNEL_8x8;
}
if (end > tiles)
end = tiles;
for (int64_t job = start; job < end; ++job) {
int64_t ii = m0 + job / xtiles * RM;
int64_t jj = n0 + job % xtiles * RN;
(this->*kernel)(ii, jj);
}
}
const TA *const A;
const TB *const B;
TC *C;
TA *At;
TB *Bt;
const int64_t k;
const int64_t lda;
const int64_t ldb;
const int64_t ldc;
const int ith;
const int nth;
};
#endif
} // namespace } // namespace
/** /**
@ -1114,6 +1712,16 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
ith, nth}; ith, nth};
tb.matmul(m, n); tb.matmul(m, n);
return true; return true;
#elif defined(__MMA__)
if (k % 8)
return false;
tinyBLAS_PPC<float, float, float> tb{
k, (const float *)A, lda,
(const float *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n);
return true;
#else #else
return false; return false;
#endif #endif