Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[spirv] support non-int type SV_DispatchThreadId #3481

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 37 additions & 13 deletions tools/clang/lib/SPIRV/DeclResultIdMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2129,6 +2129,12 @@ bool DeclResultIdMapper::createStageVars(
noWriteBack, /*vecComponent=*/nullptr, loc))
return true;

auto evalElemType =
hlsl::IsHLSLVecType(type) ? hlsl::GetHLSLVecElementType(type) : type;
const auto vecSizeOfType =
hlsl::IsHLSLVecType(type) ? hlsl::GetHLSLVecSize(type) : 1;
bool differentTypeIsUsedForEvalType = false;

switch (semanticKind) {
case hlsl::Semantic::Kind::DomainLocation:
evalType = astContext.getExtVectorType(astContext.FloatTy, 3);
Expand Down Expand Up @@ -2157,12 +2163,19 @@ bool DeclResultIdMapper::createStageVars(
evalType = astContext.getExtVectorType(astContext.FloatTy, 2);
break;
case hlsl::Semantic::Kind::DispatchThreadID:
// Based on SPIR-V spec, we have to always use a vector with 3 int
// elements for DispatchThreadID. Therefore, we use `astContext.IntTy`
// instead of `type`.
if (!evalElemType->isIntegerType()) {
evalElemType = astContext.IntTy;
differentTypeIsUsedForEvalType = true;
}
evalType = astContext.getExtVectorType(evalElemType, 3);
break;
case hlsl::Semantic::Kind::GroupThreadID:
case hlsl::Semantic::Kind::GroupID:
// Keep the original integer signedness
evalType = astContext.getExtVectorType(
hlsl::IsHLSLVecType(type) ? hlsl::GetHLSLVecElementType(type) : type,
3);
evalType = astContext.getExtVectorType(evalElemType, 3);
break;
case hlsl::Semantic::Kind::ShadingRate:
evalType = astContext.getExtVectorType(astContext.IntTy, 2);
Expand Down Expand Up @@ -2392,17 +2405,12 @@ bool DeclResultIdMapper::createStageVars(
semanticKind == hlsl::Semantic::Kind::GroupID) &&
(!hlsl::IsHLSLVecType(type) ||
hlsl::GetHLSLVecSize(type) != 3)) {
const auto srcVecElemType = hlsl::IsHLSLVecType(type)
? hlsl::GetHLSLVecElementType(type)
: type;
const auto vecSize =
hlsl::IsHLSLVecType(type) ? hlsl::GetHLSLVecSize(type) : 1;
if (vecSize == 1)
*value = spvBuilder.createCompositeExtract(srcVecElemType, *value,
{0}, thisSemantic.loc);
else if (vecSize == 2)
if (vecSizeOfType == 1)
*value = spvBuilder.createCompositeExtract(evalElemType, *value, {0},
thisSemantic.loc);
else if (vecSizeOfType == 2)
*value = spvBuilder.createVectorShuffle(
astContext.getExtVectorType(srcVecElemType, 2), *value, *value,
astContext.getExtVectorType(evalElemType, 2), *value, *value,
{0, 1}, thisSemantic.loc);
}
// Special handling of SV_ShadingRate, which is a bitpacked enum value,
Expand Down Expand Up @@ -2521,6 +2529,22 @@ bool DeclResultIdMapper::createStageVars(
}
}

// When it is a special stage variable that the given QualType is not
// allowed based on SPIR-V spec, we use a evalType different from the given
// AST type for stageVar. To follow the given HLSL context, We have to
// conduct the proper type-cast.
if (differentTypeIsUsedForEvalType) {
auto *stageVarWithGivenAstType =
spvBuilder.addFnVar(type, thisSemantic.loc, name);
auto *castInst =
theEmitter.castToType(*value, evalType, type, thisSemantic.loc);
spvBuilder.createStore(stageVarWithGivenAstType, castInst,
thisSemantic.loc);
stageVarInstructions[cast<DeclaratorDecl>(decl)] =
stageVarWithGivenAstType;
*value = spvBuilder.createLoad(type, stageVarWithGivenAstType, loc);
}

return true;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Run: %dxc -T cs_6_0 -E main

// CHECK: %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3int Input

// CHECK: %in_var_SV_DispatchThreadID = OpVariable %_ptr_Function_v2float Function

// CHECK: [[load:%\d+]] = OpLoad %v3int %gl_GlobalInvocationID
// CHECK-NEXT: [[shuffle:%\d+]] = OpVectorShuffle %v2int [[load]] [[load]] 0 1
// CHECK-NEXT: [[cast:%\d+]] = OpConvertSToF %v2float [[shuffle]]
// CHECK-NEXT: OpStore %in_var_SV_DispatchThreadID [[cast]]
// CHECK-NEXT: OpLoad %v2float %in_var_SV_DispatchThreadID

RWStructuredBuffer<float4> rwTexture;

[numthreads(1, 1, 1)]
void main(float2 id : SV_DispatchThreadID)
{
rwTexture[3] = id.xxxx;
}
4 changes: 4 additions & 0 deletions tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1561,6 +1561,10 @@ TEST_F(FileTest, SpirvStageIO16bitTypes) {
runFileTest("spirv.stage-io.16bit.hlsl");
}

TEST_F(FileTest, DispatchThreadIdWithFloatType) {
runFileTest("spirv.interface.cs.float.dispatch-thread-id.hlsl");
}

TEST_F(FileTest, SpirvInterpolationPS) {
runFileTest("spirv.interpolation.ps.hlsl");
}
Expand Down