Skip to content

Commit

Permalink
[Diag] Diagnose missing numthreads in compute shaders
Browse files Browse the repository at this point in the history
Currently we error out in validation if a compute shader doesn't have
a numthreads attribute:

```
<source>:2: error: Declared Thread Group X size 2833376888 outside valid range [1..1024].
<source>:2: error: Declared Thread Group Y size 21956 outside valid range [1..1024].
<source>:2: error: Declared Thread Group Z size 2833381336 outside valid range [1..64].
<source>:2: error: Declared Thread Group Count 3863291136 (X*Y*Z) is beyond the valid maximum of 1024.
```

Instead, diagnose this in Sema like we do for mesh and amplitude shaders

Reviewers: damyanp, bob80905, pow2clk, llvm-beanz

Reviewed By: llvm-beanz, bob80905, damyanp, pow2clk

Pull Request: microsoft#6021
  • Loading branch information
bogner authored Nov 16, 2023
1 parent e3c3114 commit 7c33007
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 15 deletions.
17 changes: 5 additions & 12 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12584,19 +12584,12 @@ void SpirvEmitter::processPixelShaderAttributes(const FunctionDecl *decl) {
}

void SpirvEmitter::processComputeShaderAttributes(const FunctionDecl *decl) {
// If not explicitly specified, x, y, and z should be defaulted to 1.
uint32_t x = 1, y = 1, z = 1;
auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>();
assert(numThreadsAttr && "thread group size missing from entry-point");

if (auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>()) {
x = static_cast<uint32_t>(numThreadsAttr->getX());
y = static_cast<uint32_t>(numThreadsAttr->getY());
z = static_cast<uint32_t>(numThreadsAttr->getZ());
} else {
emitError("thread group size [numthreads(x,y,z)] is missing from the "
"entry-point function",
decl->getLocation());
return;
}
uint32_t x = static_cast<uint32_t>(numThreadsAttr->getX());
uint32_t y = static_cast<uint32_t>(numThreadsAttr->getY());
uint32_t z = static_cast<uint32_t>(numThreadsAttr->getZ());

spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
{x, y, z}, decl->getLocation());
Expand Down
3 changes: 3 additions & 0 deletions tools/clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15425,6 +15425,9 @@ void DiagnoseGeometryEntry(Sema &S, FunctionDecl *FD,
void DiagnoseComputeEntry(Sema &S, FunctionDecl *FD, llvm::StringRef StageName,
bool isActiveEntry) {
if (isActiveEntry) {
if (!(FD->getAttr<HLSLNumThreadsAttr>()))
S.Diags.Report(FD->getLocation(), diag::err_hlsl_missing_attr)
<< StageName << "numthreads";
if (auto WaveSizeAttr = FD->getAttr<HLSLWaveSizeAttr>()) {
std::string profile = S.getLangOpts().HLSLProfile;
const ShaderModel *SM = hlsl::ShaderModel::GetByName(profile.c_str());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ void entryHistogram(uint3 id: SV_DispatchThreadID, uint idx: SV_GroupIndex)
{
}

// CHECK: 11:6: error: thread group size [numthreads(x,y,z)] is missing from the entry-point function
// CHECK: 11:6: error: compute entry point must have the numthreads attribute
[shader("compute")]
void entryAverage(uint3 id: SV_DispatchThreadID, uint idx: SV_GroupIndex)
{
}

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: not %dxc -T cs_6_0 -E main -fcgl %s -spirv 2>&1 | FileCheck %s

// CHECK: 4:6: error: thread group size [numthreads(x,y,z)] is missing from the entry-point function
// CHECK: 4:6: error: compute entry point must have the numthreads attribute
void main() {}
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,7 @@ void MSmain(out vertices myvert verts[32],
myvert v = {0.0, 0.0, 0.0, 0.0};
verts[ix] = v;
}

// expected-error@+2{{compute entry point must have the numthreads attribute}}
[shader("compute")]
void CSmain() {}

0 comments on commit 7c33007

Please sign in to comment.