diff --git a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp index 7576d8d870..12a495d97c 100644 --- a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp +++ b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp @@ -2129,6 +2129,12 @@ bool DeclResultIdMapper::createStageVars( noWriteBack, /*vecComponent=*/nullptr, loc)) return true; + auto evalElemType = + hlsl::IsHLSLVecType(type) ? hlsl::GetHLSLVecElementType(type) : type; + const auto vecSizeOfType = + hlsl::IsHLSLVecType(type) ? hlsl::GetHLSLVecSize(type) : 1; + bool differentTypeIsUsedForEvalType = false; + switch (semanticKind) { case hlsl::Semantic::Kind::DomainLocation: evalType = astContext.getExtVectorType(astContext.FloatTy, 3); @@ -2157,12 +2163,19 @@ bool DeclResultIdMapper::createStageVars( evalType = astContext.getExtVectorType(astContext.FloatTy, 2); break; case hlsl::Semantic::Kind::DispatchThreadID: + // Based on SPIR-V spec, we have to always use a vector with 3 int + // elements for DispatchThreadID. Therefore, we use `astContext.IntTy` + // instead of `type`. + if (!evalElemType->isIntegerType()) { + evalElemType = astContext.IntTy; + differentTypeIsUsedForEvalType = true; + } + evalType = astContext.getExtVectorType(evalElemType, 3); + break; case hlsl::Semantic::Kind::GroupThreadID: case hlsl::Semantic::Kind::GroupID: // Keep the original integer signedness - evalType = astContext.getExtVectorType( - hlsl::IsHLSLVecType(type) ? hlsl::GetHLSLVecElementType(type) : type, - 3); + evalType = astContext.getExtVectorType(evalElemType, 3); break; case hlsl::Semantic::Kind::ShadingRate: evalType = astContext.getExtVectorType(astContext.IntTy, 2); @@ -2392,17 +2405,12 @@ bool DeclResultIdMapper::createStageVars( semanticKind == hlsl::Semantic::Kind::GroupID) && (!hlsl::IsHLSLVecType(type) || hlsl::GetHLSLVecSize(type) != 3)) { - const auto srcVecElemType = hlsl::IsHLSLVecType(type) - ? hlsl::GetHLSLVecElementType(type) - : type; - const auto vecSize = - hlsl::IsHLSLVecType(type) ? hlsl::GetHLSLVecSize(type) : 1; - if (vecSize == 1) - *value = spvBuilder.createCompositeExtract(srcVecElemType, *value, - {0}, thisSemantic.loc); - else if (vecSize == 2) + if (vecSizeOfType == 1) + *value = spvBuilder.createCompositeExtract(evalElemType, *value, {0}, + thisSemantic.loc); + else if (vecSizeOfType == 2) *value = spvBuilder.createVectorShuffle( - astContext.getExtVectorType(srcVecElemType, 2), *value, *value, + astContext.getExtVectorType(evalElemType, 2), *value, *value, {0, 1}, thisSemantic.loc); } // Special handling of SV_ShadingRate, which is a bitpacked enum value, @@ -2521,6 +2529,22 @@ bool DeclResultIdMapper::createStageVars( } } + // When it is a special stage variable that the given QualType is not + // allowed based on SPIR-V spec, we use a evalType different from the given + // AST type for stageVar. To follow the given HLSL context, We have to + // conduct the proper type-cast. + if (differentTypeIsUsedForEvalType) { + auto *stageVarWithGivenAstType = + spvBuilder.addFnVar(type, thisSemantic.loc, name); + auto *castInst = + theEmitter.castToType(*value, evalType, type, thisSemantic.loc); + spvBuilder.createStore(stageVarWithGivenAstType, castInst, + thisSemantic.loc); + stageVarInstructions[cast(decl)] = + stageVarWithGivenAstType; + *value = spvBuilder.createLoad(type, stageVarWithGivenAstType, loc); + } + return true; } diff --git a/tools/clang/test/CodeGenSPIRV/spirv.interface.cs.float.dispatch-thread-id.hlsl b/tools/clang/test/CodeGenSPIRV/spirv.interface.cs.float.dispatch-thread-id.hlsl new file mode 100644 index 0000000000..c5888f4c2c --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/spirv.interface.cs.float.dispatch-thread-id.hlsl @@ -0,0 +1,19 @@ +// Run: %dxc -T cs_6_0 -E main + +// CHECK: %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3int Input + +// CHECK: %in_var_SV_DispatchThreadID = OpVariable %_ptr_Function_v2float Function + +// CHECK: [[load:%\d+]] = OpLoad %v3int %gl_GlobalInvocationID +// CHECK-NEXT: [[shuffle:%\d+]] = OpVectorShuffle %v2int [[load]] [[load]] 0 1 +// CHECK-NEXT: [[cast:%\d+]] = OpConvertSToF %v2float [[shuffle]] +// CHECK-NEXT: OpStore %in_var_SV_DispatchThreadID [[cast]] +// CHECK-NEXT: OpLoad %v2float %in_var_SV_DispatchThreadID + +RWStructuredBuffer rwTexture; + +[numthreads(1, 1, 1)] +void main(float2 id : SV_DispatchThreadID) +{ + rwTexture[3] = id.xxxx; +} diff --git a/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp b/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp index 37d92aa22f..21667aa7b7 100644 --- a/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp +++ b/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp @@ -1561,6 +1561,10 @@ TEST_F(FileTest, SpirvStageIO16bitTypes) { runFileTest("spirv.stage-io.16bit.hlsl"); } +TEST_F(FileTest, DispatchThreadIdWithFloatType) { + runFileTest("spirv.interface.cs.float.dispatch-thread-id.hlsl"); +} + TEST_F(FileTest, SpirvInterpolationPS) { runFileTest("spirv.interpolation.ps.hlsl"); }