Skip to content

Commit

Permalink
Fix bad WaveSize non-range PSV0 and improve things
Browse files Browse the repository at this point in the history
PSV0 MaximumExpectedWaveLaneCount was incorrectly set to 0 for non-range.

Create struct for WaveSize in DxilFunctionProps.h.
Centralize encoding and validation logic there.
Use validation logic in both SemaHLSL and DxilValidation.

Remove test requiring newer shader model in CodeGenHLSL.
Add comprehensive test for compute and node, cs and lib targets, SM 6.6 and 6.8, testing ast, metadata and RDAT.
Add PSV0 tests to catch incorrect runtime data.

Update validation rules and test for more cases.
  • Loading branch information
tex3d committed Jan 22, 2024
1 parent 00e170f commit b7172ff
Show file tree
Hide file tree
Showing 27 changed files with 719 additions and 218 deletions.
9 changes: 6 additions & 3 deletions docs/DXIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3261,11 +3261,14 @@ SM.TRIOUTPUTPRIMITIVEMISMATCH Hull Shader declared with Tri Domain m
SM.UNDEFINEDOUTPUT Not all elements of output %0 were written.
SM.VALIDDOMAIN Invalid Tessellator Domain specified. Must be isoline, tri or quad.
SM.VIEWIDNEEDSSLOT ViewID requires compatible space in pixel shader input signature
SM.WAVESIZEMINGEQMAX Declared Minimum WaveSize %0 greater or equal to declared Maximum Wavesize %1
SM.WAVESIZEALLZEROWHENUNDEFINED WaveSize Max and Preferred must be 0 when Min is 0
SM.WAVESIZEMAXANDPREFERREDZEROWHENNORANGE WaveSize Max and Preferred must be 0 to encode min==max
SM.WAVESIZEMAXGREATERTHANMIN WaveSize Max must greater than Min
SM.WAVESIZENEEDSDXIL16PLUS WaveSize is valid only for DXIL version 1.6 and higher.
SM.WAVESIZEPREFERREDOUTOFRANGE Preferred WaveSize %0 outside valid range [%1..%2]
SM.WAVESIZEONCOMPUTEORNODE WaveSize only allowed on compute or node shaders
SM.WAVESIZEPREFERREDINRANGE WaveSize Preferred must be within Min..Max range
SM.WAVESIZERANGENEEDSDXIL18PLUS WaveSize Range is valid only for DXIL version 1.8 and higher.
SM.WAVESIZEVALUE Declared WaveSize %0 outside valid range [%1..%2], or not a power of 2.
SM.WAVESIZEVALUE WaveSize value must be a power of 2 in range [4..128]
SM.ZEROHSINPUTCONTROLPOINTWITHINPUT When HS input control point count is 0, no input signature should exist.
TYPES.DEFINED Type must be defined based on DXIL primitives
TYPES.I8 I8 can only be used as immediate value for intrinsic or as i8* via bitcast by lifetime intrinsics.
Expand Down
20 changes: 0 additions & 20 deletions include/dxc/DXIL/DxilConstants.h
Original file line number Diff line number Diff line change
Expand Up @@ -448,26 +448,6 @@ inline bool IsFeedbackTexture(DXIL::ResourceKind ResourceKind) {
ResourceKind == DXIL::ResourceKind::FeedbackTexture2DArray;
}

inline bool isPowerOf2(unsigned x) { return (x & (x - 1)) == 0; }

inline bool IsValidWaveSizeValue(unsigned min_wave, unsigned max_wave,
unsigned pref_wave) {
// must be power of 2 between 4 and 128
bool minIsValid = min_wave >= kMinWaveSize && min_wave <= kMaxWaveSize &&
isPowerOf2(min_wave);
if (max_wave == 0)
return minIsValid;

bool maxIsValid = max_wave >= kMinWaveSize && max_wave <= kMaxWaveSize &&
isPowerOf2(max_wave);
// 0 is a valid value for the preferred wave size
bool prefIsValid =
pref_wave == 0 || (pref_wave >= kMinWaveSize &&
pref_wave <= kMaxWaveSize && isPowerOf2(pref_wave));

return minIsValid && maxIsValid && prefIsValid;
}

// TODO: change opcodes.
/* <py::lines('OPCODE-ENUM')>hctdb_instrhelp.get_enum_decl("OpCode")</py>*/
// OPCODE-ENUM:BEGIN
Expand Down
96 changes: 88 additions & 8 deletions include/dxc/DXIL/DxilFunctionProps.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,93 @@ class Constant;
} // namespace llvm

namespace hlsl {

// SM 6.6 allows WaveSize specification for only a single required size.
// SM 6.8+ allows specification of WaveSize as a min, max and preferred value.
struct DxilWaveSize {
unsigned Min = 0;
unsigned Max = 0;
unsigned Preferred = 0;

DxilWaveSize() = default;
DxilWaveSize(unsigned min, unsigned max = 0, unsigned preferred = 0)
: Min(min), Max(max), Preferred(preferred) {}
DxilWaveSize(const DxilWaveSize &other) = default;
DxilWaveSize &operator=(const DxilWaveSize &other) = default;
bool operator==(const DxilWaveSize &other) const {
return Min == other.Min && Max == other.Max && Preferred == other.Preferred;
}

// Encode valid DxilWaveSize, handling potential degenerate cases.
static DxilWaveSize Encode(unsigned min, unsigned max = 0,
unsigned preferred = 0) {
if (max == min)
max = 0;
if (max == 0 && preferred == min)
preferred = 0;
return DxilWaveSize(min, max, preferred);
}

// Valid non-zero values are powers of 2 between 4 and 128, inclusive.
static bool IsValidValue(unsigned Value) {
return (Value >= 4 && Value <= 128 && ((Value & (Value - 1)) == 0));
}
// Valid encodings:
// 0, 0, 0: Not defined
// Min, 0, 0: single WaveSize (SM 6.6/6.7)
// Min, Max (> Min), 0 or Preferred (>= Min and <= Max): Range (SM 6.8+)
enum class ValidationResult {
Success,
InvalidMin,
InvalidMax,
InvalidPreferred,
MaxOrPreferredWhenUndefined,
PreferredWhenNoRange,
MaxEqualsMin,
MaxLessThanMin,
PreferredOutOfRange,
};
ValidationResult Validate() const {
if (Min == 0) { // Not defined
if (Max != 0 || Preferred != 0)
return ValidationResult::MaxOrPreferredWhenUndefined;
} else if (!IsValidValue(Min)) {
return ValidationResult::InvalidMin;
} else if (Max == 0) { // single WaveSize (SM 6.6/6.7)
if (Preferred != 0)
return ValidationResult::PreferredWhenNoRange;
} else if (!IsValidValue(Max)) {
return ValidationResult::InvalidMax;
} else if (Min == Max) {
return ValidationResult::MaxEqualsMin;
} else if (Max < Min) {
return ValidationResult::MaxLessThanMin;
} else if (Preferred != 0) {
if (!IsValidValue(Preferred))
return ValidationResult::InvalidPreferred;
if (Preferred < Min || Preferred > Max)
return ValidationResult::PreferredOutOfRange;
}
return ValidationResult::Success;
}
bool IsValid() const { return Validate() == ValidationResult::Success; }

bool IsDefined() const { return Min != 0; }
bool IsRange() const { return Max != 0; }
bool HasPreferred() const { return Preferred != 0; }

// Decode for range used for runtime data.
// Writes results and returns true if a valid size or range was defined.
bool DecodeMinMax(unsigned &min, unsigned &max) const {
if (!IsValid() || !IsDefined()) {
return false;
}
min = Min;
max = IsRange() ? Max : Min;
return true;
}
};

struct DxilFunctionProps {
DxilFunctionProps() {
memset(&ShaderProps, 0, sizeof(ShaderProps));
Expand All @@ -34,9 +121,6 @@ struct DxilFunctionProps {
memset(&Node, 0, sizeof(Node));
Node.LaunchType = DXIL::NodeLaunchType::Invalid;
Node.LocalRootArgumentsTableIndex = -1;
waveMinSize = 0;
waveMaxSize = 0;
wavePreferredSize = 0;
}
union {
// Geometry shader.
Expand Down Expand Up @@ -108,12 +192,8 @@ struct DxilFunctionProps {
NodeID NodeShaderSharedInput;
std::vector<NodeIOProperties> InputNodes;
std::vector<NodeIOProperties> OutputNodes;
DxilWaveSize WaveSize;

// SM 6.6 allows WaveSize specification for only a single required size. SM
// 6.8+ allows specification of WaveSize as a min, max and preferred value.
unsigned waveMinSize;
unsigned waveMaxSize;
unsigned wavePreferredSize;
// Save root signature for lib profile entry.
std::vector<uint8_t> serializedRootSignature;
void SetSerializedRootSignature(const uint8_t *pData, unsigned size) {
Expand Down
5 changes: 2 additions & 3 deletions include/dxc/DXIL/DxilModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,8 @@ class DxilModule {
unsigned GetNumThreads(unsigned idx) const;

// Compute shader
void SetWaveSize(unsigned size);
unsigned GetMinWaveSize() const;
unsigned GetMaxWaveSize() const;
DxilWaveSize &GetWaveSize();
const DxilWaveSize &GetWaveSize() const;

// Geometry shader.
DXIL::InputPrimitive GetInputPrimitive() const;
Expand Down
40 changes: 19 additions & 21 deletions lib/DXIL/DxilMetadataHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1596,19 +1596,18 @@ MDTuple *DxilMDHelper::EmitDxilEntryProperties(uint64_t rawShaderFlag,
NumThreadVals.emplace_back(Uint32ToConstMD(props.numThreads[2]));
MDVals.emplace_back(MDNode::get(m_Ctx, NumThreadVals));

if (props.waveMinSize != 0) {
bool UseRange = props.waveMaxSize != 0;
if (UseRange)
if (props.WaveSize.IsDefined()) {
if (props.WaveSize.IsRange())
DXASSERT(DXIL::CompareVersions(m_MinValMajor, m_MinValMinor, 1, 8) >= 0,
"DXIL version must be > 1.8");
MDVals.emplace_back(
Uint32ToConstMD(UseRange ? DxilMDHelper::kDxilRangedWaveSizeTag
MDVals.emplace_back(Uint32ToConstMD(
props.WaveSize.IsRange() ? DxilMDHelper::kDxilRangedWaveSizeTag
: DxilMDHelper::kDxilWaveSizeTag));
SmallVector<Metadata *, 3> WaveSizeVal;
WaveSizeVal.emplace_back(Uint32ToConstMD(props.waveMinSize));
if (UseRange) {
WaveSizeVal.emplace_back(Uint32ToConstMD(props.waveMaxSize));
WaveSizeVal.emplace_back(Uint32ToConstMD(props.wavePreferredSize));
WaveSizeVal.emplace_back(Uint32ToConstMD(props.WaveSize.Min));
if (props.WaveSize.IsRange()) {
WaveSizeVal.emplace_back(Uint32ToConstMD(props.WaveSize.Max));
WaveSizeVal.emplace_back(Uint32ToConstMD(props.WaveSize.Preferred));
}
MDVals.emplace_back(MDNode::get(m_Ctx, WaveSizeVal));
}
Expand Down Expand Up @@ -1832,17 +1831,17 @@ void DxilMDHelper::LoadDxilEntryProperties(const MDOperand &MDO,
} break;
case DxilMDHelper::kDxilWaveSizeTag: {
MDNode *pNode = cast<MDNode>(MDO.get());
props.waveMinSize = ConstMDToUint32(pNode->getOperand(0));
props.WaveSize.Min = ConstMDToUint32(pNode->getOperand(0));
} break;
case DxilMDHelper::kDxilRangedWaveSizeTag: {
// if we're here, we're using the range variant.
// Extra metadata is used if SM < 6.8
if (!m_pSM->IsSMAtLeast(6, 8))
m_bExtraMetadata = true;
MDNode *pNode = cast<MDNode>(MDO.get());
props.waveMinSize = ConstMDToUint32(pNode->getOperand(0));
props.waveMaxSize = ConstMDToUint32(pNode->getOperand(1));
props.wavePreferredSize = ConstMDToUint32(pNode->getOperand(2));
props.WaveSize.Min = ConstMDToUint32(pNode->getOperand(0));
props.WaveSize.Max = ConstMDToUint32(pNode->getOperand(1));
props.WaveSize.Preferred = ConstMDToUint32(pNode->getOperand(2));
} break;
case DxilMDHelper::kDxilEntryRootSigTag: {
MDNode *pNode = cast<MDNode>(MDO.get());
Expand Down Expand Up @@ -2663,16 +2662,15 @@ void DxilMDHelper::EmitDxilNodeState(std::vector<llvm::Metadata *> &MDVals,

// Optional Fields

if (props.waveMinSize != 0) {
bool UseRange = props.waveMaxSize != 0;
MDVals.emplace_back(
Uint32ToConstMD(UseRange ? DxilMDHelper::kDxilRangedWaveSizeTag
if (props.WaveSize.IsDefined()) {
MDVals.emplace_back(Uint32ToConstMD(
props.WaveSize.IsRange() ? DxilMDHelper::kDxilRangedWaveSizeTag
: DxilMDHelper::kDxilWaveSizeTag));
SmallVector<Metadata *, 3> WaveSizeVal;
WaveSizeVal.emplace_back(Uint32ToConstMD(props.waveMinSize));
if (UseRange) {
WaveSizeVal.emplace_back(Uint32ToConstMD(props.waveMaxSize));
WaveSizeVal.emplace_back(Uint32ToConstMD(props.wavePreferredSize));
WaveSizeVal.emplace_back(Uint32ToConstMD(props.WaveSize.Min));
if (props.WaveSize.IsRange()) {
WaveSizeVal.emplace_back(Uint32ToConstMD(props.WaveSize.Max));
WaveSizeVal.emplace_back(Uint32ToConstMD(props.WaveSize.Preferred));
}
MDVals.emplace_back(MDNode::get(m_Ctx, WaveSizeVal));
}
Expand Down
28 changes: 6 additions & 22 deletions lib/DXIL/DxilModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,32 +395,16 @@ unsigned DxilModule::GetNumThreads(unsigned idx) const {
return props.numThreads[idx];
}

void DxilModule::SetWaveSize(unsigned size) {
DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsCS(),
"only works for CS profile");
DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
DXASSERT_NOMSG(m_pSM->GetKind() == props.shaderKind);
props.waveMinSize = size;
DxilWaveSize &DxilModule::GetWaveSize() {
return const_cast<DxilWaveSize &>(
static_cast<const DxilModule *>(this)->GetWaveSize());
}

unsigned DxilModule::GetMinWaveSize() const {
const DxilWaveSize &DxilModule::GetWaveSize() const {
DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsCS(),
"only works for CS profiles");
if (!m_pSM->IsCS())
return 0;
const DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
DXASSERT_NOMSG(m_pSM->GetKind() == props.shaderKind);
return props.waveMinSize;
}

unsigned DxilModule::GetMaxWaveSize() const {
DXASSERT(m_DxilEntryPropsMap.size() == 1 && m_pSM->IsCS(),
"only works for CS profiles");
if (!m_pSM->IsCS())
return 0;
"only works for CS profile");
const DxilFunctionProps &props = m_DxilEntryPropsMap.begin()->second->props;
DXASSERT_NOMSG(m_pSM->GetKind() == props.shaderKind);
return props.waveMaxSize;
return props.WaveSize;
}

DXIL::InputPrimitive DxilModule::GetInputPrimitive() const {
Expand Down
37 changes: 28 additions & 9 deletions lib/DxilContainer/DxilContainerAssembler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -941,11 +941,20 @@ class DxilPSVWriter : public DxilPartWriter {
break;
}
case ShaderModel::Kind::Compute: {
UINT waveMinSize = (UINT)m_Module.GetMinWaveSize();
UINT waveMaxSize = (UINT)m_Module.GetMaxWaveSize();
if (waveMinSize != 0) {
pInfo->MinimumExpectedWaveLaneCount = waveMinSize;
pInfo->MaximumExpectedWaveLaneCount = waveMaxSize;
const DxilWaveSize &waveSize = m_Module.GetWaveSize();
// Can't assert valid or validation tests will assert during assembly
// DXASSERT(waveSize.IsValid(), "wave size should be valid");
DXASSERT(!waveSize.IsDefined() || m_PSVInitInfo.PSVVersion >= 2,
"wave size requires SM 6.6 or above");
DXASSERT(!waveSize.IsRange() || m_PSVInitInfo.PSVVersion >= 3,
"wave size range requires SM 6.8 or above");
if (waveSize.IsDefined() && m_PSVInitInfo.PSVVersion >= 2) {
unsigned waveSizeMin = 0, waveSizeMax = 0;
waveSize.DecodeMinMax(waveSizeMin, waveSizeMax);
if (m_PSVInitInfo.PSVVersion < 3)
waveSizeMax = waveSizeMin;
pInfo->MinimumExpectedWaveLaneCount = waveSizeMin;
pInfo->MaximumExpectedWaveLaneCount = waveSizeMax;
}
break;
}
Expand Down Expand Up @@ -1800,6 +1809,7 @@ class DxilRDATWriter : public DxilPartWriter {
? &info_latest
: nullptr;
ShaderFlags flags = ShaderFlags::CollectShaderFlags(&function, &DM);
DxilWaveSize waveSize;
if (DM.HasDxilFunctionProps(&function)) {
const auto &props = DM.GetDxilFunctionProps(&function);
if (props.IsClosestHit() || props.IsAnyHit()) {
Expand All @@ -1811,12 +1821,21 @@ class DxilRDATWriter : public DxilPartWriter {
payloadSizeInBytes = props.ShaderProps.Ray.paramSizeInBytes;
}
shaderKind = (uint32_t)props.shaderKind;
waveSize = props.WaveSize;
// Can't assert valid or validation tests will assert during assembly
// DXASSERT(waveSize.IsValid(), "wave size should be valid");
DXASSERT(!waveSize.IsDefined() ||
DXIL::CompareVersions(m_ValMajor, m_ValMinor, 1, 6) >= 0,
"wave size requires SM 6.6 or above");
DXASSERT(!waveSize.IsRange() ||
DXIL::CompareVersions(m_ValMajor, m_ValMinor, 1, 8) >= 0,
"wave size range requires SM 6.8 or above");
if (pInfo2 && DM.HasDxilEntryProps(&function)) {
const auto &entryProps = DM.GetDxilEntryProps(&function);
pInfo2->MinimumExpectedWaveLaneCount = entryProps.props.waveMinSize;
pInfo2->MaximumExpectedWaveLaneCount =
entryProps.props.waveMaxSize > 0 ? entryProps.props.waveMaxSize
: entryProps.props.waveMinSize;
unsigned waveSizeMin = 0, waveSizeMax = 0;
waveSize.DecodeMinMax(waveSizeMin, waveSizeMax);
pInfo2->MinimumExpectedWaveLaneCount = waveSizeMin;
pInfo2->MaximumExpectedWaveLaneCount = waveSizeMax;
pInfo2->ShaderFlags = 0;
if (entryProps.props.IsNode()) {
shaderInfo = AddShaderNodeInfo(DM, function, entryProps, *pInfo2,
Expand Down
Loading

0 comments on commit b7172ff

Please sign in to comment.