From 249a7902ec710c8d027b9cc0ed10219d2b4184f8 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 27 Nov 2024 01:21:59 -0600 Subject: [PATCH] vulkan: further optimize q5_k mul_mat_vec (#10479) --- .../vulkan-shaders/mul_mat_vec_q5_k.comp | 52 ++++++++++--------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index 22a6bfae4..b455cbd31 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -34,9 +34,6 @@ void main() { const uint q_offset = 32*v_im + l0; const uint y_offset = 64*v_im + l0; - const uint8_t hm1 = uint8_t(1 << (2*v_im)); - const uint8_t hm2 = uint8_t(hm1 << 4); - FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) { @@ -71,6 +68,18 @@ void main() { uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; + uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); + + uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; + uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; + uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0; + uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; + + qs0_16_u32_lo4 += qs0_16_lo4_offset16; + qs0_16_u32_hi4 += qs0_16_hi4_offset16; + qs64_80_u32_lo4 += qs64_80_lo4_offset16; + qs64_80_u32_hi4 += qs64_80_hi4_offset16; + uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4)); uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4)); uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4)); @@ -102,31 +111,26 @@ void main() { B_TYPE_VEC2 by232 = data_b_v2[(b_offset + y2_idx) / 2 + 16]; B_TYPE_VEC2 by248 = data_b_v2[(b_offset + y2_idx) / 2 + 24]; - uint32_t qh0 = data_a_packed16[ib0 + i].qh[l0 / 2]; - uint32_t qh1 = qh0 >> 8; - uint32_t qh16 = data_a_packed16[ib0 + i].qh[l0 / 2 + 8]; - uint32_t qh17 = qh16 >> 8; - const FLOAT_TYPE sx = - fma(FLOAT_TYPE(by10.x), (q4_0 + (((qh0 & hm1) != 0) ? 16 : 0)), - fma(FLOAT_TYPE(by10.y), (q4_1 + (((qh1 & hm1) != 0) ? 16 : 0)), - fma(FLOAT_TYPE(by116.x), (q4_2 + (((qh16 & hm1) != 0) ? 16 : 0)), - FLOAT_TYPE(by116.y) * (q4_3 + (((qh17 & hm1) != 0) ? 16 : 0))))); + fma(FLOAT_TYPE(by10.x), q4_0, + fma(FLOAT_TYPE(by10.y), q4_1, + fma(FLOAT_TYPE(by116.x), q4_2, + FLOAT_TYPE(by116.y) * q4_3))); const FLOAT_TYPE sy = - fma(FLOAT_TYPE(by132.x), (q4_4 + (((qh0 & (hm1 << 1)) != 0) ? 16 : 0)), - fma(FLOAT_TYPE(by132.y), (q4_5 + (((qh1 & (hm1 << 1)) != 0) ? 16 : 0)), - fma(FLOAT_TYPE(by148.x), (q4_6 + (((qh16 & (hm1 << 1)) != 0) ? 16 : 0)), - FLOAT_TYPE(by148.y) * (q4_7 + (((qh17 & (hm1 << 1)) != 0) ? 16 : 0))))); + fma(FLOAT_TYPE(by132.x), q4_4, + fma(FLOAT_TYPE(by132.y), q4_5, + fma(FLOAT_TYPE(by148.x), q4_6, + FLOAT_TYPE(by148.y) * q4_7))); const FLOAT_TYPE sz = - fma(FLOAT_TYPE(by20.x), (q4_8 + (((qh0 & hm2) != 0) ? 16 : 0)), - fma(FLOAT_TYPE(by20.y), (q4_9 + (((qh1 & hm2) != 0) ? 16 : 0)), - fma(FLOAT_TYPE(by216.x), (q4_10 + (((qh16 & hm2) != 0) ? 16 : 0)), - FLOAT_TYPE(by216.y) * (q4_11 + (((qh17 & hm2) != 0) ? 16 : 0))))); + fma(FLOAT_TYPE(by20.x), q4_8, + fma(FLOAT_TYPE(by20.y), q4_9, + fma(FLOAT_TYPE(by216.x), q4_10, + FLOAT_TYPE(by216.y) * q4_11))); const FLOAT_TYPE sw = - fma(FLOAT_TYPE(by232.x), (q4_12 + (((qh0 & (hm2 << 1)) != 0) ? 16 : 0)), - fma(FLOAT_TYPE(by232.y), (q4_13 + (((qh1 & (hm2 << 1)) != 0) ? 16 : 0)), - fma(FLOAT_TYPE(by248.x), (q4_14 + (((qh16 & (hm2 << 1)) != 0) ? 16 : 0)), - FLOAT_TYPE(by248.y) * (q4_15 + (((qh17 & (hm2 << 1)) != 0) ? 16 : 0))))); + fma(FLOAT_TYPE(by232.x), q4_12, + fma(FLOAT_TYPE(by232.y), q4_13, + fma(FLOAT_TYPE(by248.x), q4_14, + FLOAT_TYPE(by248.y) * q4_15))); const FLOAT_TYPE smin = fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,