diff --git a/lvk/LVK.h b/lvk/LVK.h index 0fad354387..cf5120ca46 100644 --- a/lvk/LVK.h +++ b/lvk/LVK.h @@ -928,6 +928,8 @@ namespace lvk { using ShaderModuleErrorCallback = void (*)(lvk::IContext*, lvk::ShaderModuleHandle, int line, int col, const char* debugName); +constexpr uint32_t kMaxCustomExtensions = 32; + struct ContextConfig { bool terminateOnValidationError = false; // invoke std::terminate() on any validation error bool enableValidation = true; @@ -936,6 +938,9 @@ struct ContextConfig { const void* pipelineCacheData = nullptr; size_t pipelineCacheDataSize = 0; ShaderModuleErrorCallback shaderModuleErrorCallback = nullptr; + const char* extensionsInstance[kMaxCustomExtensions] = {}; // add extra instance extensions on top of required ones + const char* extensionsDevice[kMaxCustomExtensions] = {}; // add extra device extensions on top of required ones + void* extensionsDeviceFeatures = nullptr; // inserted into VkPhysicalDeviceVulkan11Features::pNext }; [[nodiscard]] bool isDepthOrStencilFormat(lvk::Format format); diff --git a/lvk/vulkan/VulkanClasses.cpp b/lvk/vulkan/VulkanClasses.cpp index 479a04e75d..0eb72659d8 100644 --- a/lvk/vulkan/VulkanClasses.cpp +++ b/lvk/vulkan/VulkanClasses.cpp @@ -4479,7 +4479,7 @@ bool lvk::VulkanContext::getQueryPoolResults(QueryPoolHandle pool, void lvk::VulkanContext::createInstance() { vkInstance_ = VK_NULL_HANDLE; - const char* instanceExtensionNames[] = { + std::vector instanceExtensionNames = { VK_KHR_SURFACE_EXTENSION_NAME, VK_EXT_DEBUG_UTILS_EXTENSION_NAME, #if defined(_WIN32) @@ -4499,11 +4499,17 @@ void lvk::VulkanContext::createInstance() { #if defined(LVK_WITH_VULKAN_PORTABILITY) VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME, #endif - VK_EXT_VALIDATION_FEATURES_EXTENSION_NAME // enabled only for validation }; - const uint32_t numInstanceExtensions = config_.enableValidation ? (uint32_t)LVK_ARRAY_NUM_ELEMENTS(instanceExtensionNames) - : (uint32_t)LVK_ARRAY_NUM_ELEMENTS(instanceExtensionNames) - 1; + if (config_.enableValidation) { + instanceExtensionNames.push_back(VK_EXT_VALIDATION_FEATURES_EXTENSION_NAME); // enabled only for validation + } + + for (const char* ext : config_.extensionsInstance) { + if (ext) { + instanceExtensionNames.push_back(ext); + } + } #if !defined(ANDROID) // GPU Assisted Validation doesn't work on Android. @@ -4574,8 +4580,8 @@ void lvk::VulkanContext::createInstance() { .pApplicationInfo = &appInfo, .enabledLayerCount = config_.enableValidation ? (uint32_t)LVK_ARRAY_NUM_ELEMENTS(kDefaultValidationLayers) : 0u, .ppEnabledLayerNames = config_.enableValidation ? kDefaultValidationLayers : nullptr, - .enabledExtensionCount = numInstanceExtensions, - .ppEnabledExtensionNames = instanceExtensionNames, + .enabledExtensionCount = (uint32_t)instanceExtensionNames.size(), + .ppEnabledExtensionNames = instanceExtensionNames.data(), }; VK_ASSERT(vkCreateInstance(&ci, nullptr, &vkInstance_)); @@ -4766,7 +4772,7 @@ lvk::Result lvk::VulkanContext::initContext(const HWDeviceDesc& desc) { }; const uint32_t numQueues = ciQueue[0].queueFamilyIndex == ciQueue[1].queueFamilyIndex ? 1 : 2; - const char* deviceExtensionNames[] = { + std::vector deviceExtensionNames = { VK_KHR_SWAPCHAIN_EXTENSION_NAME, #if defined(LVK_WITH_TRACY) VK_EXT_CALIBRATED_TIMESTAMPS_EXTENSION_NAME, @@ -4796,6 +4802,12 @@ lvk::Result lvk::VulkanContext::initContext(const HWDeviceDesc& desc) { #endif }; + for (const char* ext : config_.extensionsDevice) { + if (ext) { + deviceExtensionNames.push_back(ext); + } + } + VkPhysicalDeviceFeatures deviceFeatures10 = { #if !defined(__APPLE__) .geometryShader = VK_TRUE, @@ -4816,6 +4828,7 @@ lvk::Result lvk::VulkanContext::initContext(const HWDeviceDesc& desc) { }; VkPhysicalDeviceVulkan11Features deviceFeatures11 = { .sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES, + .pNext = config_.extensionsDeviceFeatures, .storageBuffer16BitAccess = VK_TRUE, .samplerYcbcrConversion = VK_TRUE, .shaderDrawParameters = VK_TRUE, @@ -4876,8 +4889,8 @@ lvk::Result lvk::VulkanContext::initContext(const HWDeviceDesc& desc) { .pNext = createInfoNext, .queueCreateInfoCount = numQueues, .pQueueCreateInfos = ciQueue, - .enabledExtensionCount = (uint32_t)LVK_ARRAY_NUM_ELEMENTS(deviceExtensionNames), - .ppEnabledExtensionNames = deviceExtensionNames, + .enabledExtensionCount = (uint32_t)deviceExtensionNames.size(), + .ppEnabledExtensionNames = deviceExtensionNames.data(), .pEnabledFeatures = &deviceFeatures10, };