vulkan: optimize and reenable split_k (#10637)

Use vector loads when possible in mul_mat_split_k_reduce. Use split_k
when there aren't enough workgroups to fill the shaders.
This commit is contained in:
Jeff Bolz 2024-12-03 13:29:54 -06:00 committed by GitHub
parent 91c36c269b
commit cc98896db8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 65 additions and 17 deletions

View File

@ -165,6 +165,7 @@ struct vk_device_struct {
vk_queue transfer_queue; vk_queue transfer_queue;
bool single_queue; bool single_queue;
uint32_t subgroup_size; uint32_t subgroup_size;
uint32_t shader_core_count;
bool uma; bool uma;
size_t idx; size_t idx;
@ -1498,7 +1499,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
@ -1610,11 +1611,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties(); const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
bool maintenance4_support = false; bool maintenance4_support = false;
bool sm_builtins = false;
// Check if maintenance4 is supported // Check if maintenance4 is supported
for (const auto& properties : ext_props) { for (const auto& properties : ext_props) {
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
maintenance4_support = true; maintenance4_support = true;
} else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
sm_builtins = true;
} }
} }
@ -1622,11 +1626,21 @@ static vk_device ggml_vk_get_device(size_t idx) {
vk::PhysicalDeviceMaintenance3Properties props3; vk::PhysicalDeviceMaintenance3Properties props3;
vk::PhysicalDeviceMaintenance4Properties props4; vk::PhysicalDeviceMaintenance4Properties props4;
vk::PhysicalDeviceSubgroupProperties subgroup_props; vk::PhysicalDeviceSubgroupProperties subgroup_props;
vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
props2.pNext = &props3; props2.pNext = &props3;
props3.pNext = &subgroup_props; props3.pNext = &subgroup_props;
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&subgroup_props;
if (maintenance4_support) { if (maintenance4_support) {
subgroup_props.pNext = &props4; last_struct->pNext = (VkBaseOutStructure *)&props4;
last_struct = (VkBaseOutStructure *)&props4;
} }
if (sm_builtins) {
last_struct->pNext = (VkBaseOutStructure *)&sm_props;
last_struct = (VkBaseOutStructure *)&sm_props;
}
device->physical_device.getProperties2(&props2); device->physical_device.getProperties2(&props2);
device->properties = props2.properties; device->properties = props2.properties;
@ -1643,6 +1657,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->vendor_id = device->properties.vendorID; device->vendor_id = device->properties.vendorID;
device->subgroup_size = subgroup_props.subgroupSize; device->subgroup_size = subgroup_props.subgroupSize;
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
if (sm_builtins) {
device->shader_core_count = sm_props.shaderSMCount;
} else {
device->shader_core_count = 0;
}
bool fp16_storage = false; bool fp16_storage = false;
bool fp16_compute = false; bool fp16_compute = false;
@ -2732,15 +2751,25 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
dst->device->device.resetFences({ dst->device->fence }); dst->device->device.resetFences({ dst->device->fence });
} }
static uint32_t ggml_vk_guess_split_k(int m, int n, int k) { static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) {
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")"); VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
// if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) {
// return 4;
// }
return 1; uint32_t split_k = 1;
if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) {
// If k is 'large' and the SMs will fill less than halfway, use split_k.
uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) {
split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
// Clamp to 2 or 4
split_k = std::min(split_k, 4u);
if (split_k == 3) {
split_k = 2;
}
}
}
GGML_UNUSED(m); GGML_UNUSED(n); GGML_UNUSED(k); return split_k;
} }
static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
@ -2964,10 +2993,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11)); const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8; const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned); vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
@ -2993,7 +3022,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (dryrun) { if (dryrun) {
const uint64_t x_sz_upd = x_sz * ne02 * ne03; const uint64_t x_sz_upd = x_sz * ne02 * ne03;
const uint64_t y_sz_upd = y_sz * ne12 * ne13; const uint64_t y_sz_upd = y_sz * ne12 * ne13;
const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * 4 : 0; const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0;
if ( if (
(qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
(qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) || (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) ||

View File

@ -5,7 +5,9 @@
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {float data_a[];}; layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 0) readonly buffer A4 {vec4 data_a4[];};
layout (binding = 1) writeonly buffer D {float data_d[];}; layout (binding = 1) writeonly buffer D {float data_d[];};
layout (binding = 1) writeonly buffer D4 {vec4 data_d4[];};
layout (push_constant) uniform parameter { layout (push_constant) uniform parameter {
uint ne; uint ne;
@ -13,17 +15,34 @@ layout (push_constant) uniform parameter {
} p; } p;
void main() { void main() {
const uint idx = gl_GlobalInvocationID.x; // Each invocation handles four consecutive components
const uint idx = gl_GlobalInvocationID.x * 4;
if (idx >= p.ne) { if (idx >= p.ne) {
return; return;
} }
// Check if all four components are in bounds and aligned,
// then use vector loads
if (idx + 3 < p.ne && (p.ne % 4) == 0) {
vec4 result = vec4(0.0f);
[[unroll]] for (uint i = 0; i < p.k_num; i++) {
result += data_a4[(i * p.ne + idx) / 4];
}
data_d4[idx / 4] = result;
} else {
[[unroll]] for (uint j = 0; j < 4; ++j) {
if (idx + j < p.ne) {
float result = 0.0f; float result = 0.0f;
[[unroll]] for (uint i = 0; i < p.k_num; i++) { [[unroll]] for (uint i = 0; i < p.k_num; i++) {
result += data_a[i * p.ne + idx]; result += data_a[i * p.ne + idx + j];
} }
data_d[idx] = result; data_d[idx + j] = result;
}
}
}
} }