diff --git a/llpc/context/llpcCompiler.cpp b/llpc/context/llpcCompiler.cpp index 1b2bf2e385..396348367f 100644 --- a/llpc/context/llpcCompiler.cpp +++ b/llpc/context/llpcCompiler.cpp @@ -575,7 +575,11 @@ Result Compiler::BuildShaderModule(const ShaderModuleBuildInfo *shaderInfo, Shad return Result::ErrorInvalidPointer; } - unsigned codeSize = ShaderModuleHelper::getCodeSize(shaderInfo); + auto codeSizeOrErr = ShaderModuleHelper::getCodeSize(shaderInfo); + if (Error err = codeSizeOrErr.takeError()) { + return errorToResult(std::move(err)); + } + const unsigned codeSize = *codeSizeOrErr; size_t allocSize = sizeof(ShaderModuleData) + codeSize; ShaderModuleData moduleData = {}; diff --git a/llpc/util/llpcShaderModuleHelper.cpp b/llpc/util/llpcShaderModuleHelper.cpp index 56bd1f4232..76deca13d5 100644 --- a/llpc/util/llpcShaderModuleHelper.cpp +++ b/llpc/util/llpcShaderModuleHelper.cpp @@ -31,6 +31,7 @@ #include "llpcShaderModuleHelper.h" #include "llpcDebug.h" #include "llpcUtil.h" +#include "llpcError.h" #include "spirvExt.h" #include "vkgcUtil.h" #include "llvm/Support/CommandLine.h" @@ -233,8 +234,8 @@ ShaderModuleUsage ShaderModuleHelper::getShaderModuleUsageInfo(const BinaryData // // @param spvBin : SPIR-V binary code // @param codeBuffer : The buffer in which to copy the shader code. -// @returns : The number of bytes written to trimSpvBin -unsigned ShaderModuleHelper::trimSpirvDebugInfo(const BinaryData *spvBin, llvm::MutableArrayRef codeBuffer) { +// @returns : The number of bytes written to trimSpvBin or an error if invalid data was encountered +Expected ShaderModuleHelper::trimSpirvDebugInfo(const BinaryData *spvBin, llvm::MutableArrayRef codeBuffer) { bool writeCode = !codeBuffer.empty(); assert(codeBuffer.empty() || codeBuffer.size() > sizeof(SpirvHeader)); @@ -260,6 +261,12 @@ unsigned ShaderModuleHelper::trimSpirvDebugInfo(const BinaryData *spvBin, llvm:: while (codePos < end) { unsigned opCode = (codePos[0] & OpCodeMask); unsigned wordCount = (codePos[0] >> WordCountShift); + + if (wordCount == 0 || codePos + wordCount > end) { + LLPC_ERRS("Invalid SPIR-V binary\n"); + return createResultError(Result::ErrorInvalidShader); + } + bool skip = false; switch (opCode) { case OpSource: @@ -496,8 +503,12 @@ Result ShaderModuleHelper::getModuleData(const ShaderModuleBuildInfo *shaderInfo if (moduleData.binType == BinaryType::Spirv) { moduleData.usage = ShaderModuleHelper::getShaderModuleUsageInfo(&shaderBinary); - moduleData.binCode = getShaderCode(shaderInfo, codeBuffer); moduleData.usage.isInternalRtShader = shaderInfo->options.pipelineOptions.internalRtShaders; + auto codeOrErr = getShaderCode(shaderInfo, codeBuffer); + if (Error err = codeOrErr.takeError()) { + return errorToResult(std::move(err)); + } + moduleData.binCode = *codeOrErr; // Calculate SPIR-V cache hash Hash cacheHash = {}; @@ -520,13 +531,16 @@ Result ShaderModuleHelper::getModuleData(const ShaderModuleBuildInfo *shaderInfo // @param shaderInfo : Shader module build info // @param codeBuffer [out] : A buffer to hold the shader code. // @return : The BinaryData for the shaderCode written to codeBuffer. -BinaryData ShaderModuleHelper::getShaderCode(const ShaderModuleBuildInfo *shaderInfo, +Expected ShaderModuleHelper::getShaderCode(const ShaderModuleBuildInfo *shaderInfo, MutableArrayRef &codeBuffer) { BinaryData code; const BinaryData &shaderBinary = shaderInfo->shaderBin; bool trimDebugInfo = cl::TrimDebugInfo && !(shaderInfo->options.pipelineOptions.internalRtShaders); if (trimDebugInfo) { - code.codeSize = trimSpirvDebugInfo(&shaderBinary, codeBuffer); + auto sizeOrErr = trimSpirvDebugInfo(&shaderBinary, codeBuffer); + if (Error err = sizeOrErr.takeError()) + return err; + code.codeSize = *sizeOrErr; } else { assert(shaderBinary.codeSize <= codeBuffer.size() * sizeof(codeBuffer.front())); memcpy(codeBuffer.data(), shaderBinary.pCode, shaderBinary.codeSize); @@ -539,7 +553,7 @@ BinaryData ShaderModuleHelper::getShaderCode(const ShaderModuleBuildInfo *shader // ===================================================================================================================== // @param shaderInfo : Shader module build info // @return : The number of bytes need to hold the code for this shader module. -unsigned ShaderModuleHelper::getCodeSize(const ShaderModuleBuildInfo *shaderInfo) { +Expected ShaderModuleHelper::getCodeSize(const ShaderModuleBuildInfo *shaderInfo) { const BinaryData &shaderBinary = shaderInfo->shaderBin; bool trimDebugInfo = cl::TrimDebugInfo && !(shaderInfo->options.pipelineOptions.internalRtShaders); if (!trimDebugInfo) diff --git a/llpc/util/llpcShaderModuleHelper.h b/llpc/util/llpcShaderModuleHelper.h index 9a07bf1c87..f7a7b22304 100644 --- a/llpc/util/llpcShaderModuleHelper.h +++ b/llpc/util/llpcShaderModuleHelper.h @@ -32,6 +32,7 @@ #pragma once #include "llpc.h" #include +#include #include namespace Llpc { @@ -56,7 +57,7 @@ class ShaderModuleHelper { public: static ShaderModuleUsage getShaderModuleUsageInfo(const BinaryData *spvBinCode); - static unsigned trimSpirvDebugInfo(const BinaryData *spvBin, llvm::MutableArrayRef codeBuffer); + static llvm::Expected trimSpirvDebugInfo(const BinaryData *spvBin, llvm::MutableArrayRef codeBuffer); static Result optimizeSpirv(const BinaryData *spirvBinIn, BinaryData *spirvBinOut); @@ -70,8 +71,8 @@ class ShaderModuleHelper { static Result getShaderBinaryType(BinaryData shaderBinary, BinaryType &binaryType); static Result getModuleData(const ShaderModuleBuildInfo *shaderInfo, llvm::MutableArrayRef codeBuffer, Vkgc::ShaderModuleData &moduleData); - static unsigned getCodeSize(const ShaderModuleBuildInfo *shaderInfo); - static BinaryData getShaderCode(const ShaderModuleBuildInfo *shaderInfo, + static llvm::Expected getCodeSize(const ShaderModuleBuildInfo *shaderInfo); + static llvm::Expected getShaderCode(const ShaderModuleBuildInfo *shaderInfo, llvm::MutableArrayRef &codeBuffer); };