Skip to content

Commit

Permalink
Calculate the size of push constants for shader modules
Browse files Browse the repository at this point in the history
  • Loading branch information
corporateshark committed Feb 18, 2024
1 parent 020e3b1 commit d6407d7
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 34 deletions.
87 changes: 57 additions & 30 deletions lvk/vulkan/VulkanClasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3568,11 +3568,11 @@ VkPipeline lvk::VulkanContext::getVkPipeline(RenderPipelineHandle handle) {
}
}

const VkShaderModule* vertModule = shaderModulesPool_.get(desc.smVert);
const VkShaderModule* tescModule = shaderModulesPool_.get(desc.smTesc);
const VkShaderModule* teseModule = shaderModulesPool_.get(desc.smTese);
const VkShaderModule* geomModule = shaderModulesPool_.get(desc.smGeom);
const VkShaderModule* fragModule = shaderModulesPool_.get(desc.smFrag);
const lvk::ShaderModuleState* vertModule = shaderModulesPool_.get(desc.smVert);
const lvk::ShaderModuleState* tescModule = shaderModulesPool_.get(desc.smTesc);
const lvk::ShaderModuleState* teseModule = shaderModulesPool_.get(desc.smTese);
const lvk::ShaderModuleState* geomModule = shaderModulesPool_.get(desc.smGeom);
const lvk::ShaderModuleState* fragModule = shaderModulesPool_.get(desc.smFrag);

LVK_ASSERT(vertModule);
LVK_ASSERT(fragModule);
Expand Down Expand Up @@ -3622,16 +3622,17 @@ VkPipeline lvk::VulkanContext::getVkPipeline(RenderPipelineHandle handle) {
compareOpToVkCompareOp(desc.backFaceStencil.stencilCompareOp))
.stencilMasks(VK_STENCIL_FACE_FRONT_BIT, 0xFF, desc.frontFaceStencil.writeMask, desc.frontFaceStencil.readMask)
.stencilMasks(VK_STENCIL_FACE_BACK_BIT, 0xFF, desc.backFaceStencil.writeMask, desc.backFaceStencil.readMask)
.shaderStage(lvk::getPipelineShaderStageCreateInfo(VK_SHADER_STAGE_VERTEX_BIT, *vertModule, desc.entryPointVert, &si))
.shaderStage(lvk::getPipelineShaderStageCreateInfo(VK_SHADER_STAGE_FRAGMENT_BIT, *fragModule, desc.entryPointFrag, &si))
.shaderStage(lvk::getPipelineShaderStageCreateInfo(VK_SHADER_STAGE_VERTEX_BIT, vertModule->sm, desc.entryPointVert, &si))
.shaderStage(lvk::getPipelineShaderStageCreateInfo(VK_SHADER_STAGE_FRAGMENT_BIT, fragModule->sm, desc.entryPointFrag, &si))
.shaderStage(tescModule ? lvk::getPipelineShaderStageCreateInfo(
VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT, *tescModule, desc.entryPointTesc, &si)
VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT, tescModule->sm, desc.entryPointTesc, &si)
: VkPipelineShaderStageCreateInfo{.module = VK_NULL_HANDLE})
.shaderStage(teseModule ? lvk::getPipelineShaderStageCreateInfo(
VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT, *teseModule, desc.entryPointTese, &si)
: VkPipelineShaderStageCreateInfo{.module = VK_NULL_HANDLE})
.shaderStage(geomModule ? lvk::getPipelineShaderStageCreateInfo(VK_SHADER_STAGE_GEOMETRY_BIT, *geomModule, desc.entryPointGeom, &si)
VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT, teseModule->sm, desc.entryPointTese, &si)
: VkPipelineShaderStageCreateInfo{.module = VK_NULL_HANDLE})
.shaderStage(geomModule
? lvk::getPipelineShaderStageCreateInfo(VK_SHADER_STAGE_GEOMETRY_BIT, geomModule->sm, desc.entryPointGeom, &si)
: VkPipelineShaderStageCreateInfo{.module = VK_NULL_HANDLE})
.cullMode(cullModeToVkCullMode(desc.cullMode))
.frontFace(windingModeToVkFrontFace(desc.frontFaceWinding))
.vertexInputState(ciVertexInputState)
Expand Down Expand Up @@ -3661,7 +3662,7 @@ VkPipeline lvk::VulkanContext::getVkPipeline(ComputePipelineHandle handle) {
}

if (cps->pipeline_ == VK_NULL_HANDLE) {
const VkShaderModule* sm = shaderModulesPool_.get(cps->desc_.smComp);
const lvk::ShaderModuleState* sm = shaderModulesPool_.get(cps->desc_.smComp);

LVK_ASSERT(sm);

Expand All @@ -3672,7 +3673,7 @@ VkPipeline lvk::VulkanContext::getVkPipeline(ComputePipelineHandle handle) {
const VkComputePipelineCreateInfo ci = {
.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
.flags = 0,
.stage = lvk::getPipelineShaderStageCreateInfo(VK_SHADER_STAGE_COMPUTE_BIT, *sm, cps->desc_.entryPoint, &siComp),
.stage = lvk::getPipelineShaderStageCreateInfo(VK_SHADER_STAGE_COMPUTE_BIT, sm->sm, cps->desc_.entryPoint, &siComp),
.layout = vkPipelineLayout_,
.basePipelineHandle = VK_NULL_HANDLE,
.basePipelineIndex = -1,
Expand Down Expand Up @@ -3765,16 +3766,16 @@ void lvk::VulkanContext::destroy(lvk::RenderPipelineHandle handle) {
}

void lvk::VulkanContext::destroy(lvk::ShaderModuleHandle handle) {
const VkShaderModule* sm = shaderModulesPool_.get(handle);
const lvk::ShaderModuleState* state = shaderModulesPool_.get(handle);

if (!sm) {
if (!state) {
return;
}

if (*sm != VK_NULL_HANDLE) {
if (state->sm != VK_NULL_HANDLE) {
// a shader module can be destroyed while pipelines created using its shaders are still in use
// https://registry.khronos.org/vulkan/specs/1.3/html/chap9.html#vkDestroyShaderModule
vkDestroyShaderModule(getVkDevice(), *sm, nullptr);
vkDestroyShaderModule(getVkDevice(), state->sm, nullptr);
}

shaderModulesPool_.destroy(handle);
Expand Down Expand Up @@ -3991,7 +3992,7 @@ lvk::Format lvk::VulkanContext::getFormat(TextureHandle handle) const {

lvk::Holder<lvk::ShaderModuleHandle> lvk::VulkanContext::createShaderModule(const ShaderModuleDesc& desc, Result* outResult) {
Result result;
VkShaderModule sm = desc.dataSize ? createShaderModuleFromSPIRV(desc.data, desc.dataSize, desc.debugName, &result) // binary
ShaderModuleState sm = desc.dataSize ? createShaderModuleFromSPIRV(desc.data, desc.dataSize, desc.debugName, &result) // binary
: createShaderModuleFromGLSL(desc.stage, desc.data, desc.debugName, &result); // text

if (!result.isOk()) {
Expand All @@ -4003,33 +4004,53 @@ lvk::Holder<lvk::ShaderModuleHandle> lvk::VulkanContext::createShaderModule(cons
return {this, shaderModulesPool_.create(std::move(sm))};
}

VkShaderModule lvk::VulkanContext::createShaderModuleFromSPIRV(const void* spirv,
size_t numBytes,
const char* debugName,
Result* outResult) const {
lvk::ShaderModuleState lvk::VulkanContext::createShaderModuleFromSPIRV(const void* spirv,
size_t numBytes,
const char* debugName,
Result* outResult) const {
VkShaderModule vkShaderModule = VK_NULL_HANDLE;

const VkShaderModuleCreateInfo ci = {
.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO,
.codeSize = numBytes,
.pCode = (const uint32_t*)spirv,
};
const VkResult result = vkCreateShaderModule(vkDevice_, &ci, nullptr, &vkShaderModule);

lvk::setResultFrom(outResult, result);
{
const VkResult result = vkCreateShaderModule(vkDevice_, &ci, nullptr, &vkShaderModule);

if (result != VK_SUCCESS) {
return VK_NULL_HANDLE;
lvk::setResultFrom(outResult, result);

if (result != VK_SUCCESS) {
return {.sm = VK_NULL_HANDLE};
}
}

VK_ASSERT(lvk::setDebugObjectName(vkDevice_, VK_OBJECT_TYPE_SHADER_MODULE, (uint64_t)vkShaderModule, debugName));

LVK_ASSERT(vkShaderModule != VK_NULL_HANDLE);

return vkShaderModule;
SpvReflectShaderModule mdl;
SpvReflectResult result = spvReflectCreateShaderModule(numBytes, spirv, &mdl);
LVK_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
SCOPE_EXIT {
spvReflectDestroyShaderModule(&mdl);
};

uint32_t pushConstantsSize = 0;

for (uint32_t i = 0; i < mdl.push_constant_block_count; ++i) {
const SpvReflectBlockVariable* block = &mdl.push_constant_blocks[i];
pushConstantsSize = std::max(pushConstantsSize, block->offset + block->size);
}

return {
.sm = vkShaderModule,
.pushConstantsSize = pushConstantsSize,
};
}

VkShaderModule lvk::VulkanContext::createShaderModuleFromGLSL(ShaderStage stage,
lvk::ShaderModuleState lvk::VulkanContext::createShaderModuleFromGLSL(ShaderStage stage,
const char* source,
const char* debugName,
Result* outResult) const {
Expand All @@ -4041,7 +4062,7 @@ VkShaderModule lvk::VulkanContext::createShaderModuleFromGLSL(ShaderStage stage,

if (!source || !*source) {
Result::setResult(outResult, Result::Code::ArgumentOutOfRange, "Shader source is empty");
return VK_NULL_HANDLE;
return {};
}

if (strstr(source, "#version ") == nullptr) {
Expand Down Expand Up @@ -5286,7 +5307,13 @@ void lvk::VulkanContext::invokeShaderModuleErrorCallback(int line, int col, cons
return;
}

lvk::ShaderModuleHandle handle = shaderModulesPool_.findObject(&sm);
lvk::ShaderModuleHandle handle;

for (uint32_t i = 0; i != shaderModulesPool_.objects_.size(); i++) {
if (shaderModulesPool_.objects_[i].obj_.sm == sm) {
handle = shaderModulesPool_.getHandle(i);
}
}

if (!handle.empty()) {
config_.shaderModuleErrorCallback(this, handle, line, col, debugName);
Expand Down
13 changes: 9 additions & 4 deletions lvk/vulkan/VulkanClasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,11 @@ struct ComputePipelineState final {
VkPipeline pipeline_ = VK_NULL_HANDLE;
};

struct ShaderModuleState final {
VkShaderModule sm = VK_NULL_HANDLE;
uint32_t pushConstantsSize = 0;
};

class CommandBuffer final : public ICommandBuffer {
public:
CommandBuffer() = default;
Expand Down Expand Up @@ -516,7 +521,7 @@ class VulkanContext final : public IContext {
ColorSpace getSwapChainColorSpace() const override;
uint32_t getNumSwapchainImages() const override;
void recreateSwapchain(int newWidth, int newHeight) override;

uint32_t getFramebufferMSAABitMask() const override;

double getTimestampPeriodToMs() const override;
Expand Down Expand Up @@ -592,8 +597,8 @@ class VulkanContext final : public IContext {
void processDeferredTasks() const;
void waitDeferredTasks();
lvk::Result growDescriptorPool(uint32_t maxTextures, uint32_t maxSamplers);
VkShaderModule createShaderModuleFromSPIRV(const void* spirv, size_t numBytes, const char* debugName, Result* outResult) const;
VkShaderModule createShaderModuleFromGLSL(ShaderStage stage, const char* source, const char* debugName, Result* outResult) const;
ShaderModuleState createShaderModuleFromSPIRV(const void* spirv, size_t numBytes, const char* debugName, Result* outResult) const;
ShaderModuleState createShaderModuleFromGLSL(ShaderStage stage, const char* source, const char* debugName, Result* outResult) const;

private:
friend class lvk::VulkanSwapchain;
Expand Down Expand Up @@ -651,7 +656,7 @@ class VulkanContext final : public IContext {

lvk::ContextConfig config_;

lvk::Pool<lvk::ShaderModule, VkShaderModule> shaderModulesPool_;
lvk::Pool<lvk::ShaderModule, lvk::ShaderModuleState> shaderModulesPool_;
lvk::Pool<lvk::RenderPipeline, lvk::RenderPipelineState> renderPipelinesPool_;
lvk::Pool<lvk::ComputePipeline, lvk::ComputePipelineState> computePipelinesPool_;
lvk::Pool<lvk::Sampler, VkSampler> samplersPool_;
Expand Down

0 comments on commit d6407d7

Please sign in to comment.