mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-14 06:19:02 +01:00
vulkan: optimize and reenable split_k (#10637)
Use vector loads when possible in mul_mat_split_k_reduce. Use split_k when there aren't enough workgroups to fill the shaders.
This commit is contained in:
parent
91c36c269b
commit
cc98896db8
@ -165,6 +165,7 @@ struct vk_device_struct {
|
|||||||
vk_queue transfer_queue;
|
vk_queue transfer_queue;
|
||||||
bool single_queue;
|
bool single_queue;
|
||||||
uint32_t subgroup_size;
|
uint32_t subgroup_size;
|
||||||
|
uint32_t shader_core_count;
|
||||||
bool uma;
|
bool uma;
|
||||||
|
|
||||||
size_t idx;
|
size_t idx;
|
||||||
@ -1498,7 +1499,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
|
||||||
@ -1610,11 +1611,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||||||
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
|
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
|
||||||
|
|
||||||
bool maintenance4_support = false;
|
bool maintenance4_support = false;
|
||||||
|
bool sm_builtins = false;
|
||||||
|
|
||||||
// Check if maintenance4 is supported
|
// Check if maintenance4 is supported
|
||||||
for (const auto& properties : ext_props) {
|
for (const auto& properties : ext_props) {
|
||||||
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
||||||
maintenance4_support = true;
|
maintenance4_support = true;
|
||||||
|
} else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
|
||||||
|
sm_builtins = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1622,11 +1626,21 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||||||
vk::PhysicalDeviceMaintenance3Properties props3;
|
vk::PhysicalDeviceMaintenance3Properties props3;
|
||||||
vk::PhysicalDeviceMaintenance4Properties props4;
|
vk::PhysicalDeviceMaintenance4Properties props4;
|
||||||
vk::PhysicalDeviceSubgroupProperties subgroup_props;
|
vk::PhysicalDeviceSubgroupProperties subgroup_props;
|
||||||
|
vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
|
||||||
props2.pNext = &props3;
|
props2.pNext = &props3;
|
||||||
props3.pNext = &subgroup_props;
|
props3.pNext = &subgroup_props;
|
||||||
|
|
||||||
|
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&subgroup_props;
|
||||||
|
|
||||||
if (maintenance4_support) {
|
if (maintenance4_support) {
|
||||||
subgroup_props.pNext = &props4;
|
last_struct->pNext = (VkBaseOutStructure *)&props4;
|
||||||
|
last_struct = (VkBaseOutStructure *)&props4;
|
||||||
}
|
}
|
||||||
|
if (sm_builtins) {
|
||||||
|
last_struct->pNext = (VkBaseOutStructure *)&sm_props;
|
||||||
|
last_struct = (VkBaseOutStructure *)&sm_props;
|
||||||
|
}
|
||||||
|
|
||||||
device->physical_device.getProperties2(&props2);
|
device->physical_device.getProperties2(&props2);
|
||||||
device->properties = props2.properties;
|
device->properties = props2.properties;
|
||||||
|
|
||||||
@ -1643,6 +1657,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||||||
device->vendor_id = device->properties.vendorID;
|
device->vendor_id = device->properties.vendorID;
|
||||||
device->subgroup_size = subgroup_props.subgroupSize;
|
device->subgroup_size = subgroup_props.subgroupSize;
|
||||||
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
||||||
|
if (sm_builtins) {
|
||||||
|
device->shader_core_count = sm_props.shaderSMCount;
|
||||||
|
} else {
|
||||||
|
device->shader_core_count = 0;
|
||||||
|
}
|
||||||
|
|
||||||
bool fp16_storage = false;
|
bool fp16_storage = false;
|
||||||
bool fp16_compute = false;
|
bool fp16_compute = false;
|
||||||
@ -2732,15 +2751,25 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
|
|||||||
dst->device->device.resetFences({ dst->device->fence });
|
dst->device->device.resetFences({ dst->device->fence });
|
||||||
}
|
}
|
||||||
|
|
||||||
static uint32_t ggml_vk_guess_split_k(int m, int n, int k) {
|
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) {
|
||||||
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
|
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
|
||||||
// if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) {
|
|
||||||
// return 4;
|
|
||||||
// }
|
|
||||||
|
|
||||||
return 1;
|
uint32_t split_k = 1;
|
||||||
|
if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) {
|
||||||
|
// If k is 'large' and the SMs will fill less than halfway, use split_k.
|
||||||
|
uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
|
||||||
|
uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
|
||||||
|
if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) {
|
||||||
|
split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
|
||||||
|
// Clamp to 2 or 4
|
||||||
|
split_k = std::min(split_k, 4u);
|
||||||
|
if (split_k == 3) {
|
||||||
|
split_k = 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
GGML_UNUSED(m); GGML_UNUSED(n); GGML_UNUSED(k);
|
return split_k;
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
|
static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
|
||||||
@ -2964,10 +2993,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|||||||
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
|
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
|
||||||
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
|
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
|
||||||
|
|
||||||
const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
|
|
||||||
|
|
||||||
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
|
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
|
||||||
|
|
||||||
|
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
|
||||||
|
|
||||||
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
||||||
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
||||||
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
|
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
|
||||||
@ -2993,7 +3022,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|||||||
if (dryrun) {
|
if (dryrun) {
|
||||||
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
|
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
|
||||||
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
|
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
|
||||||
const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * 4 : 0;
|
const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0;
|
||||||
if (
|
if (
|
||||||
(qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
|
(qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
|
||||||
(qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) ||
|
(qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) ||
|
||||||
|
@ -5,7 +5,9 @@
|
|||||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A {float data_a[];};
|
layout (binding = 0) readonly buffer A {float data_a[];};
|
||||||
|
layout (binding = 0) readonly buffer A4 {vec4 data_a4[];};
|
||||||
layout (binding = 1) writeonly buffer D {float data_d[];};
|
layout (binding = 1) writeonly buffer D {float data_d[];};
|
||||||
|
layout (binding = 1) writeonly buffer D4 {vec4 data_d4[];};
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
layout (push_constant) uniform parameter {
|
||||||
uint ne;
|
uint ne;
|
||||||
@ -13,17 +15,34 @@ layout (push_constant) uniform parameter {
|
|||||||
} p;
|
} p;
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const uint idx = gl_GlobalInvocationID.x;
|
// Each invocation handles four consecutive components
|
||||||
|
const uint idx = gl_GlobalInvocationID.x * 4;
|
||||||
|
|
||||||
if (idx >= p.ne) {
|
if (idx >= p.ne) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
float result = 0.0f;
|
// Check if all four components are in bounds and aligned,
|
||||||
|
// then use vector loads
|
||||||
|
if (idx + 3 < p.ne && (p.ne % 4) == 0) {
|
||||||
|
vec4 result = vec4(0.0f);
|
||||||
|
|
||||||
[[unroll]] for (uint i = 0; i < p.k_num; i++) {
|
[[unroll]] for (uint i = 0; i < p.k_num; i++) {
|
||||||
result += data_a[i * p.ne + idx];
|
result += data_a4[(i * p.ne + idx) / 4];
|
||||||
|
}
|
||||||
|
|
||||||
|
data_d4[idx / 4] = result;
|
||||||
|
} else {
|
||||||
|
[[unroll]] for (uint j = 0; j < 4; ++j) {
|
||||||
|
if (idx + j < p.ne) {
|
||||||
|
float result = 0.0f;
|
||||||
|
|
||||||
|
[[unroll]] for (uint i = 0; i < p.k_num; i++) {
|
||||||
|
result += data_a[i * p.ne + idx + j];
|
||||||
|
}
|
||||||
|
|
||||||
|
data_d[idx + j] = result;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
data_d[idx] = result;
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user