Vulkan: VK_KHR_cooperative_matrix support to speed up prompt processing (#10597)

* Vulkan: Implement VK_KHR_cooperative_matrix support in the matrix matrix multiplication shader

* Improve performance with better q4_k and q5_k dequant and store unrolling

* Add Vulkan MUL_MAT and MUL_MAT_ID accumulator precision selection

* Rework mulmat shader selection and compilation logic, avoid compiling shaders that won't get used by device

* Vulkan: Implement accumulator switch for specific mul mat mat shaders

* Vulkan: Unroll more loops for more mul mat mat performance

* Vulkan: Add VK_AMD_shader_core_properties2 support to read Compute Unit count for split_k logic

* Disable coopmat support on AMD proprietary driver

* Remove redundant checks

* Add environment variable GGML_VK_DISABLE_COOPMAT to disable VK_KHR_cooperative_matrix support

* Fix rebase typo

* Fix coopmat2 MUL_MAT_ID pipeline selection
This commit is contained in:
0cc4m 2024-12-07 10:24:15 +01:00 committed by GitHub
parent 86a1934978
commit 3df784b305
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 739 additions and 386 deletions

File diff suppressed because it is too large Load Diff

View File

@ -7,6 +7,12 @@
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif #endif
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_shader_subgroup_basic : enable
#endif
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#endif #endif
@ -57,6 +63,7 @@ layout (push_constant) uniform parameter
#endif #endif
} p; } p;
layout (constant_id = 0) const uint BLOCK_SIZE = 64;
layout (constant_id = 1) const uint BM = 64; layout (constant_id = 1) const uint BM = 64;
layout (constant_id = 2) const uint BN = 64; layout (constant_id = 2) const uint BN = 64;
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
@ -65,13 +72,26 @@ layout (constant_id = 5) const uint WN = 32;
layout (constant_id = 6) const uint WMITER = 2; layout (constant_id = 6) const uint WMITER = 2;
layout (constant_id = 7) const uint TM = 4; layout (constant_id = 7) const uint TM = 4;
layout (constant_id = 8) const uint TN = 2; layout (constant_id = 8) const uint TN = 2;
layout (constant_id = 9) const uint WARP = 32; layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
layout (constant_id = 10) const uint WARP = 32;
shared FLOAT_TYPE buf_a[BM * (BK+1)]; #ifdef COOPMAT
shared FLOAT_TYPE buf_b[BN * (BK+1)]; #define SHMEM_STRIDE (BK + 8)
#else
#define SHMEM_STRIDE (BK + 1)
#endif
shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
shared u16vec2 row_ids[3072]; shared u16vec2 row_ids[3072];
#endif // MUL_MAT_ID
#define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef COOPMAT
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif #endif
void main() { void main() {
@ -98,17 +118,32 @@ void main() {
const uint ik = gl_WorkGroupID.x / blocks_m; const uint ik = gl_WorkGroupID.x / blocks_m;
const uint ic = gl_WorkGroupID.y; const uint ic = gl_WorkGroupID.y;
const uint warp_i = gl_LocalInvocationID.x / WARP;
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const uint WSUBM = WM / WMITER; const uint WSUBM = WM / WMITER;
const uint WSUBN = WN / WNITER; const uint WSUBN = WN / WNITER;
#ifdef COOPMAT
const uint warp_i = gl_SubgroupID;
const uint tiw = gl_SubgroupInvocationID;
const uint cms_per_row = WM / TM;
const uint cms_per_col = WN / TN;
const uint storestride = WARP / TM;
const uint store_r = tiw % TM;
const uint store_c = tiw / TM;
#else
const uint warp_i = gl_LocalInvocationID.x / WARP;
const uint tiw = gl_LocalInvocationID.x % WARP; const uint tiw = gl_LocalInvocationID.x % WARP;
const uint tiwr = tiw % (WSUBM / TM); const uint tiwr = tiw % (WSUBM / TM);
const uint tiwc = tiw / (WSUBM / TM); const uint tiwc = tiw / (WSUBM / TM);
#endif
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
@ -156,21 +191,31 @@ void main() {
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
#endif #endif
float sums[WMITER * TM * WNITER * TN]; #ifdef COOPMAT
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
}
#else
ACC_TYPE sums[WMITER * TM * WNITER * TN];
FLOAT_TYPE cache_a[WMITER * TM]; FLOAT_TYPE cache_a[WMITER * TM];
FLOAT_TYPE cache_b[WNITER * TN]; FLOAT_TYPE cache_b[WNITER * TN];
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = 0.0f; sums[i] = ACC_TYPE(0.0f);
} }
#endif
[[unroll]] for (uint block = start_k; block < end_k; block += BK) { for (uint block = start_k; block < end_k; block += BK) {
[[unroll]] for (uint l = 0; l < BM; l += loadstride_a) { [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
#if defined(DATA_A_F32) || defined(DATA_A_F16) #if defined(DATA_A_F32) || defined(DATA_A_F16)
#if LOAD_VEC_A == 8 #if LOAD_VEC_A == 8
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x); buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y); buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z); buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
@ -181,21 +226,21 @@ void main() {
buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w); buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
#elif LOAD_VEC_A == 4 #elif LOAD_VEC_A == 4
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x); buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y); buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z); buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w); buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
#else #else
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
} else { } else {
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f); buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
} }
#endif #endif
#elif defined(DATA_A_Q4_0) #elif defined(DATA_A_Q4_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
@ -208,7 +253,7 @@ void main() {
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q4_1) #elif defined(DATA_A_Q4_1)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
@ -222,7 +267,7 @@ void main() {
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q5_0) #elif defined(DATA_A_Q5_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
@ -237,7 +282,7 @@ void main() {
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q5_1) #elif defined(DATA_A_Q5_1)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
@ -253,7 +298,7 @@ void main() {
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q8_0) #elif defined(DATA_A_Q8_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = (idx & 0xF) * 2; const uint iqs = (idx & 0xF) * 2;
@ -265,7 +310,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q2_K) #elif defined(DATA_A_Q2_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127 const uint iqs = idx % 128; // 0..127
@ -284,7 +329,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q3_K) #elif defined(DATA_A_Q3_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127 const uint iqs = idx % 128; // 0..127
@ -308,7 +353,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
#elif defined(DATA_A_Q4_K) #elif defined(DATA_A_Q4_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127 const uint iqs = idx % 128; // 0..127
@ -320,15 +365,20 @@ void main() {
const vec2 loadd = vec2(data_a[ib].d); const vec2 loadd = vec2(data_a[ib].d);
uint8_t sc; const uint scidx0 = (is < 4) ? is : (is + 4);
uint8_t mbyte; const uint scidx1 = (is < 4) ? is : (is - 4);
if (is < 4) { const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
sc = uint8_t(data_a[ib].scales[is ] & 63); const uint scidxshift1 = (is < 4) ? 0 : 2;
mbyte = uint8_t(data_a[ib].scales[is + 4] & 63); const uint mbidx0 = is + 4;
} else { const uint mbidx1 = (is < 4) ? is + 4 : is;
sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4)); const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4)); const uint mbidxshift0 = (is < 4) ? 0 : 4;
} const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint mbidxshift1 = (is < 4) ? 0 : 2;
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const float d = loadd.x * sc; const float d = loadd.x * sc;
const float m = -loadd.y * mbyte; const float m = -loadd.y * mbyte;
@ -336,7 +386,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
#elif defined(DATA_A_Q5_K) #elif defined(DATA_A_Q5_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127 const uint iqs = idx % 128; // 0..127
@ -351,15 +401,20 @@ void main() {
const vec2 loadd = vec2(data_a[ib].d); const vec2 loadd = vec2(data_a[ib].d);
uint8_t sc; const uint scidx0 = (is < 4) ? is : (is + 4);
uint8_t mbyte; const uint scidx1 = (is < 4) ? is : (is - 4);
if (is < 4) { const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
sc = uint8_t(data_a[ib].scales[is ] & 63); const uint scidxshift1 = (is < 4) ? 0 : 2;
mbyte = uint8_t(data_a[ib].scales[is + 4] & 63); const uint mbidx0 = is + 4;
} else { const uint mbidx1 = (is < 4) ? is + 4 : is;
sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4)); const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4)); const uint mbidxshift0 = (is < 4) ? 0 : 4;
} const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint mbidxshift1 = (is < 4) ? 0 : 2;
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const float d = loadd.x * sc; const float d = loadd.x * sc;
const float m = -loadd.y * mbyte; const float m = -loadd.y * mbyte;
@ -367,7 +422,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
#elif defined(DATA_A_Q6_K) #elif defined(DATA_A_Q6_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127 const uint iqs = idx % 128; // 0..127
@ -386,7 +441,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
#elif defined(DATA_A_IQ4_NL) #elif defined(DATA_A_IQ4_NL)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
@ -407,7 +462,7 @@ void main() {
#else #else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
#endif #endif
const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B; const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z); buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
@ -423,24 +478,24 @@ void main() {
#else #else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
#endif #endif
const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B; const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x); buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y); buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z); buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w); buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
#elif !MUL_MAT_ID #elif !MUL_MAT_ID
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
} else { } else {
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f); buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
} }
#else #else
const uint row_i = ic * BN + loadc_b + l; const uint row_i = ic * BN + loadc_b + l;
if (row_i < _ne1) { if (row_i < _ne1) {
const u16vec2 row_idx = row_ids[row_i]; const u16vec2 row_idx = row_ids[row_i];
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
} else { } else {
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f); buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
} }
#endif #endif
} }
@ -450,16 +505,30 @@ void main() {
pos_a += BK / LOAD_VEC_A; pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B; pos_b += BK / LOAD_VEC_B;
for (uint i = 0; i < BK; i++) { #ifdef COOPMAT
[[unroll]] for (uint i = 0; i < BK; i += TK) {
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
// Load from shared into cache
coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
}
}
}
#else
[[unroll]] for (uint i = 0; i < BK; i++) {
// Load from shared into cache // Load from shared into cache
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint j = 0; j < TM; j++) { [[unroll]] for (uint j = 0; j < TM; j++) {
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
} }
} }
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint j = 0; j < TN; j++) { [[unroll]] for (uint j = 0; j < TN; j++) {
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
} }
} }
@ -468,12 +537,13 @@ void main() {
[[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
sums[sums_idx] = fma(float(cache_a[wsir * TM + cr]), float(cache_b[wsic * TN + cc]), sums[sums_idx]); sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]);
} }
} }
} }
} }
} }
#endif
barrier(); barrier();
} }
@ -485,6 +555,54 @@ void main() {
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif #endif
#ifdef COOPMAT
#ifdef MUL_MAT_ID
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
const uint row_i = dc + cm_col * TN + col + store_c;
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
#else
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
if (is_aligned && is_in_bounds) {
// Full coopMat is within bounds and stride_d is aligned with 16B
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
} else if (is_in_bounds) {
// Full coopMat is within bounds, but stride_d is not aligned
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
// Partial coopMat is within bounds
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
}
}
#endif // MUL_MAT_ID
#else
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
@ -496,7 +614,7 @@ void main() {
if (row_i >= _ne1) break; if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i]; const u16vec2 row_idx = row_ids[row_i];
#endif #endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) {
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
@ -504,9 +622,10 @@ void main() {
if (dr_warp + cr < p.M && dc_warp + cc < p.N) { if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
} }
#endif #endif // MUL_MAT_ID
} }
} }
} }
} }
#endif // COOPMAT
} }

View File

@ -60,6 +60,7 @@ const std::vector<std::string> type_names = {
"iq4_nl" "iq4_nl"
}; };
namespace {
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) { void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
#ifdef _WIN32 #ifdef _WIN32
HANDLE stdout_read, stdout_write; HANDLE stdout_read, stdout_write;
@ -198,8 +199,8 @@ static uint32_t compile_count = 0;
static std::mutex compile_count_mutex; static std::mutex compile_count_mutex;
static std::condition_variable compile_count_cond; static std::condition_variable compile_count_cond;
void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) { void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
std::string out_fname = join_paths(output_dir, name + ".spv"); std::string out_fname = join_paths(output_dir, name + ".spv");
std::string in_path = join_paths(input_dir, in_fname); std::string in_path = join_paths(input_dir, in_fname);
@ -258,7 +259,7 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s
} }
static std::vector<std::future<void>> compiles; static std::vector<std::future<void>> compiles;
void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) { void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
{ {
// wait until fewer than N compiles are in progress. // wait until fewer than N compiles are in progress.
// 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
@ -269,10 +270,10 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
} }
compile_count++; compile_count++;
} }
compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat2, f16acc)); compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
} }
void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) { void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
@ -291,14 +292,20 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) {
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
if (coopmat) {
base_dict["COOPMAT"] = "1";
}
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
// Shaders with f16 B_TYPE // Shaders with f16 B_TYPE
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
for (const auto& tname : type_names) { for (const auto& tname : type_names) {
std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string data_a_key = "DATA_A_" + to_uppercase(tname);
@ -307,12 +314,12 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) {
// For aligned matmul loads // For aligned matmul loads
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2"; std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
if (tname != "f16" && tname != "f32") { if (tname != "f16" && tname != "f32") {
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
} }
} }
} }
@ -322,25 +329,24 @@ void process_shaders() {
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}}; std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
// matmul // matmul
for (const auto& fp16 : {false, true}) {
for (const auto& matmul_id : {false, true}) { for (const auto& matmul_id : {false, true}) {
for (const auto& coopmat2 : {false, true}) { // No coopmats
for (const auto& f16acc : {false, true}) { // fp32
#if !defined(VK_NV_cooperative_matrix2) matmul_shaders(false, matmul_id, false, false, false);
if (coopmat2) {
continue; // fp16, fp32acc and fp16acc
} matmul_shaders(true, matmul_id, false, false, false);
matmul_shaders(true, matmul_id, false, false, true);
// Coopmat, fp32acc and fp16acc
matmul_shaders(true, matmul_id, true, false, false);
matmul_shaders(true, matmul_id, true, false, true);
#if defined(VK_NV_cooperative_matrix2)
// Coopmat2, fp32acc and fp16acc
matmul_shaders(true, matmul_id, false, true, false);
matmul_shaders(true, matmul_id, false, true, true);
#endif #endif
if (coopmat2 && !fp16) {
continue;
}
if (!coopmat2 && f16acc) {
continue;
}
matmul_shaders(fp16, matmul_id, coopmat2, f16acc);
}
}
}
} }
#if defined(VK_NV_cooperative_matrix2) #if defined(VK_NV_cooperative_matrix2)
@ -355,11 +361,11 @@ void process_shaders() {
if (tname == "f16") { if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, true, f16acc); merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
} else { } else {
std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, true, f16acc); merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
} }
} }
} }
@ -524,6 +530,7 @@ void write_output_files() {
fclose(hdr); fclose(hdr);
fclose(src); fclose(src);
} }
}
int main(int argc, char** argv) { int main(int argc, char** argv) {
std::map<std::string, std::string> args; std::map<std::string, std::string> args;