From 601905e75ee6cbacec0ee5aa523c96fb0258bd63 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Mon, 2 Oct 2023 09:00:55 -0400 Subject: [PATCH] Move the subgroups and printf into common. --- kompute/common.comp | 2 ++ kompute/op_mul_mv_q_n.comp | 9 +++------ kompute/op_softmax.comp | 2 -- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/kompute/common.comp b/kompute/common.comp index 12fc7d8b5..2e843a878 100644 --- a/kompute/common.comp +++ b/kompute/common.comp @@ -12,6 +12,8 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int8: require #extension GL_EXT_shader_explicit_arithmetic_types_int16: require #extension GL_EXT_control_flow_attributes: enable +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_EXT_debug_printf : enable #define QK4_0 32 #define QR4_0 2 diff --git a/kompute/op_mul_mv_q_n.comp b/kompute/op_mul_mv_q_n.comp index 83de952dd..15bcbf765 100644 --- a/kompute/op_mul_mv_q_n.comp +++ b/kompute/op_mul_mv_q_n.comp @@ -6,9 +6,6 @@ * this software. Except as expressly granted in the SOM license, all rights are reserved by Nomic, Inc. */ -#extension GL_KHR_shader_subgroup_arithmetic : require -#extension GL_EXT_debug_printf : enable - void main() { const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT); const uint r0 = gl_WorkGroupID.x; @@ -27,9 +24,9 @@ void main() { uint yb = y + ix * BLOCKS_IN_QUANT + il; - debugPrintfEXT("gl_NumSubgroups=%d, gl_SubgroupID=%d, gl_SubgroupInvocationID=%d, glSubgroupSize=%d, gl_WorkGroupSize.x=%d, gl_WorkGroupSize.y=%d, gl_WorkGroupSize.z=%d\n", - gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize, - gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z); + //debugPrintfEXT("gl_NumSubgroups=%d, gl_SubgroupID=%d, gl_SubgroupInvocationID=%d, glSubgroupSize=%d, gl_WorkGroupSize.x=%d, gl_WorkGroupSize.y=%d, gl_WorkGroupSize.z=%d\n", + // gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize, + // gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z); for (uint ib = ix; ib < nb; ib += gl_SubgroupSize/2) { for (int row = 0; row < N_ROWS; row++) { diff --git a/kompute/op_softmax.comp b/kompute/op_softmax.comp index 60456a3bb..d21577ac0 100644 --- a/kompute/op_softmax.comp +++ b/kompute/op_softmax.comp @@ -10,8 +10,6 @@ #include "common.comp" -#extension GL_KHR_shader_subgroup_arithmetic : require - layout(local_size_x_id = 0) in; layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };