From bb9dcd560a7e81265398b0d463c40f3e467daf19 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Wed, 14 Feb 2024 20:57:17 +0100 Subject: [PATCH] Refactor validation and enumeration platform checks into functions to clean up ggml_vk_instance_init() --- ggml-vulkan.cpp | 101 ++++++++++++++++++++++++++++++------------------ 1 file changed, 63 insertions(+), 38 deletions(-) diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 37123ac8f..4e5eaff15 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -1091,7 +1091,10 @@ static void ggml_vk_print_gpu_info(size_t idx) { } } -static void ggml_vk_instance_init() { +static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions); +static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions); + +void ggml_vk_instance_init() { if (vk_instance_initialized) { return; } @@ -1102,54 +1105,40 @@ static void ggml_vk_instance_init() { vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION }; const std::vector instance_extensions = vk::enumerateInstanceExtensionProperties(); -#ifdef __APPLE__ - bool portability_enumeration_ext = false; - // Check for portability enumeration extension for MoltenVK support - for (const auto& properties : instance_extensions) { - if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) { - portability_enumeration_ext = true; - break; - } - } - if (!portability_enumeration_ext) { - std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; - } -#endif + const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions); + const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions); - std::vector layers = { -#ifdef GGML_VULKAN_VALIDATE - "VK_LAYER_KHRONOS_validation", -#endif - }; - std::vector extensions = { -#ifdef GGML_VULKAN_VALIDATE - "VK_EXT_validation_features", -#endif - }; -#ifdef __APPLE__ + std::vector layers; + + if (validation_ext) { + layers.push_back("VK_LAYER_KHRONOS_validation"); + } + std::vector extensions; + if (validation_ext) { + extensions.push_back("VK_EXT_validation_features"); + } if (portability_enumeration_ext) { extensions.push_back("VK_KHR_portability_enumeration"); } -#endif vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions); -#ifdef __APPLE__ if (portability_enumeration_ext) { instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR; } -#endif + std::vector features_enable; + vk::ValidationFeaturesEXT validation_features; -#ifdef GGML_VULKAN_VALIDATE - const std::vector features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices }; - vk::ValidationFeaturesEXT validation_features = { - features_enable, - {}, - }; - validation_features.setPNext(nullptr); - instance_create_info.setPNext(&validation_features); + if (validation_ext) { + features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices }; + validation_features = { + features_enable, + {}, + }; + validation_features.setPNext(nullptr); + instance_create_info.setPNext(&validation_features); - std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl; -#endif + std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl; + } vk_instance.instance = vk::createInstance(instance_create_info); memset(vk_instance.initialized, 0, sizeof(bool) * GGML_VK_MAX_DEVICES); @@ -5329,6 +5318,42 @@ GGML_CALL int ggml_backend_vk_reg_devices() { return vk_instance.device_indices.size(); } +// Extension availability +static bool ggml_vk_instance_validation_ext_available(const std::vector& instance_extensions) { +#ifdef GGML_VULKAN_VALIDATE + bool portability_enumeration_ext = false; + // Check for portability enumeration extension for MoltenVK support + for (const auto& properties : instance_extensions) { + if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) { + return true; + } + } + if (!portability_enumeration_ext) { + std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; + } +#endif + return false; + + UNUSED(instance_extensions); +} +static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions) { +#ifdef __APPLE__ + bool portability_enumeration_ext = false; + // Check for portability enumeration extension for MoltenVK support + for (const auto& properties : instance_extensions) { + if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) { + return true; + } + } + if (!portability_enumeration_ext) { + std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; + } +#endif + return false; + + UNUSED(instance_extensions); +} + // checks #ifdef GGML_VULKAN_CHECK_RESULTS