From 64a36e889fadb9fae80490c1336a69f9660b6f1c Mon Sep 17 00:00:00 2001 From: Pietro Saccardi Date: Fri, 1 Dec 2023 12:14:27 +0100 Subject: [PATCH 1/2] Make BuildComputePipeline fail if the shader does not pass validation Validation might fail e.g. because the shader entry point is missing. This change allows the error to bubble up through the call stack until vkCreateComputePipelines, instead of killing the app later on with exit code 1. BuildComputePipeline is converted to early-return style. --- llpc/context/llpcCompiler.cpp | 50 +++++++++++++++++------------------ 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/llpc/context/llpcCompiler.cpp b/llpc/context/llpcCompiler.cpp index 659bfc2f95..f8b0d04bb4 100644 --- a/llpc/context/llpcCompiler.cpp +++ b/llpc/context/llpcCompiler.cpp @@ -2289,13 +2289,15 @@ Result Compiler::BuildComputePipeline(const ComputePipelineBuildInfo *pipelineIn const bool buildUsingRelocatableElf = relocatableElfRequested && canUseRelocatableComputeShaderElf(pipelineInfo); Result result = validatePipelineShaderInfo(&pipelineInfo->cs); + if (result != Result::Success) + return result; MetroHash::Hash cacheHash = {}; MetroHash::Hash pipelineHash = {}; cacheHash = PipelineDumper::generateHashForComputePipeline(pipelineInfo, true); pipelineHash = PipelineDumper::generateHashForComputePipeline(pipelineInfo, false); - if (result == Result::Success && EnableOuts()) { + if (EnableOuts()) { const ShaderModuleData *moduleData = reinterpret_cast(pipelineInfo->cs.pModuleData); auto moduleHash = reinterpret_cast(&moduleData->hash[0]); LLPC_OUTS("\n===============================================================================\n"); @@ -2310,8 +2312,7 @@ Result Compiler::BuildComputePipeline(const ComputePipelineBuildInfo *pipelineIn LLPC_OUTS("\n"); } - if (result == Result::Success) - dumpCompilerOptions(pipelineDumpFile); + dumpCompilerOptions(pipelineDumpFile); std::optional cacheAccessor; if (cl::CacheFullPipelines) { @@ -2325,40 +2326,39 @@ Result Compiler::BuildComputePipeline(const ComputePipelineBuildInfo *pipelineIn result = buildComputePipelineInternal(&computeContext, pipelineInfo, buildUsingRelocatableElf, &candidateElf, &pipelineOut->stageCacheAccess); - if (result == Result::Success) { - elfBin.codeSize = candidateElf.size(); - elfBin.pCode = candidateElf.data(); - } if (cacheAccessor && pipelineOut->pipelineCacheAccess == CacheAccessInfo::CacheNotChecked) pipelineOut->pipelineCacheAccess = CacheAccessInfo::CacheMiss; + + if (result != Result::Success) { + return result; + } + elfBin.codeSize = candidateElf.size(); + elfBin.pCode = candidateElf.data(); } else { LLPC_OUTS("Cache hit for compute pipeline.\n"); elfBin = cacheAccessor->getElfFromCache(); pipelineOut->pipelineCacheAccess = CacheAccessInfo::InternalCacheHit; } - if (result == Result::Success) { - void *allocBuf = nullptr; - if (pipelineInfo->pfnOutputAlloc) { - allocBuf = pipelineInfo->pfnOutputAlloc(pipelineInfo->pInstance, pipelineInfo->pUserData, elfBin.codeSize); - if (allocBuf) { - uint8_t *code = static_cast(allocBuf); - memcpy(code, elfBin.pCode, elfBin.codeSize); + if (!pipelineInfo->pfnOutputAlloc) // Allocator is not specified + return Result::ErrorInvalidPointer; - pipelineOut->pipelineBin.codeSize = elfBin.codeSize; - pipelineOut->pipelineBin.pCode = code; - } else - result = Result::ErrorOutOfMemory; - } else { - // Allocator is not specified - result = Result::ErrorInvalidPointer; - } - } + void *const allocBuf = + pipelineInfo->pfnOutputAlloc(pipelineInfo->pInstance, pipelineInfo->pUserData, elfBin.codeSize); + if (!allocBuf) + return Result::ErrorOutOfMemory; - if (cacheAccessor && !cacheAccessor->isInCache() && result == Result::Success) { + uint8_t *code = static_cast(allocBuf); + memcpy(code, elfBin.pCode, elfBin.codeSize); + + pipelineOut->pipelineBin.codeSize = elfBin.codeSize; + pipelineOut->pipelineBin.pCode = code; + + if (cacheAccessor && !cacheAccessor->isInCache()) { cacheAccessor->setElfInCache(elfBin); } - return result; + + return Result::Success; } // ===================================================================================================================== From c7656ef95a4e8fee461160b4f2bcc4069e2141a7 Mon Sep 17 00:00:00 2001 From: Pietro Saccardi Date: Fri, 1 Dec 2023 12:42:23 +0100 Subject: [PATCH 2/2] Prevent trimSpirvDebugInfo from looping indefinitely on invalid input An invalid SPIR-V binary might contain data that does not match any opCode, and yields a wordCount of zero, triggering an infinite loop. trimSpirvDebugInfo now returns Expected and will give an error in such a situation. Caller functions getShaderCode and getCodeSize (as well as their callers BuildShaderModule and getModuleData) have been updated to handle the error. --- llpc/context/llpcCompiler.cpp | 6 +++++- llpc/util/llpcShaderModuleHelper.cpp | 29 +++++++++++++++++++++------- llpc/util/llpcShaderModuleHelper.h | 10 ++++++---- 3 files changed, 33 insertions(+), 12 deletions(-) diff --git a/llpc/context/llpcCompiler.cpp b/llpc/context/llpcCompiler.cpp index f8b0d04bb4..d21c447006 100644 --- a/llpc/context/llpcCompiler.cpp +++ b/llpc/context/llpcCompiler.cpp @@ -575,7 +575,11 @@ Result Compiler::BuildShaderModule(const ShaderModuleBuildInfo *shaderInfo, Shad return Result::ErrorInvalidPointer; } - unsigned codeSize = ShaderModuleHelper::getCodeSize(shaderInfo); + auto codeSizeOrErr = ShaderModuleHelper::getCodeSize(shaderInfo); + if (Error err = codeSizeOrErr.takeError()) + return errorToResult(std::move(err)); + + const unsigned codeSize = *codeSizeOrErr; size_t allocSize = sizeof(ShaderModuleData) + codeSize; ShaderModuleData moduleData = {}; diff --git a/llpc/util/llpcShaderModuleHelper.cpp b/llpc/util/llpcShaderModuleHelper.cpp index 56bd1f4232..862d0f1fc7 100644 --- a/llpc/util/llpcShaderModuleHelper.cpp +++ b/llpc/util/llpcShaderModuleHelper.cpp @@ -30,6 +30,7 @@ */ #include "llpcShaderModuleHelper.h" #include "llpcDebug.h" +#include "llpcError.h" #include "llpcUtil.h" #include "spirvExt.h" #include "vkgcUtil.h" @@ -233,8 +234,9 @@ ShaderModuleUsage ShaderModuleHelper::getShaderModuleUsageInfo(const BinaryData // // @param spvBin : SPIR-V binary code // @param codeBuffer : The buffer in which to copy the shader code. -// @returns : The number of bytes written to trimSpvBin -unsigned ShaderModuleHelper::trimSpirvDebugInfo(const BinaryData *spvBin, llvm::MutableArrayRef codeBuffer) { +// @returns : The number of bytes written to trimSpvBin or an error if invalid data was encountered +Expected ShaderModuleHelper::trimSpirvDebugInfo(const BinaryData *spvBin, + llvm::MutableArrayRef codeBuffer) { bool writeCode = !codeBuffer.empty(); assert(codeBuffer.empty() || codeBuffer.size() > sizeof(SpirvHeader)); @@ -260,6 +262,12 @@ unsigned ShaderModuleHelper::trimSpirvDebugInfo(const BinaryData *spvBin, llvm:: while (codePos < end) { unsigned opCode = (codePos[0] & OpCodeMask); unsigned wordCount = (codePos[0] >> WordCountShift); + + if (wordCount == 0 || codePos + wordCount > end) { + LLPC_ERRS("Invalid SPIR-V binary\n"); + return createResultError(Result::ErrorInvalidShader); + } + bool skip = false; switch (opCode) { case OpSource: @@ -496,8 +504,12 @@ Result ShaderModuleHelper::getModuleData(const ShaderModuleBuildInfo *shaderInfo if (moduleData.binType == BinaryType::Spirv) { moduleData.usage = ShaderModuleHelper::getShaderModuleUsageInfo(&shaderBinary); - moduleData.binCode = getShaderCode(shaderInfo, codeBuffer); moduleData.usage.isInternalRtShader = shaderInfo->options.pipelineOptions.internalRtShaders; + auto codeOrErr = getShaderCode(shaderInfo, codeBuffer); + if (Error err = codeOrErr.takeError()) + return errorToResult(std::move(err)); + + moduleData.binCode = *codeOrErr; // Calculate SPIR-V cache hash Hash cacheHash = {}; @@ -520,13 +532,16 @@ Result ShaderModuleHelper::getModuleData(const ShaderModuleBuildInfo *shaderInfo // @param shaderInfo : Shader module build info // @param codeBuffer [out] : A buffer to hold the shader code. // @return : The BinaryData for the shaderCode written to codeBuffer. -BinaryData ShaderModuleHelper::getShaderCode(const ShaderModuleBuildInfo *shaderInfo, - MutableArrayRef &codeBuffer) { +Expected ShaderModuleHelper::getShaderCode(const ShaderModuleBuildInfo *shaderInfo, + MutableArrayRef &codeBuffer) { BinaryData code; const BinaryData &shaderBinary = shaderInfo->shaderBin; bool trimDebugInfo = cl::TrimDebugInfo && !(shaderInfo->options.pipelineOptions.internalRtShaders); if (trimDebugInfo) { - code.codeSize = trimSpirvDebugInfo(&shaderBinary, codeBuffer); + auto sizeOrErr = trimSpirvDebugInfo(&shaderBinary, codeBuffer); + if (Error err = sizeOrErr.takeError()) + return err; + code.codeSize = *sizeOrErr; } else { assert(shaderBinary.codeSize <= codeBuffer.size() * sizeof(codeBuffer.front())); memcpy(codeBuffer.data(), shaderBinary.pCode, shaderBinary.codeSize); @@ -539,7 +554,7 @@ BinaryData ShaderModuleHelper::getShaderCode(const ShaderModuleBuildInfo *shader // ===================================================================================================================== // @param shaderInfo : Shader module build info // @return : The number of bytes need to hold the code for this shader module. -unsigned ShaderModuleHelper::getCodeSize(const ShaderModuleBuildInfo *shaderInfo) { +Expected ShaderModuleHelper::getCodeSize(const ShaderModuleBuildInfo *shaderInfo) { const BinaryData &shaderBinary = shaderInfo->shaderBin; bool trimDebugInfo = cl::TrimDebugInfo && !(shaderInfo->options.pipelineOptions.internalRtShaders); if (!trimDebugInfo) diff --git a/llpc/util/llpcShaderModuleHelper.h b/llpc/util/llpcShaderModuleHelper.h index 9a07bf1c87..07a0a9424d 100644 --- a/llpc/util/llpcShaderModuleHelper.h +++ b/llpc/util/llpcShaderModuleHelper.h @@ -32,6 +32,7 @@ #pragma once #include "llpc.h" #include +#include #include namespace Llpc { @@ -56,7 +57,8 @@ class ShaderModuleHelper { public: static ShaderModuleUsage getShaderModuleUsageInfo(const BinaryData *spvBinCode); - static unsigned trimSpirvDebugInfo(const BinaryData *spvBin, llvm::MutableArrayRef codeBuffer); + static llvm::Expected trimSpirvDebugInfo(const BinaryData *spvBin, + llvm::MutableArrayRef codeBuffer); static Result optimizeSpirv(const BinaryData *spirvBinIn, BinaryData *spirvBinOut); @@ -70,9 +72,9 @@ class ShaderModuleHelper { static Result getShaderBinaryType(BinaryData shaderBinary, BinaryType &binaryType); static Result getModuleData(const ShaderModuleBuildInfo *shaderInfo, llvm::MutableArrayRef codeBuffer, Vkgc::ShaderModuleData &moduleData); - static unsigned getCodeSize(const ShaderModuleBuildInfo *shaderInfo); - static BinaryData getShaderCode(const ShaderModuleBuildInfo *shaderInfo, - llvm::MutableArrayRef &codeBuffer); + static llvm::Expected getCodeSize(const ShaderModuleBuildInfo *shaderInfo); + static llvm::Expected getShaderCode(const ShaderModuleBuildInfo *shaderInfo, + llvm::MutableArrayRef &codeBuffer); }; } // namespace Llpc