Skip to content

Commit

Permalink
Prevent trimSpirvDebugInfo from looping indefinitely on invalid input
Browse files Browse the repository at this point in the history
An invalid SPIR-V binary might contain data that does not match any opCode, and yields a wordCount of zero, triggering an infinite loop.

trimSpirvDebugInfo now returns Expected<unsigned> and will give an error in such a situation.
Caller functions getShaderCode and getCodeSize (as well as their callers BuildShaderModule and getModuleData) have been updated to handle the error.
  • Loading branch information
5p4k committed Dec 12, 2023
1 parent f8feeaa commit 8ae925f
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
6 changes: 5 additions & 1 deletion llpc/context/llpcCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,11 @@ Result Compiler::BuildShaderModule(const ShaderModuleBuildInfo *shaderInfo, Shad
return Result::ErrorInvalidPointer;
}

unsigned codeSize = ShaderModuleHelper::getCodeSize(shaderInfo);
auto codeSizeOrErr = ShaderModuleHelper::getCodeSize(shaderInfo);
if (Error err = codeSizeOrErr.takeError()) {
return errorToResult(std::move(err));
}
const unsigned codeSize = *codeSizeOrErr;
size_t allocSize = sizeof(ShaderModuleData) + codeSize;

ShaderModuleData moduleData = {};
Expand Down
26 changes: 20 additions & 6 deletions llpc/util/llpcShaderModuleHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "llpcShaderModuleHelper.h"
#include "llpcDebug.h"
#include "llpcUtil.h"
#include "llpcError.h"
#include "spirvExt.h"
#include "vkgcUtil.h"
#include "llvm/Support/CommandLine.h"
Expand Down Expand Up @@ -233,8 +234,8 @@ ShaderModuleUsage ShaderModuleHelper::getShaderModuleUsageInfo(const BinaryData
//
// @param spvBin : SPIR-V binary code
// @param codeBuffer : The buffer in which to copy the shader code.
// @returns : The number of bytes written to trimSpvBin
unsigned ShaderModuleHelper::trimSpirvDebugInfo(const BinaryData *spvBin, llvm::MutableArrayRef<unsigned> codeBuffer) {
// @returns : The number of bytes written to trimSpvBin or an error if invalid data was encountered
Expected<unsigned> ShaderModuleHelper::trimSpirvDebugInfo(const BinaryData *spvBin, llvm::MutableArrayRef<unsigned> codeBuffer) {
bool writeCode = !codeBuffer.empty();
assert(codeBuffer.empty() || codeBuffer.size() > sizeof(SpirvHeader));

Expand All @@ -260,6 +261,12 @@ unsigned ShaderModuleHelper::trimSpirvDebugInfo(const BinaryData *spvBin, llvm::
while (codePos < end) {
unsigned opCode = (codePos[0] & OpCodeMask);
unsigned wordCount = (codePos[0] >> WordCountShift);

if (wordCount == 0 || codePos + wordCount > end) {
LLPC_ERRS("Invalid SPIR-V binary\n");
return createResultError(Result::ErrorInvalidShader);
}

bool skip = false;
switch (opCode) {
case OpSource:
Expand Down Expand Up @@ -496,8 +503,12 @@ Result ShaderModuleHelper::getModuleData(const ShaderModuleBuildInfo *shaderInfo

if (moduleData.binType == BinaryType::Spirv) {
moduleData.usage = ShaderModuleHelper::getShaderModuleUsageInfo(&shaderBinary);
moduleData.binCode = getShaderCode(shaderInfo, codeBuffer);
moduleData.usage.isInternalRtShader = shaderInfo->options.pipelineOptions.internalRtShaders;
auto codeOrErr = getShaderCode(shaderInfo, codeBuffer);
if (Error err = codeOrErr.takeError()) {
return errorToResult(std::move(err));
}
moduleData.binCode = *codeOrErr;

// Calculate SPIR-V cache hash
Hash cacheHash = {};
Expand All @@ -520,13 +531,16 @@ Result ShaderModuleHelper::getModuleData(const ShaderModuleBuildInfo *shaderInfo
// @param shaderInfo : Shader module build info
// @param codeBuffer [out] : A buffer to hold the shader code.
// @return : The BinaryData for the shaderCode written to codeBuffer.
BinaryData ShaderModuleHelper::getShaderCode(const ShaderModuleBuildInfo *shaderInfo,
Expected<BinaryData> ShaderModuleHelper::getShaderCode(const ShaderModuleBuildInfo *shaderInfo,
MutableArrayRef<unsigned int> &codeBuffer) {
BinaryData code;
const BinaryData &shaderBinary = shaderInfo->shaderBin;
bool trimDebugInfo = cl::TrimDebugInfo && !(shaderInfo->options.pipelineOptions.internalRtShaders);
if (trimDebugInfo) {
code.codeSize = trimSpirvDebugInfo(&shaderBinary, codeBuffer);
auto sizeOrErr = trimSpirvDebugInfo(&shaderBinary, codeBuffer);
if (Error err = sizeOrErr.takeError())
return err;
code.codeSize = *sizeOrErr;
} else {
assert(shaderBinary.codeSize <= codeBuffer.size() * sizeof(codeBuffer.front()));
memcpy(codeBuffer.data(), shaderBinary.pCode, shaderBinary.codeSize);
Expand All @@ -539,7 +553,7 @@ BinaryData ShaderModuleHelper::getShaderCode(const ShaderModuleBuildInfo *shader
// =====================================================================================================================
// @param shaderInfo : Shader module build info
// @return : The number of bytes need to hold the code for this shader module.
unsigned ShaderModuleHelper::getCodeSize(const ShaderModuleBuildInfo *shaderInfo) {
Expected<unsigned> ShaderModuleHelper::getCodeSize(const ShaderModuleBuildInfo *shaderInfo) {
const BinaryData &shaderBinary = shaderInfo->shaderBin;
bool trimDebugInfo = cl::TrimDebugInfo && !(shaderInfo->options.pipelineOptions.internalRtShaders);
if (!trimDebugInfo)
Expand Down
7 changes: 4 additions & 3 deletions llpc/util/llpcShaderModuleHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#pragma once
#include "llpc.h"
#include <llvm/ADT/ArrayRef.h>
#include <llvm/Support/Error.h>
#include <vector>

namespace Llpc {
Expand All @@ -56,7 +57,7 @@ class ShaderModuleHelper {
public:
static ShaderModuleUsage getShaderModuleUsageInfo(const BinaryData *spvBinCode);

static unsigned trimSpirvDebugInfo(const BinaryData *spvBin, llvm::MutableArrayRef<unsigned> codeBuffer);
static llvm::Expected<unsigned> trimSpirvDebugInfo(const BinaryData *spvBin, llvm::MutableArrayRef<unsigned> codeBuffer);

static Result optimizeSpirv(const BinaryData *spirvBinIn, BinaryData *spirvBinOut);

Expand All @@ -70,8 +71,8 @@ class ShaderModuleHelper {
static Result getShaderBinaryType(BinaryData shaderBinary, BinaryType &binaryType);
static Result getModuleData(const ShaderModuleBuildInfo *shaderInfo, llvm::MutableArrayRef<unsigned> codeBuffer,
Vkgc::ShaderModuleData &moduleData);
static unsigned getCodeSize(const ShaderModuleBuildInfo *shaderInfo);
static BinaryData getShaderCode(const ShaderModuleBuildInfo *shaderInfo,
static llvm::Expected<unsigned> getCodeSize(const ShaderModuleBuildInfo *shaderInfo);
static llvm::Expected<BinaryData> getShaderCode(const ShaderModuleBuildInfo *shaderInfo,
llvm::MutableArrayRef<unsigned int> &codeBuffer);
};

Expand Down

0 comments on commit 8ae925f

Please sign in to comment.