Skip to content

Commit

Permalink
[SPIRV] Generate DebugTypeMatrix (#6757)
Browse files Browse the repository at this point in the history
When the OpenCL.DebugInfo.100 debug info was implemented, there was no
DebugTypeMatrix. Now that NonSemantic.Shader.DebugInfo.100 has been
merged, we should use DebugTypeMatrix. This PR corrects that oversight.
  • Loading branch information
SteveUrquhart authored Jul 24, 2024
1 parent 4a5253d commit e0fbce7
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 8 deletions.
4 changes: 4 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ class SpirvContext {
SpirvDebugInstruction *elemType,
uint32_t elemCount);

SpirvDebugType *getDebugTypeMatrix(const SpirvType *spirvType,
SpirvDebugInstruction *vectorType,
uint32_t vectorCount);

SpirvDebugType *getDebugTypeFunction(const SpirvType *spirvType,
uint32_t flags, SpirvDebugType *ret,
llvm::ArrayRef<SpirvDebugType *> params);
Expand Down
26 changes: 26 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class SpirvInstruction {
IK_DebugTypeBasic,
IK_DebugTypeArray,
IK_DebugTypeVector,
IK_DebugTypeMatrix,
IK_DebugTypeFunction,
IK_DebugTypeComposite,
IK_DebugTypeMember,
Expand Down Expand Up @@ -2879,6 +2880,31 @@ class SpirvDebugTypeVector : public SpirvDebugType {
uint32_t elementCount;
};

/// Represents matrix debug types
class SpirvDebugTypeMatrix : public SpirvDebugType {
public:
SpirvDebugTypeMatrix(SpirvDebugTypeVector *vectorType, uint32_t vectorCount);

DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvDebugTypeMatrix)

static bool classof(const SpirvInstruction *inst) {
return inst->getKind() == IK_DebugTypeMatrix;
}

bool invokeVisitor(Visitor *v) override;

SpirvDebugTypeVector *getVectorType() const { return vectorType; }
uint32_t getVectorCount() const { return vectorCount; }

uint32_t getSizeInBits() const override {
return vectorCount * vectorType->getSizeInBits();
}

private:
SpirvDebugTypeVector *vectorType;
uint32_t vectorCount;
};

/// Represents a function debug type. Includes the function return type and
/// parameter types.
class SpirvDebugTypeFunction : public SpirvDebugType {
Expand Down
1 change: 1 addition & 0 deletions tools/clang/include/clang/SPIRV/SpirvVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class Visitor {
DEFINE_VISIT_METHOD(SpirvDebugTypeBasic)
DEFINE_VISIT_METHOD(SpirvDebugTypeArray)
DEFINE_VISIT_METHOD(SpirvDebugTypeVector)
DEFINE_VISIT_METHOD(SpirvDebugTypeMatrix)
DEFINE_VISIT_METHOD(SpirvDebugTypeFunction)
DEFINE_VISIT_METHOD(SpirvDebugTypeComposite)
DEFINE_VISIT_METHOD(SpirvDebugTypeMember)
Expand Down
19 changes: 11 additions & 8 deletions tools/clang/lib/SPIRV/DebugTypeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,15 +365,18 @@ SpirvDebugType *DebugTypeVisitor::lowerToDebugType(const SpirvType *spirvType) {
break;
}
case SpirvType::TK_Matrix: {
// TODO: I temporarily use a DebugTypeArray for a matrix type.
// However, when the debug info extension supports matrix type
// e.g., DebugTypeMatrix, we must replace DebugTypeArray with
// DebugTypeMatrix.
auto *matType = dyn_cast<MatrixType>(spirvType);
SpirvDebugInstruction *elemDebugType =
lowerToDebugType(matType->getElementType());
debugType = spvContext.getDebugTypeArray(
spirvType, elemDebugType, {matType->numRows(), matType->numCols()});
if (spvOptions.debugInfoVulkan) {
SpirvDebugInstruction *vecDebugType =
lowerToDebugType(matType->getVecType());
debugType = spvContext.getDebugTypeMatrix(spirvType, vecDebugType,
matType->numCols());
} else {
SpirvDebugInstruction *elemDebugType =
lowerToDebugType(matType->getElementType());
debugType = spvContext.getDebugTypeArray(
spirvType, elemDebugType, {matType->numRows(), matType->numCols()});
}
break;
}
case SpirvType::TK_Pointer: {
Expand Down
15 changes: 15 additions & 0 deletions tools/clang/lib/SPIRV/EmitVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1662,6 +1662,21 @@ bool EmitVisitor::visit(SpirvDebugTypeVector *inst) {
return true;
}

bool EmitVisitor::visit(SpirvDebugTypeMatrix *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
curInst.push_back(
getOrAssignResultId<SpirvInstruction>(inst->getInstructionSet()));
curInst.push_back(inst->getDebugOpcode());
curInst.push_back(
getOrAssignResultId<SpirvInstruction>(inst->getVectorType()));
curInst.push_back(getLiteralEncodedForDebugInfo(inst->getVectorCount()));
curInst.push_back(getLiteralEncodedForDebugInfo(1));
finalizeInstruction(&richDebugInfo);
return true;
}

bool EmitVisitor::visit(SpirvDebugTypeArray *inst) {
initInstruction(inst);
curInst.push_back(inst->getResultTypeId());
Expand Down
1 change: 1 addition & 0 deletions tools/clang/lib/SPIRV/EmitVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ class EmitVisitor : public Visitor {
bool visit(SpirvDebugExpression *) override;
bool visit(SpirvDebugTypeBasic *) override;
bool visit(SpirvDebugTypeVector *) override;
bool visit(SpirvDebugTypeMatrix *) override;
bool visit(SpirvDebugTypeArray *) override;
bool visit(SpirvDebugTypeFunction *) override;
bool visit(SpirvDebugTypeComposite *) override;
Expand Down
5 changes: 5 additions & 0 deletions tools/clang/lib/SPIRV/SortDebugInfoVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ void SortDebugInfoVisitor::whileEachOperandOfDebugInstruction(
if (!visitor(inst->getElementType()))
break;
} break;
case SpirvInstruction::IK_DebugTypeMatrix: {
SpirvDebugTypeMatrix *inst = cast<SpirvDebugTypeMatrix>(di);
assert(inst != nullptr);
visitor(inst->getVectorType());
} break;
case SpirvInstruction::IK_DebugTypeFunction: {
SpirvDebugTypeFunction *inst = dyn_cast<SpirvDebugTypeFunction>(di);
assert(inst != nullptr);
Expand Down
15 changes: 15 additions & 0 deletions tools/clang/lib/SPIRV/SpirvContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,21 @@ SpirvContext::getDebugTypeVector(const SpirvType *spirvType,
return debugType;
}

SpirvDebugType *
SpirvContext::getDebugTypeMatrix(const SpirvType *spirvType,
SpirvDebugInstruction *vectorType,
uint32_t vectorCount) {
// Reuse existing debug type if possible.
if (debugTypes.find(spirvType) != debugTypes.end())
return debugTypes[spirvType];

auto *eTy = dyn_cast<SpirvDebugTypeVector>(vectorType);
assert(eTy && "Element type must be a SpirvDebugTypeVector.");
auto *debugType = new (this) SpirvDebugTypeMatrix(eTy, vectorCount);
debugTypes[spirvType] = debugType;
return debugType;
}

SpirvDebugType *
SpirvContext::getDebugTypeFunction(const SpirvType *spirvType, uint32_t flags,
SpirvDebugType *ret,
Expand Down
6 changes: 6 additions & 0 deletions tools/clang/lib/SPIRV/SpirvInstruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvDebugScope)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvDebugTypeBasic)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvDebugTypeArray)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvDebugTypeVector)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvDebugTypeMatrix)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvDebugTypeFunction)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvDebugTypeComposite)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvDebugTypeMember)
Expand Down Expand Up @@ -1102,6 +1103,11 @@ SpirvDebugTypeVector::SpirvDebugTypeVector(SpirvDebugType *elemType,
: SpirvDebugType(IK_DebugTypeVector, /*opcode*/ 6u), elementType(elemType),
elementCount(elemCount) {}

SpirvDebugTypeMatrix::SpirvDebugTypeMatrix(SpirvDebugTypeVector *vectorType,
uint32_t vectorCount)
: SpirvDebugType(IK_DebugTypeMatrix, /*opcode*/ 108u),
vectorType(vectorType), vectorCount(vectorCount) {}

SpirvDebugTypeFunction::SpirvDebugTypeFunction(
uint32_t flags, SpirvDebugType *ret,
llvm::ArrayRef<SpirvDebugType *> params)
Expand Down
24 changes: 24 additions & 0 deletions tools/clang/test/CodeGenSPIRV/shader.debug.type.matrix.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: %dxc -T ps_6_0 -E main -fspv-debug=vulkan -fcgl %s -spirv | FileCheck %s
// RUN: %dxc -T ps_6_2 -E main -fspv-debug=vulkan -fcgl %s -spirv -enable-16bit-types | FileCheck %s --check-prefix=CHECK-HALF

// CHECK: [[float:%[0-9]+]] = OpExtInst %void {{%[0-9]+}} DebugTypeBasic {{%[0-9]+}} %uint_32 %uint_3 %uint_0
// CHECK: {{%[0-9]+}} = OpExtInst %void {{%[0-9]+}} DebugTypeVector {{%[0-9]+}} %uint_4
// CHECK: {{%[0-9]+}} = OpExtInst %void {{%[0-9]+}} DebugTypeMatrix {{%[0-9]+}} %uint_3 %uint_1
// CHECK: [[float:%[0-9]+]] = OpExtInst %void {{%[0-9]+}} DebugTypeBasic {{%[0-9]+}} %uint_64 %uint_3 %uint_0
// CHECK: {{%[0-9]+}} = OpExtInst %void {{%[0-9]+}} DebugTypeVector {{%[0-9]+}} %uint_4
// CHECK: {{%[0-9]+}} = OpExtInst %void {{%[0-9]+}} DebugTypeMatrix {{%[0-9]+}} %uint_3 %uint_1
// CHECK-HALF: [[float:%[0-9]+]] = OpExtInst %void {{%[0-9]+}} DebugTypeBasic {{%[0-9]+}} %uint_16 %uint_3 %uint_0
// CHECK-HALF: {{%[0-9]+}} = OpExtInst %void {{%[0-9]+}} DebugTypeVector {{%[0-9]+}} %uint_4
// CHECK-HALF: {{%[0-9]+}} = OpExtInst %void {{%[0-9]+}} DebugTypeMatrix {{%[0-9]+}} %uint_3 %uint_1
// CHECK-HALF: [[float:%[0-9]+]] = OpExtInst %void {{%[0-9]+}} DebugTypeBasic {{%[0-9]+}} %uint_64 %uint_3 %uint_0
// CHECK-HALF: {{%[0-9]+}} = OpExtInst %void {{%[0-9]+}} DebugTypeVector {{%[0-9]+}} %uint_4
// CHECK-HALF: {{%[0-9]+}} = OpExtInst %void {{%[0-9]+}} DebugTypeMatrix {{%[0-9]+}} %uint_3 %uint_1
// CHECK-HALF: [[float:%[0-9]+]] = OpExtInst %void {{%[0-9]+}} DebugTypeBasic {{%[0-9]+}} %uint_32 %uint_3 %uint_0
// CHECK-HALF: {{%[0-9]+}} = OpExtInst %void {{%[0-9]+}} DebugTypeVector {{%[0-9]+}} %uint_4
// CHECK-HALF: {{%[0-9]+}} = OpExtInst %void {{%[0-9]+}} DebugTypeMatrix {{%[0-9]+}} %uint_3 %uint_1

void main() {
float3x4 mat_float;
double3x4 mat_double;
half3x4 mat_half;
}

0 comments on commit e0fbce7

Please sign in to comment.