#version 450

#include "common.comp"

#define NL 2
#define BYTES_FOR_TYPE 4 /*bytes for float*/
#define SIZE_OF_BLOCK sizeof_block_q4_1

layout(local_size_x = 1) in;

layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
layout (binding = 1) readonly buffer tensorInB { int inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };

layout (push_constant) uniform parameter {
    uint inAOff;
    uint inBOff;
    uint outOff;
    int ne00;
    int nb01;
    int nb1;
} pcs;

block_q4_1 get_unaligned_block_q4_1(uint index) {
    block_q4_1 fres;
    fres.d = u8BufToFloat16(inA, index);
    fres.m = u8BufToFloat16(inA, index+2);
    [[unroll]] for (uint it = 0; it != QK4_1 / 2; it++) {
        fres.qs[it] = inA[index+4+it];
    }
    return fres;
}

mat4 dequantize_block(uint index, uint il) {
    const block_q4_1 block = get_unaligned_block_q4_1(index);
    return dequantize_q4_1(block, il);
}

#include "op_getrows.comp"