vulkan: dynamic subgroup size for the remaining k quants (#10745)

* q5_k

q4_k

q3_k

q2_k

q6_k multi row example

* revert as multi row isnt faster for k quants
This commit is contained in:
Eve 2024-12-10 19:33:23 +00:00 committed by GitHub
parent ae4b922614
commit dafae66cc2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 72 additions and 61 deletions

View File

@ -44,12 +44,6 @@
#define MAX_VK_BUFFERS 256
#ifndef K_QUANTS_PER_ITERATION
#define K_QUANTS_PER_ITERATION 1
#else
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
#endif
#define VK_CHECK(err, msg) \
do { \
vk::Result err_ = (err); \
@ -1792,10 +1786,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
@ -1806,10 +1800,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size}, 1, true);
@ -1820,10 +1814,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true);

View File

@ -2,8 +2,6 @@
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_8bit_storage : require
#define K_QUANTS_PER_ITERATION 2
#ifdef MUL_MAT_ID
#define EXPERT_COUNT 8
#endif

View File

@ -3,9 +3,11 @@
#include "mul_mat_vec_base.comp"
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE tmp[32];
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
shared FLOAT_TYPE tmp[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
@ -20,22 +22,25 @@ void main() {
const uint num_blocks_per_row = p.ncols / QUANT_K;
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
// 16 threads are used to process each block
const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...16
const uint ix = tid/16;
const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
const uint step = 8;
const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = tid - step*v_im; // 0...15 or 0...7
const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = itid - step*v_im; // 0...15 or 0...7
const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
const uint l0 = 2*v_in; // 0...15
const uint q_offset = 32*v_im + l0;
const uint s_offset = 8*v_im;
const uint y_offset = 128*v_im + l0;
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y_idx = i * QUANT_K + y_offset;
f16vec2 d = data_a[ib0 + i].d;
@ -71,7 +76,7 @@ void main() {
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
[[unroll]] for (int l = 0; l < 2; ++l) {
sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
@ -96,7 +101,7 @@ void main() {
// sum up partial sums and write back result
barrier();
[[unroll]] for (uint s = 16; s > 0; s >>= 1) {
[[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}

View File

@ -3,9 +3,11 @@
#include "mul_mat_vec_base.comp"
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE tmp[32];
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
shared FLOAT_TYPE tmp[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
@ -20,17 +22,20 @@ void main() {
const uint num_blocks_per_row = p.ncols / QUANT_K;
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
// 16 threads are used to process each block
const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...16
const uint ix = tid/16;
const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
const uint step = 8;
const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = tid - step*v_im; // 0...15 or 0...7
const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = itid - step*v_im; // 0...15 or 0...7
const uint8_t m = uint8_t(1 << (4 * v_im));
const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
const uint l0 = 2*v_in; // 0...15
const uint q_offset = 32*v_im + l0;
const uint y_offset = 128*v_im + l0;
@ -38,7 +43,7 @@ void main() {
const uint s_shift = 4 * v_im;
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y_idx = i * QUANT_K + y_offset;
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
@ -66,7 +71,7 @@ void main() {
u8vec2 s10 = unpack8(s10_16);
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
[[unroll]] for (int l = 0; l < 2; ++l) {
sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)),
@ -83,7 +88,7 @@ void main() {
// sum up partial sums and write back result
barrier();
[[unroll]] for (uint s = 16; s > 0; s >>= 1) {
[[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}

View File

@ -4,11 +4,12 @@
#include "mul_mat_vec_base.comp"
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE tmp[32];
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
shared FLOAT_TYPE tmp[BLOCK_SIZE];
// This shader assumes K_QUANTS_PER_ITERATION == 2 for alignment of loads
void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
@ -22,14 +23,17 @@ void main() {
const uint num_blocks_per_row = p.ncols / QUANT_K;
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
// 16 threads are used to process each block
const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...16
const uint ix = tid/16;
const uint step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
const uint step = 4;
const uint il = tid/step; // 0...3
const uint ir = tid - step*il; // 0...7 or 0...3
const uint n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
const uint il = itid/step; // 0...3
const uint ir = itid - step*il; // 0...7 or 0...3
const uint n = 4;
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const uint v_in = il % 2;
@ -40,7 +44,7 @@ void main() {
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y1_idx = i * QUANT_K + y_offset;
const uint y2_idx = y1_idx + 128;
@ -115,7 +119,7 @@ void main() {
// sum up partial sums and write back result
barrier();
[[unroll]] for (uint s = 16; s > 0; s >>= 1) {
[[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}

View File

@ -4,9 +4,11 @@
#include "mul_mat_vec_base.comp"
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE tmp[32];
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
shared FLOAT_TYPE tmp[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
@ -21,11 +23,14 @@ void main() {
const uint num_blocks_per_row = p.ncols / QUANT_K;
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
const uint tid = gl_LocalInvocationID.x/2; // 0...31 or 0...16
const uint ix = gl_LocalInvocationID.x%2; // 0 or 0, 1
// 16 threads are used to process each block
const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...16
const uint ix = tid/16;
const uint il = tid/4; // 0...3
const uint ir = tid - 4*il; // 0...7 or 0...3
const uint il = itid/4; // 0...3
const uint ir = itid - 4*il; // 0...7 or 0...3
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const uint v_in = il % 2;
@ -36,7 +41,7 @@ void main() {
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y1_idx = i * QUANT_K + y_offset;
const uint y2_idx = y1_idx + 128;
@ -143,7 +148,7 @@ void main() {
// sum up partial sums and write back result
barrier();
[[unroll]] for (uint s = 16; s > 0; s >>= 1) {
[[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}