From d6407d7f9b8049ca09a444f70bb2e11ccb5539e1 Mon Sep 17 00:00:00 2001 From: Sergey Kosarevsky Date: Sat, 17 Feb 2024 22:03:58 -0800 Subject: [PATCH] Calculate the size of push constants for shader modules --- lvk/vulkan/VulkanClasses.cpp | 87 +++++++++++++++++++++++------------- lvk/vulkan/VulkanClasses.h | 13 ++++-- 2 files changed, 66 insertions(+), 34 deletions(-) diff --git a/lvk/vulkan/VulkanClasses.cpp b/lvk/vulkan/VulkanClasses.cpp index f93cde4b02..188369e578 100644 --- a/lvk/vulkan/VulkanClasses.cpp +++ b/lvk/vulkan/VulkanClasses.cpp @@ -3568,11 +3568,11 @@ VkPipeline lvk::VulkanContext::getVkPipeline(RenderPipelineHandle handle) { } } - const VkShaderModule* vertModule = shaderModulesPool_.get(desc.smVert); - const VkShaderModule* tescModule = shaderModulesPool_.get(desc.smTesc); - const VkShaderModule* teseModule = shaderModulesPool_.get(desc.smTese); - const VkShaderModule* geomModule = shaderModulesPool_.get(desc.smGeom); - const VkShaderModule* fragModule = shaderModulesPool_.get(desc.smFrag); + const lvk::ShaderModuleState* vertModule = shaderModulesPool_.get(desc.smVert); + const lvk::ShaderModuleState* tescModule = shaderModulesPool_.get(desc.smTesc); + const lvk::ShaderModuleState* teseModule = shaderModulesPool_.get(desc.smTese); + const lvk::ShaderModuleState* geomModule = shaderModulesPool_.get(desc.smGeom); + const lvk::ShaderModuleState* fragModule = shaderModulesPool_.get(desc.smFrag); LVK_ASSERT(vertModule); LVK_ASSERT(fragModule); @@ -3622,16 +3622,17 @@ VkPipeline lvk::VulkanContext::getVkPipeline(RenderPipelineHandle handle) { compareOpToVkCompareOp(desc.backFaceStencil.stencilCompareOp)) .stencilMasks(VK_STENCIL_FACE_FRONT_BIT, 0xFF, desc.frontFaceStencil.writeMask, desc.frontFaceStencil.readMask) .stencilMasks(VK_STENCIL_FACE_BACK_BIT, 0xFF, desc.backFaceStencil.writeMask, desc.backFaceStencil.readMask) - .shaderStage(lvk::getPipelineShaderStageCreateInfo(VK_SHADER_STAGE_VERTEX_BIT, *vertModule, desc.entryPointVert, &si)) - .shaderStage(lvk::getPipelineShaderStageCreateInfo(VK_SHADER_STAGE_FRAGMENT_BIT, *fragModule, desc.entryPointFrag, &si)) + .shaderStage(lvk::getPipelineShaderStageCreateInfo(VK_SHADER_STAGE_VERTEX_BIT, vertModule->sm, desc.entryPointVert, &si)) + .shaderStage(lvk::getPipelineShaderStageCreateInfo(VK_SHADER_STAGE_FRAGMENT_BIT, fragModule->sm, desc.entryPointFrag, &si)) .shaderStage(tescModule ? lvk::getPipelineShaderStageCreateInfo( - VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT, *tescModule, desc.entryPointTesc, &si) + VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT, tescModule->sm, desc.entryPointTesc, &si) : VkPipelineShaderStageCreateInfo{.module = VK_NULL_HANDLE}) .shaderStage(teseModule ? lvk::getPipelineShaderStageCreateInfo( - VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT, *teseModule, desc.entryPointTese, &si) - : VkPipelineShaderStageCreateInfo{.module = VK_NULL_HANDLE}) - .shaderStage(geomModule ? lvk::getPipelineShaderStageCreateInfo(VK_SHADER_STAGE_GEOMETRY_BIT, *geomModule, desc.entryPointGeom, &si) + VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT, teseModule->sm, desc.entryPointTese, &si) : VkPipelineShaderStageCreateInfo{.module = VK_NULL_HANDLE}) + .shaderStage(geomModule + ? lvk::getPipelineShaderStageCreateInfo(VK_SHADER_STAGE_GEOMETRY_BIT, geomModule->sm, desc.entryPointGeom, &si) + : VkPipelineShaderStageCreateInfo{.module = VK_NULL_HANDLE}) .cullMode(cullModeToVkCullMode(desc.cullMode)) .frontFace(windingModeToVkFrontFace(desc.frontFaceWinding)) .vertexInputState(ciVertexInputState) @@ -3661,7 +3662,7 @@ VkPipeline lvk::VulkanContext::getVkPipeline(ComputePipelineHandle handle) { } if (cps->pipeline_ == VK_NULL_HANDLE) { - const VkShaderModule* sm = shaderModulesPool_.get(cps->desc_.smComp); + const lvk::ShaderModuleState* sm = shaderModulesPool_.get(cps->desc_.smComp); LVK_ASSERT(sm); @@ -3672,7 +3673,7 @@ VkPipeline lvk::VulkanContext::getVkPipeline(ComputePipelineHandle handle) { const VkComputePipelineCreateInfo ci = { .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, .flags = 0, - .stage = lvk::getPipelineShaderStageCreateInfo(VK_SHADER_STAGE_COMPUTE_BIT, *sm, cps->desc_.entryPoint, &siComp), + .stage = lvk::getPipelineShaderStageCreateInfo(VK_SHADER_STAGE_COMPUTE_BIT, sm->sm, cps->desc_.entryPoint, &siComp), .layout = vkPipelineLayout_, .basePipelineHandle = VK_NULL_HANDLE, .basePipelineIndex = -1, @@ -3765,16 +3766,16 @@ void lvk::VulkanContext::destroy(lvk::RenderPipelineHandle handle) { } void lvk::VulkanContext::destroy(lvk::ShaderModuleHandle handle) { - const VkShaderModule* sm = shaderModulesPool_.get(handle); + const lvk::ShaderModuleState* state = shaderModulesPool_.get(handle); - if (!sm) { + if (!state) { return; } - if (*sm != VK_NULL_HANDLE) { + if (state->sm != VK_NULL_HANDLE) { // a shader module can be destroyed while pipelines created using its shaders are still in use // https://registry.khronos.org/vulkan/specs/1.3/html/chap9.html#vkDestroyShaderModule - vkDestroyShaderModule(getVkDevice(), *sm, nullptr); + vkDestroyShaderModule(getVkDevice(), state->sm, nullptr); } shaderModulesPool_.destroy(handle); @@ -3991,7 +3992,7 @@ lvk::Format lvk::VulkanContext::getFormat(TextureHandle handle) const { lvk::Holder lvk::VulkanContext::createShaderModule(const ShaderModuleDesc& desc, Result* outResult) { Result result; - VkShaderModule sm = desc.dataSize ? createShaderModuleFromSPIRV(desc.data, desc.dataSize, desc.debugName, &result) // binary + ShaderModuleState sm = desc.dataSize ? createShaderModuleFromSPIRV(desc.data, desc.dataSize, desc.debugName, &result) // binary : createShaderModuleFromGLSL(desc.stage, desc.data, desc.debugName, &result); // text if (!result.isOk()) { @@ -4003,10 +4004,10 @@ lvk::Holder lvk::VulkanContext::createShaderModule(cons return {this, shaderModulesPool_.create(std::move(sm))}; } -VkShaderModule lvk::VulkanContext::createShaderModuleFromSPIRV(const void* spirv, - size_t numBytes, - const char* debugName, - Result* outResult) const { +lvk::ShaderModuleState lvk::VulkanContext::createShaderModuleFromSPIRV(const void* spirv, + size_t numBytes, + const char* debugName, + Result* outResult) const { VkShaderModule vkShaderModule = VK_NULL_HANDLE; const VkShaderModuleCreateInfo ci = { @@ -4014,22 +4015,42 @@ VkShaderModule lvk::VulkanContext::createShaderModuleFromSPIRV(const void* spirv .codeSize = numBytes, .pCode = (const uint32_t*)spirv, }; - const VkResult result = vkCreateShaderModule(vkDevice_, &ci, nullptr, &vkShaderModule); - lvk::setResultFrom(outResult, result); + { + const VkResult result = vkCreateShaderModule(vkDevice_, &ci, nullptr, &vkShaderModule); - if (result != VK_SUCCESS) { - return VK_NULL_HANDLE; + lvk::setResultFrom(outResult, result); + + if (result != VK_SUCCESS) { + return {.sm = VK_NULL_HANDLE}; + } } VK_ASSERT(lvk::setDebugObjectName(vkDevice_, VK_OBJECT_TYPE_SHADER_MODULE, (uint64_t)vkShaderModule, debugName)); LVK_ASSERT(vkShaderModule != VK_NULL_HANDLE); - return vkShaderModule; + SpvReflectShaderModule mdl; + SpvReflectResult result = spvReflectCreateShaderModule(numBytes, spirv, &mdl); + LVK_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS); + SCOPE_EXIT { + spvReflectDestroyShaderModule(&mdl); + }; + + uint32_t pushConstantsSize = 0; + + for (uint32_t i = 0; i < mdl.push_constant_block_count; ++i) { + const SpvReflectBlockVariable* block = &mdl.push_constant_blocks[i]; + pushConstantsSize = std::max(pushConstantsSize, block->offset + block->size); + } + + return { + .sm = vkShaderModule, + .pushConstantsSize = pushConstantsSize, + }; } -VkShaderModule lvk::VulkanContext::createShaderModuleFromGLSL(ShaderStage stage, +lvk::ShaderModuleState lvk::VulkanContext::createShaderModuleFromGLSL(ShaderStage stage, const char* source, const char* debugName, Result* outResult) const { @@ -4041,7 +4062,7 @@ VkShaderModule lvk::VulkanContext::createShaderModuleFromGLSL(ShaderStage stage, if (!source || !*source) { Result::setResult(outResult, Result::Code::ArgumentOutOfRange, "Shader source is empty"); - return VK_NULL_HANDLE; + return {}; } if (strstr(source, "#version ") == nullptr) { @@ -5286,7 +5307,13 @@ void lvk::VulkanContext::invokeShaderModuleErrorCallback(int line, int col, cons return; } - lvk::ShaderModuleHandle handle = shaderModulesPool_.findObject(&sm); + lvk::ShaderModuleHandle handle; + + for (uint32_t i = 0; i != shaderModulesPool_.objects_.size(); i++) { + if (shaderModulesPool_.objects_[i].obj_.sm == sm) { + handle = shaderModulesPool_.getHandle(i); + } + } if (!handle.empty()) { config_.shaderModuleErrorCallback(this, handle, line, col, debugName); diff --git a/lvk/vulkan/VulkanClasses.h b/lvk/vulkan/VulkanClasses.h index 1120665baf..3b92b4f02d 100644 --- a/lvk/vulkan/VulkanClasses.h +++ b/lvk/vulkan/VulkanClasses.h @@ -341,6 +341,11 @@ struct ComputePipelineState final { VkPipeline pipeline_ = VK_NULL_HANDLE; }; +struct ShaderModuleState final { + VkShaderModule sm = VK_NULL_HANDLE; + uint32_t pushConstantsSize = 0; +}; + class CommandBuffer final : public ICommandBuffer { public: CommandBuffer() = default; @@ -516,7 +521,7 @@ class VulkanContext final : public IContext { ColorSpace getSwapChainColorSpace() const override; uint32_t getNumSwapchainImages() const override; void recreateSwapchain(int newWidth, int newHeight) override; - + uint32_t getFramebufferMSAABitMask() const override; double getTimestampPeriodToMs() const override; @@ -592,8 +597,8 @@ class VulkanContext final : public IContext { void processDeferredTasks() const; void waitDeferredTasks(); lvk::Result growDescriptorPool(uint32_t maxTextures, uint32_t maxSamplers); - VkShaderModule createShaderModuleFromSPIRV(const void* spirv, size_t numBytes, const char* debugName, Result* outResult) const; - VkShaderModule createShaderModuleFromGLSL(ShaderStage stage, const char* source, const char* debugName, Result* outResult) const; + ShaderModuleState createShaderModuleFromSPIRV(const void* spirv, size_t numBytes, const char* debugName, Result* outResult) const; + ShaderModuleState createShaderModuleFromGLSL(ShaderStage stage, const char* source, const char* debugName, Result* outResult) const; private: friend class lvk::VulkanSwapchain; @@ -651,7 +656,7 @@ class VulkanContext final : public IContext { lvk::ContextConfig config_; - lvk::Pool shaderModulesPool_; + lvk::Pool shaderModulesPool_; lvk::Pool renderPipelinesPool_; lvk::Pool computePipelinesPool_; lvk::Pool samplersPool_;