vulkan: copy iq4_nl LUT into shared memory (#10409)

This commit is contained in:
Jeff Bolz 2024-11-20 01:40:18 -06:00 committed by GitHub
parent 1bacb9f625
commit 8fd4b7fa29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 29 additions and 4 deletions

View File

@ -10,6 +10,8 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() { void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
init_iq4nl_shmem();
const uint tid = gl_LocalInvocationID.x % 64; const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32; const uint il = tid/32;
const uint ir = tid%32; const uint ir = tid%32;

View File

@ -12,6 +12,10 @@ void main() {
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
#if defined(DATA_A_IQ4_NL)
init_iq4nl_shmem();
#endif
if (i00 >= p.ne00) { if (i00 >= p.ne00) {
return; return;
} }

View File

@ -161,6 +161,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
void main() { void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
#if defined(DATA_A_IQ4_NL)
init_iq4nl_shmem();
#endif
// do NUM_ROWS at a time, unless there aren't enough remaining rows // do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) { if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS); compute_outputs(first_row, NUM_ROWS);

View File

@ -75,6 +75,10 @@ shared u16vec2 row_ids[3072];
#endif #endif
void main() { void main() {
#if defined(DATA_A_IQ4_NL)
init_iq4nl_shmem();
#endif
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z; const uint expert_idx = gl_GlobalInvocationID.z;
#else #else

View File

@ -298,10 +298,21 @@ struct block_iq4_nl_packed16
#define A_TYPE block_iq4_nl #define A_TYPE block_iq4_nl
#define A_TYPE_PACKED16 block_iq4_nl_packed16 #define A_TYPE_PACKED16 block_iq4_nl_packed16
const int8_t kvalues_iq4nl[16] = { const int8_t kvalues_iq4nl_const[16] = {
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113) int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113)
}; };
shared FLOAT_TYPE kvalues_iq4nl[16];
void init_iq4nl_shmem()
{
// copy the table into shared memory and sync
if (gl_LocalInvocationIndex.x < 16) {
kvalues_iq4nl[gl_LocalInvocationIndex.x] = FLOAT_TYPE(kvalues_iq4nl_const[gl_LocalInvocationIndex.x]);
}
barrier();
}
#endif #endif
#endif // !defined(GGML_TYPES_COMP) #endif // !defined(GGML_TYPES_COMP)

View File

@ -331,11 +331,11 @@ void process_shaders() {
shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp"; shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
if (tname == "f16") { if (tname == "f16") {
string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
} else { } else {
string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}); string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
} }
string_to_spv("get_rows_" + tname + "_f32", shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}); string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
} }
} }