vulkan: Optimize binary ops (#10270)

Reuse the index calculations across all of src0/src1/dst. Add a shader
variant for when src0/src1 are the same dimensions and additional modulus
for src1 aren't needed. Div/mod are slow, so add "fast" div/mod that
have a fast path when the calculation isn't needed or can be done more
cheaply.
This commit is contained in:
Jeff Bolz 2024-11-13 23:22:55 -06:00 committed by GitHub
parent 66798e42fb
commit af148c9386
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 117 additions and 52 deletions

View File

@ -192,9 +192,10 @@ struct vk_device_struct {
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_acc_f32; vk_pipeline pipeline_acc_f32;
vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16; vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
vk_pipeline pipeline_mul_f32; vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
vk_pipeline pipeline_div_f32; vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
vk_pipeline pipeline_upscale_f32; vk_pipeline pipeline_upscale_f32;
vk_pipeline pipeline_scale_f32; vk_pipeline pipeline_scale_f32;
@ -1456,13 +1457,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@ -3801,20 +3806,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr; return nullptr;
case GGML_OP_ADD: case GGML_OP_ADD:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_add_f32; return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32;
} }
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_add_f16_f32_f16; return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
} }
return nullptr; return nullptr;
case GGML_OP_MUL: case GGML_OP_MUL:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_mul_f32; return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32;
} }
return nullptr; return nullptr;
case GGML_OP_DIV: case GGML_OP_DIV:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_div_f32; return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32;
} }
return nullptr; return nullptr;
case GGML_OP_CONCAT: case GGML_OP_CONCAT:

View File

@ -3,6 +3,8 @@
#include "types.comp" #include "types.comp"
#include "generic_binary_head.comp" #include "generic_binary_head.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() { void main() {
const uint idx = gl_GlobalInvocationID.x; const uint idx = gl_GlobalInvocationID.x;
if (idx >= p.ne) { if (idx >= p.ne) {
@ -15,10 +17,13 @@ void main() {
const uint oy = (src1_i - (oz * p.nb02)) / p.nb01; const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
const uint ox = src1_i % p.nb01; const uint ox = src1_i % p.nb01;
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) { if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11])); data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
} else { } else {
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)])); data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]));
} }
} }

View File

@ -1,14 +1,29 @@
#version 450 #version 450
#extension GL_EXT_shader_16bit_storage : require
#include "types.comp" #include "types.comp"
#include "generic_binary_head.comp" #include "generic_binary_head.comp"
const uint num_threads = 256;
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
void main() { void main() {
const uint idx = get_idx(); uint idx = get_idx();
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) { if (idx >= p.ne) {
return; continue;
} }
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[src1_idx(idx)])); data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
idx += num_threads;
}
} }

View File

@ -3,6 +3,8 @@
#include "types.comp" #include "types.comp"
#include "generic_binary_head.comp" #include "generic_binary_head.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() { void main() {
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
const int dim = p.param3; const int dim = p.param3;

View File

@ -3,12 +3,25 @@
#include "types.comp" #include "types.comp"
#include "generic_binary_head.comp" #include "generic_binary_head.comp"
const uint num_threads = 256;
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
void main() { void main() {
const uint idx = get_idx(); uint idx = get_idx();
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) { if (idx >= p.ne) {
return; continue;
} }
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) / FLOAT_TYPE(data_b[src1_idx(idx)])); data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
idx += num_threads;
}
} }

View File

@ -1,4 +1,5 @@
#extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require
layout (push_constant) uniform parameter layout (push_constant) uniform parameter
{ {
@ -10,43 +11,50 @@ layout (push_constant) uniform parameter
float param1; float param2; int param3; float param1; float param2; int param3;
} p; } p;
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
// true if src0/src1 are the same shape and the indices can be reused without additional modulus
layout(constant_id = 0) const bool norepeat = false;
uint get_idx() { uint get_idx() {
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
} }
uint src0_idx(uint idx) { // mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
const uint i03 = idx / (p.ne02*p.ne01*p.ne00); uint fastmod(uint a, uint b) {
if ((b & (b-1)) == 0) {
return a & (b-1);
}
return a % b;
}
uint fastdiv(uint a, uint b) {
return (a < b) ? 0 : (a / b);
}
void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) {
i03 = fastdiv(idx, (p.ne02*p.ne01*p.ne00));
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00); i02 = fastdiv((idx - i03_offset), (p.ne01*p.ne00));
const uint i02_offset = i02*p.ne01*p.ne00; const uint i02_offset = i02*p.ne01*p.ne00;
const uint i01 = (idx - i03_offset - i02_offset) / p.ne00; i01 = (idx - i03_offset - i02_offset) / p.ne00;
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; i00 = idx - i03_offset - i02_offset - i01*p.ne00;
}
uint src0_idx(uint i00, uint i01, uint i02, uint i03) {
return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
} }
uint src1_idx(uint idx) { uint src1_idx(uint i00, uint i01, uint i02, uint i03) {
const uint i03 = idx / (p.ne02*p.ne01*p.ne00); if (norepeat) {
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; return i03*p.nb13 + i02*p.nb12 + i01*p.nb11 + i00*p.nb10;
const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00); } else {
const uint i02_offset = i02*p.ne01*p.ne00; return fastmod(i03, p.ne13)*p.nb13 + fastmod(i02, p.ne12)*p.nb12 + fastmod(i01, p.ne11)*p.nb11 + fastmod(i00, p.ne10)*p.nb10;
const uint i01 = (idx - i03_offset - i02_offset) / p.ne00; }
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
return (i03 % p.ne13)*p.nb13 + (i02 % p.ne12)*p.nb12 + (i01 % p.ne11)*p.nb11 + (i00 % p.ne10)*p.nb10;
} }
uint dst_idx(uint idx) { uint dst_idx(uint i00, uint i01, uint i02, uint i03) {
const uint i23 = idx / (p.ne22*p.ne21*p.ne20); return i03*p.nb23 + i02*p.nb22 + i01*p.nb21 + i00*p.nb20;
const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20;
const uint i22 = (idx - i23_offset) / (p.ne21*p.ne20);
const uint i22_offset = i22*p.ne21*p.ne20;
const uint i21 = (idx - i23_offset - i22_offset) / p.ne20;
const uint i20 = idx - i23_offset - i22_offset - i21*p.ne20;
return i23*p.nb23 + i22*p.nb22 + i21*p.nb21 + i20*p.nb20;
} }

View File

@ -3,6 +3,8 @@
#include "types.comp" #include "types.comp"
#include "generic_binary_head.comp" #include "generic_binary_head.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() { void main() {
const uint i00 = gl_GlobalInvocationID.x; const uint i00 = gl_GlobalInvocationID.x;
const uint i10 = gl_GlobalInvocationID.y; const uint i10 = gl_GlobalInvocationID.y;

View File

@ -4,6 +4,8 @@
#include "generic_binary_head.comp" #include "generic_binary_head.comp"
#include "dequant_funcs.comp" #include "dequant_funcs.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() { void main() {
const uint i00 = (gl_GlobalInvocationID.x)*2; const uint i00 = (gl_GlobalInvocationID.x)*2;
const uint i10 = gl_GlobalInvocationID.y; const uint i10 = gl_GlobalInvocationID.y;

View File

@ -3,12 +3,25 @@
#include "types.comp" #include "types.comp"
#include "generic_binary_head.comp" #include "generic_binary_head.comp"
const uint num_threads = 256;
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
void main() { void main() {
const uint idx = get_idx(); uint idx = get_idx();
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) { if (idx >= p.ne) {
return; continue;
} }
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) * FLOAT_TYPE(data_b[src1_idx(idx)])); data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
idx += num_threads;
}
} }