mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-15 14:50:51 +01:00
118 lines
4.3 KiB
Plaintext
118 lines
4.3 KiB
Plaintext
|
/**
|
||
|
* Copyright (c) 2023 Nomic, Inc. All rights reserved.
|
||
|
*
|
||
|
* This software is licensed under the terms of the Software for Open Models License (SOM),
|
||
|
* version 1.0, as detailed in the LICENSE_SOM.txt file. A copy of this license should accompany
|
||
|
* this software. Except as expressly granted in the SOM license, all rights are reserved by Nomic, Inc.
|
||
|
*/
|
||
|
|
||
|
#version 450
|
||
|
|
||
|
#include "common.comp"
|
||
|
|
||
|
#define SIZE_OF_BLOCK sizeof_block_q6_k
|
||
|
|
||
|
layout(local_size_x_id = 0) in;
|
||
|
layout(local_size_y_id = 1) in;
|
||
|
layout(local_size_z = 1) in;
|
||
|
|
||
|
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
|
||
|
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
|
||
|
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
||
|
|
||
|
layout (push_constant) uniform parameter {
|
||
|
uint inAOff;
|
||
|
uint inBOff;
|
||
|
uint outOff;
|
||
|
int ne00;
|
||
|
int ne10;
|
||
|
int ne0;
|
||
|
int ne1;
|
||
|
int ne01;
|
||
|
int gqa;
|
||
|
} pcs;
|
||
|
|
||
|
block_q6_k get_unaligned_block_q6_k(uint index) {
|
||
|
block_q6_k fres;
|
||
|
[[unroll]] for (uint it = 0; it != QK_K / 2; it++) {
|
||
|
fres.ql[it] = inA[index + it];
|
||
|
}
|
||
|
[[unroll]] for (uint it = 0; it != QK_K / 4; it++) {
|
||
|
fres.qh[it] = inA[index + QK_K/2 + it];
|
||
|
}
|
||
|
[[unroll]] for (uint it = 0; it != QK_K / 16; it++) {
|
||
|
fres.scales[it] = int8_t(inA[index + QK_K/2 + QK_K/4 + it]);
|
||
|
}
|
||
|
fres.d = u8BufToFloat16(inA, index + QK_K/2 + QK_K/4 + QK_K/16);
|
||
|
return fres;
|
||
|
}
|
||
|
|
||
|
void main() {
|
||
|
const uint8_t kmask1 = uint8_t(0x03);
|
||
|
const uint8_t kmask2 = uint8_t(0x0C);
|
||
|
const uint8_t kmask3 = uint8_t(0x30);
|
||
|
const uint8_t kmask4 = uint8_t(0xC0);
|
||
|
|
||
|
const int nb = pcs.ne00/QK_K;
|
||
|
|
||
|
const uint r0 = gl_WorkGroupID.x;
|
||
|
const uint r1 = gl_WorkGroupID.y;
|
||
|
const uint r2 = gl_WorkGroupID.z;
|
||
|
|
||
|
const uint row = 2 * r0 + gl_SubgroupID;
|
||
|
const uint offset0 = r2/pcs.gqa*(nb*pcs.ne0);
|
||
|
const uint x = row * nb + offset0; // Based from inA without base offset
|
||
|
const uint yy = r1*pcs.ne10 + r2*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
|
||
|
|
||
|
float sumf = 0;
|
||
|
|
||
|
const uint tid = gl_SubgroupInvocationID/2;
|
||
|
const uint ix = gl_SubgroupInvocationID%2;
|
||
|
const uint ip = tid/8; // 0 or 1
|
||
|
const uint il = tid%8;
|
||
|
const uint n = 4;
|
||
|
const uint l0 = n*il;
|
||
|
const uint is = 8*ip + l0/16;
|
||
|
|
||
|
const uint y_offset = 128*ip + l0;
|
||
|
const uint q_offset_l = 64*ip + l0;
|
||
|
const uint q_offset_h = 32*ip + l0;
|
||
|
|
||
|
for (uint i = ix; i < nb; i += 2) {
|
||
|
|
||
|
const uint baseIndex = (x + i) * SIZE_OF_BLOCK + pcs.inAOff;
|
||
|
// const uint index = (x + i) * SIZE_OF_BLOCK + pcs.inAOff;
|
||
|
// const block_q6_k block = get_unaligned_block_q6_k(index);
|
||
|
|
||
|
const uint qlIndex = q_offset_l;
|
||
|
const uint q2Index = qlIndex + 32;
|
||
|
const uint qhIndex = q_offset_h;
|
||
|
const uint y = yy + i * QK_K + y_offset;
|
||
|
|
||
|
float sums[4] = {0.0f, 0.0f, 0.0f, 0.0f};
|
||
|
for (uint l = 0; l < n; ++l) {
|
||
|
|
||
|
// const uint8_t currentQ1 = block.ql[qlIndex + l];
|
||
|
// const uint8_t currentQ2 = block.ql[q2Index + l];
|
||
|
// const uint8_t currentQh = block.qh[qhIndex + l];
|
||
|
const uint8_t currentQ1 = inA[baseIndex + qlIndex + l];
|
||
|
const uint8_t currentQ2 = inA[baseIndex + q2Index + l];
|
||
|
const uint8_t currentQh = inA[baseIndex + qhIndex + l];
|
||
|
|
||
|
sums[0] += inB[y+l+ 0] * (int8_t((currentQ1 & 0xF) | ((currentQh & kmask1) << 4)) - 32);
|
||
|
sums[1] += inB[y+l+32] * (int8_t((currentQ2 & 0xF) | ((currentQh & kmask2) << 2)) - 32);
|
||
|
sums[2] += inB[y+l+64] * (int8_t((currentQ1 >> 4) | ((currentQh & kmask3) << 0)) - 32);
|
||
|
sums[3] += inB[y+l+96] * (int8_t((currentQ2 >> 4) | ((currentQh & kmask4) >> 2)) - 32);
|
||
|
}
|
||
|
|
||
|
// sumf += block.d * (sums[0] * block.scales[0+is] + sums[1] * block.scales[2+is] + sums[2] * block.scales[4+is] + sums[3] * block.scales[6+is]);
|
||
|
float d = u8BufToFloat16(inA, baseIndex + QK_K/2 + QK_K/4 + QK_K/16);
|
||
|
sumf += d * (sums[0] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + is]) + sums[1] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 2 + is]) + sums[2] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 4 + is]) + sums[3] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 6 + is]));
|
||
|
}
|
||
|
|
||
|
const float tot = subgroupAdd(sumf);
|
||
|
if (subgroupElect()) {
|
||
|
out_[r1*pcs.ne0 + r2*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
|
||
|
}
|
||
|
}
|