From 06d4b21598da0162999b35429cfb567ed962d7ec Mon Sep 17 00:00:00 2001
From: Adam Treat <treat.adam@gmail.com>
Date: Mon, 2 Oct 2023 11:30:10 -0400
Subject: [PATCH] Fix offset into the qh and now we have working vulkan
 accelerated for gguff'd llama.

---
 kompute/op_mul_mat_q6_k.comp | 26 ++------------------------
 1 file changed, 2 insertions(+), 24 deletions(-)

diff --git a/kompute/op_mul_mat_q6_k.comp b/kompute/op_mul_mat_q6_k.comp
index 1e4ea37f8..c7b9aa753 100644
--- a/kompute/op_mul_mat_q6_k.comp
+++ b/kompute/op_mul_mat_q6_k.comp
@@ -32,28 +32,13 @@ layout (push_constant) uniform parameter {
     int gqa;
 } pcs;
 
-block_q6_k get_unaligned_block_q6_k(uint index) {
-    block_q6_k fres;
-    [[unroll]] for (uint it = 0; it != QK_K / 2; it++) {
-        fres.ql[it] = inA[index + it];
-    }
-    [[unroll]] for (uint it = 0; it != QK_K / 4; it++) {
-        fres.qh[it] = inA[index + QK_K/2 + it];
-    }
-    [[unroll]] for (uint it = 0; it != QK_K / 16; it++) {
-        fres.scales[it] = int8_t(inA[index + QK_K/2 + QK_K/4 + it]);
-    }
-    fres.d = u8BufToFloat16(inA, index + QK_K/2 + QK_K/4 + QK_K/16);
-    return fres;
-}
-
 void main() {
     const uint8_t kmask1 = uint8_t(0x03);
     const uint8_t kmask2 = uint8_t(0x0C);
     const uint8_t kmask3 = uint8_t(0x30);
     const uint8_t kmask4 = uint8_t(0xC0);
 
-    const int nb = pcs.ne00/QK_K;
+    const uint nb = pcs.ne00/QK_K;
 
     const uint r0 = gl_WorkGroupID.x;
     const uint r1 = gl_WorkGroupID.y;
@@ -81,8 +66,6 @@ void main() {
     for (uint i = ix; i < nb; i += 2) {
 
         const uint baseIndex = (x + i) * SIZE_OF_BLOCK + pcs.inAOff;
-//        const uint index = (x + i) * SIZE_OF_BLOCK + pcs.inAOff;
-//        const block_q6_k block = get_unaligned_block_q6_k(index);
 
         const uint qlIndex = q_offset_l;
         const uint q2Index = qlIndex + 32;
@@ -91,13 +74,9 @@ void main() {
 
         float sums[4] = {0.0f, 0.0f, 0.0f, 0.0f};
         for (uint l = 0; l < n; ++l) {
-
-//            const uint8_t currentQ1 = block.ql[qlIndex + l];
-//            const uint8_t currentQ2 = block.ql[q2Index + l];
-//            const uint8_t currentQh = block.qh[qhIndex + l];
             const uint8_t currentQ1 = inA[baseIndex + qlIndex + l];
             const uint8_t currentQ2 = inA[baseIndex + q2Index + l];
-            const uint8_t currentQh = inA[baseIndex + qhIndex + l];
+            const uint8_t currentQh = inA[baseIndex + QK_K/2 + qhIndex + l];
 
             sums[0] += inB[y+l+ 0] * (int8_t((currentQ1 & 0xF) | ((currentQh & kmask1) << 4)) - 32);
             sums[1] += inB[y+l+32] * (int8_t((currentQ2 & 0xF) | ((currentQh & kmask2) << 2)) - 32);
@@ -105,7 +84,6 @@ void main() {
             sums[3] += inB[y+l+96] * (int8_t((currentQ2  >> 4) | ((currentQh & kmask4) >> 2)) - 32);
         }
 
-//        sumf += block.d * (sums[0] * block.scales[0+is] + sums[1] * block.scales[2+is] + sums[2] * block.scales[4+is] + sums[3] * block.scales[6+is]);
         float d = u8BufToFloat16(inA, baseIndex + QK_K/2 + QK_K/4 + QK_K/16);
         sumf += d * (sums[0] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + is]) + sums[1] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 2 + is]) + sums[2] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 4 + is]) + sums[3] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 6 + is]));
     }