mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-01 08:49:00 +01:00
772703c8ff
Compute two result elements per workgroup (for Q{4,5}_{0,1}). This reuses the B loads across the rows and also reuses some addressing calculations. This required manually partially unrolling the loop, since the compiler is less willing to unroll outer loops. Add bounds-checking on the last iteration of the loop. I think this was at least partly broken before. Optimize the Q4_K shader to vectorize most loads and reduce the number of bit twiddling instructions.
110 lines
3.4 KiB
Plaintext
110 lines
3.4 KiB
Plaintext
#version 450
|
|
|
|
#ifdef FLOAT16
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
|
#endif
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
|
|
|
#extension GL_EXT_null_initializer : enable
|
|
|
|
#include "mul_mat_vec_base.comp"
|
|
|
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
|
layout (constant_id = 1) const uint NUM_ROWS = 1;
|
|
|
|
uint a_offset, b_offset, d_offset, y_offset;
|
|
|
|
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
|
|
|
|
void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
|
|
{
|
|
const uint col = i*BLOCK_SIZE + 2*tid;
|
|
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
|
|
const uint iybs = col - col%QUANT_K; // y block start index
|
|
|
|
// Check if the second of the pair of elements is OOB, and don't fetch B or
|
|
// accumulate it. We still fetch a pair of elements for A, which is fine for
|
|
// quantized formats since they'll be within the same block. We should
|
|
// probably skip fetching the second element for F16/F32, but as of now we
|
|
// still do.
|
|
const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
|
|
|
|
FLOAT_TYPE b0 = 0, b1 = 0;
|
|
b0 = FLOAT_TYPE(data_b[b_offset + iybs + iqs]);
|
|
if (!OOB) {
|
|
b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
|
|
}
|
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
const uint ib = ((first_row + n)*p.ncols + col)/QUANT_K; // block index
|
|
|
|
const vec2 v = dequantize(ib, iqs, a_offset);
|
|
|
|
// matrix multiplication
|
|
temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
|
|
if (!OOB) {
|
|
temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
|
|
}
|
|
}
|
|
}
|
|
|
|
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|
const uint tid = gl_LocalInvocationID.x;
|
|
|
|
get_offsets(a_offset, b_offset, d_offset);
|
|
a_offset /= QUANT_K;
|
|
|
|
y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
|
|
|
|
FLOAT_TYPE temp[NUM_ROWS] = {};
|
|
|
|
const int unroll_count = 8;
|
|
|
|
const uint num_iters = (p.ncols >= 2*tid) ? ((p.ncols - 2*tid + BLOCK_SIZE - 1) / BLOCK_SIZE) : 0;
|
|
const uint unrolled_iters = num_iters & ~(2*unroll_count - 1);
|
|
|
|
uint i = 0;
|
|
while (i < unrolled_iters) {
|
|
// Manually partially unroll the loop
|
|
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
|
iter(temp, first_row, num_rows, tid, i, false);
|
|
i += 2;
|
|
}
|
|
}
|
|
while (i < num_iters) {
|
|
iter(temp, first_row, num_rows, tid, i, true);
|
|
i += 2;
|
|
}
|
|
|
|
// sum up partial sums and write back result
|
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
tmpsh[n][tid] = temp[n];
|
|
}
|
|
barrier();
|
|
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
|
if (tid < s) {
|
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
tmpsh[n][tid] += tmpsh[n][tid + s];
|
|
}
|
|
}
|
|
barrier();
|
|
}
|
|
if (tid == 0) {
|
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
|
|
}
|
|
}
|
|
}
|
|
|
|
void main() {
|
|
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
|
|
|
|
// do NUM_ROWS at a time, unless there aren't enough remaining rows
|
|
if (first_row + NUM_ROWS <= p.stride_d) {
|
|
compute_outputs(first_row, NUM_ROWS);
|
|
} else {
|
|
compute_outputs(first_row, p.stride_d - first_row);
|
|
}
|
|
}
|