Skip to content

Commit

Permalink
Changes based on the code review.
Browse files Browse the repository at this point in the history
  • Loading branch information
s-perron committed Jan 26, 2024
1 parent 1d271cd commit d42b590
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 49 deletions.
108 changes: 62 additions & 46 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,12 +571,21 @@ const StructType *lowerStructType(const SpirvCodeGenOptions &spirvOptions,
// field-index because bitfields are merged into a single field in the SPIR-V
// representation.
//
// If `includeMerged` is true, `operation` will be called on the same spir-v
// field for each field it represents. For example, if a spir-v field holds the
// values for 3 bit-fields, `operation` will be called 3 times with the same
// `spirvFieldIndex`. The `bitfield` information in `field` will be different.
//
// If false, `operation` will be called once on the first field in the merged
// field.
//
// If the operation returns false, we stop processing fields.
void forEachSpirvField(
const RecordType *recordType, const StructType *spirvType,
std::function<bool(size_t spirvFieldIndex, const QualType &fieldType,
const StructType::FieldInfo &field)>
operation) {
operation,
bool includeMerged) {
const auto *cxxDecl = recordType->getAsCXXRecordDecl();
const auto *recordDecl = recordType->getDecl();

Expand All @@ -598,7 +607,8 @@ void forEachSpirvField(
for (const auto *field : recordDecl->fields()) {
const auto &spirvField = spirvType->getFields()[astFieldIndex];
const uint32_t currentFieldIndex = spirvField.fieldIndex;
if (astFieldIndex > 0 && currentFieldIndex == lastConvertedIndex) {
if (!includeMerged && astFieldIndex > 0 &&
currentFieldIndex == lastConvertedIndex) {
++astFieldIndex;
continue;
}
Expand Down Expand Up @@ -3565,9 +3575,9 @@ SpirvInstruction *SpirvEmitter::processFlatConversion(
std::vector<SpirvInstruction *> flatValues = decomposeToScalars(initInstr);

if (flatValues.size() == 1) {
return SplatScalarToGenerate(type, flatValues[0], SpirvLayoutRule::Void);
return splatScalarToGenerate(type, flatValues[0], SpirvLayoutRule::Void);
}
return GenerateFromScalars(type, flatValues, SpirvLayoutRule::Void);
return generateFromScalars(type, flatValues, SpirvLayoutRule::Void);
}

SpirvInstruction *
Expand Down Expand Up @@ -6615,7 +6625,8 @@ SpirvInstruction *SpirvEmitter::reconstructValue(SpirvInstruction *srcVal,
reconstructValue(subSrcVal, fieldType, dstLR, loc, range));

return true;
});
},
false);

auto *result = spvBuilder.createCompositeConstruct(
valType, elements, srcVal->getSourceLocation(), range);
Expand Down Expand Up @@ -7152,29 +7163,31 @@ SpirvInstruction *SpirvEmitter::convertVectorToStruct(QualType astStructType,
uint32_t vectorIndex = 0;
uint32_t elemCount = 1;
llvm::SmallVector<SpirvInstruction *, 4> members;
forEachSpirvField(astStructType->getAs<RecordType>(), spirvStructType,
[&](size_t spirvFieldIndex, const QualType &fieldType,
const auto &field) {
if (isScalarType(fieldType)) {
members.push_back(spvBuilder.createCompositeExtract(
elemType, vector, {vectorIndex++}, loc, range));
return true;
}
forEachSpirvField(
astStructType->getAs<RecordType>(), spirvStructType,
[&](size_t spirvFieldIndex, const QualType &fieldType,
const auto &field) {
if (isScalarType(fieldType)) {
members.push_back(spvBuilder.createCompositeExtract(
elemType, vector, {vectorIndex++}, loc, range));
return true;
}

if (isVectorType(fieldType, nullptr, &elemCount)) {
llvm::SmallVector<uint32_t, 4> indices;
for (uint32_t i = 0; i < elemCount; ++i)
indices.push_back(vectorIndex++);
if (isVectorType(fieldType, nullptr, &elemCount)) {
llvm::SmallVector<uint32_t, 4> indices;
for (uint32_t i = 0; i < elemCount; ++i)
indices.push_back(vectorIndex++);

members.push_back(spvBuilder.createVectorShuffle(
astContext.getExtVectorType(elemType, elemCount),
vector, vector, indices, loc, range));
return true;
}
members.push_back(spvBuilder.createVectorShuffle(
astContext.getExtVectorType(elemType, elemCount), vector, vector,
indices, loc, range));
return true;
}

assert(false && "unhandled type");
return false;
});
assert(false && "unhandled type");
return false;
},
false);

return spvBuilder.createCompositeConstruct(
astStructType, members, vector->getSourceLocation(), range);
Expand Down Expand Up @@ -14544,10 +14557,11 @@ SpirvEmitter::decomposeToScalars(SpirvInstruction *inst) {
resultType = hlsl::GetHLSLResourceResultType(resultType);
}

// switch(inst->getAstResultType()->getKind())
if (isScalarType(resultType)) {
return {inst};
} else if (isVectorType(resultType, &elementType, &elementCount)) {
}

if (isVectorType(resultType, &elementType, &elementCount)) {
std::vector<SpirvInstruction *> result;
for (uint32_t i = 0; i < elementCount; i++) {
auto *element = spvBuilder.createCompositeExtract(
Expand All @@ -14556,7 +14570,9 @@ SpirvEmitter::decomposeToScalars(SpirvInstruction *inst) {
result.push_back(element);
}
return result;
} else if (isMxNMatrix(resultType, &elementType, &numOfRows, &numOfCols)) {
}

if (isMxNMatrix(resultType, &elementType, &numOfRows, &numOfCols)) {
std::vector<SpirvInstruction *> result;
for (uint32_t i = 0; i < numOfRows; i++) {
for (uint32_t j = 0; j < numOfCols; j++) {
Expand All @@ -14567,7 +14583,9 @@ SpirvEmitter::decomposeToScalars(SpirvInstruction *inst) {
}
}
return result;
} else if (isArrayType(resultType, &elementType, &elementCount)) {
}

if (isArrayType(resultType, &elementType, &elementCount)) {
std::vector<SpirvInstruction *> result;
for (uint32_t i = 0; i < elementCount; i++) {
auto *element = spvBuilder.createCompositeExtract(
Expand Down Expand Up @@ -14604,19 +14622,17 @@ SpirvEmitter::decomposeToScalars(SpirvInstruction *inst) {
result.insert(result.end(), decomposedField.begin(),
decomposedField.end());
return true;
});
},
true);
return result;
}

else {
resultType->dump();
llvm_unreachable("Trying to decompose a type that we cannot decompose");
}
llvm_unreachable("Trying to decompose a type that we cannot decompose");
return {};
}

SpirvInstruction *
SpirvEmitter::GenerateFromScalars(QualType type,
SpirvEmitter::generateFromScalars(QualType type,
std::vector<SpirvInstruction *> &scalars,
SpirvLayoutRule layoutRule) {
QualType elementType;
Expand Down Expand Up @@ -14676,7 +14692,7 @@ SpirvEmitter::GenerateFromScalars(QualType type,
} else if (isArrayType(type, &elementType, &elementCount)) {
std::vector<SpirvInstruction *> elements;
for (uint32_t i = 0; i < elementCount; i++) {
elements.push_back(GenerateFromScalars(elementType, scalars, layoutRule));
elements.push_back(generateFromScalars(elementType, scalars, layoutRule));
}
SpirvInstruction *result = spvBuilder.createCompositeConstruct(
type, elements, scalars[0]->getSourceLocation());
Expand Down Expand Up @@ -14707,9 +14723,9 @@ SpirvEmitter::GenerateFromScalars(QualType type,
return {};
}

SpirvInstruction *
SpirvEmitter::SplatScalarToGenerate(QualType type, SpirvInstruction *scalar,
SpirvLayoutRule layoutRule) {
SpirvInstruction *SpirvEmitter::splatScalarToGenerate(QualType type,
SpirvInstruction *scalar,
SpirvLayoutRule rule) {
QualType elementType;
uint32_t elementCount = 0;
uint32_t numOfRows = 0;
Expand All @@ -14718,7 +14734,7 @@ SpirvEmitter::SplatScalarToGenerate(QualType type, SpirvInstruction *scalar,
if (isScalarType(type)) {
// If the type if bool with a non-void layout rule, then it should be
// treated as a uint.
assert(layoutRule == SpirvLayoutRule::Void &&
assert(rule == SpirvLayoutRule::Void &&
"If the layout type is not void, then we should cast to an int when "
"type is a boolean.");
QualType sourceType = scalar->getAstResultType();
Expand All @@ -14737,7 +14753,7 @@ SpirvEmitter::SplatScalarToGenerate(QualType type, SpirvInstruction *scalar,
std::vector<SpirvInstruction *> elements(elementCount, element);
SpirvInstruction *result = spvBuilder.createCompositeConstruct(
type, elements, scalar->getSourceLocation());
result->setLayoutRule(layoutRule);
result->setLayoutRule(rule);
return result;
} else if (isMxNMatrix(type, &elementType, &numOfRows, &numOfCols)) {
SourceLocation loc = scalar->getSourceLocation();
Expand All @@ -14751,19 +14767,19 @@ SpirvEmitter::SplatScalarToGenerate(QualType type, SpirvInstruction *scalar,
QualType rowType = astContext.getExtVectorType(elementType, numOfCols);
SpirvInstruction *r =
spvBuilder.createCompositeConstruct(rowType, row, loc);
r->setLayoutRule(layoutRule);
r->setLayoutRule(rule);
std::vector<SpirvInstruction *> rows(numOfRows, r);
SpirvInstruction *result =
spvBuilder.createCompositeConstruct(type, rows, loc);
result->setLayoutRule(layoutRule);
result->setLayoutRule(rule);
return result;
} else if (isArrayType(type, &elementType, &elementCount)) {
SpirvInstruction *element =
SplatScalarToGenerate(elementType, scalar, layoutRule);
splatScalarToGenerate(elementType, scalar, rule);
std::vector<SpirvInstruction *> elements(elementCount, element);
SpirvInstruction *result = spvBuilder.createCompositeConstruct(
type, elements, scalar->getSourceLocation());
result->setLayoutRule(layoutRule);
result->setLayoutRule(rule);
return result;
} else if (const RecordType *recordType = dyn_cast<RecordType>(type)) {
SourceLocation loc = scalar->getSourceLocation();
Expand All @@ -14782,7 +14798,7 @@ SpirvEmitter::SplatScalarToGenerate(QualType type, SpirvInstruction *scalar,
});
SpirvInstruction *result =
spvBuilder.createCompositeConstruct(type, elements, loc);
result->setLayoutRule(layoutRule);
result->setLayoutRule(rule);
return result;
} else {
llvm_unreachable("Trying to generate a type that we cannot generate");
Expand Down
6 changes: 3 additions & 3 deletions tools/clang/lib/SPIRV/SpirvEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -1257,14 +1257,14 @@ class SpirvEmitter : public ASTConsumer {
/// rule that is obtained by assigning each scalar in `type` to corresponding
/// value in `scalars`. This is the inverse of `decomposeToScalars`.
SpirvInstruction *
GenerateFromScalars(QualType type, std::vector<SpirvInstruction *> &scalars,
generateFromScalars(QualType type, std::vector<SpirvInstruction *> &scalars,
SpirvLayoutRule layoutRule);

/// Returns a spirv instruction with the value of the given type and layout
/// rule that is obtained by assigning `scalar` each scalar in `type`. This is
/// the same as calling `GenerateFromScalars` with a sufficiently large vector
/// the same as calling `generateFromScalars` with a sufficiently large vector
/// where every element is `scalar`.
SpirvInstruction *SplatScalarToGenerate(QualType type,
SpirvInstruction *splatScalarToGenerate(QualType type,
SpirvInstruction *scalar,
SpirvLayoutRule rule);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ struct T {
int64_t j;
};

struct UT {
uint32_t i;
uint64_t j;
};

void main() {

// CHECK: [[inf:%[0-9]+]] = OpFDiv %float %float_1 %float_0
Expand Down Expand Up @@ -59,4 +64,20 @@ void main() {
// CHECK: [[t:%[0-9]+]] = OpCompositeConstruct %T [[lit]] [[longLit]]
// CHECK: OpStore %t [[t]]
T t = (T)(0x100000000+1);

// TODO(6188): This is wrong because we lose most significant bits in the literal.
// CHECK: [[lit:%[0-9]+]] = OpIAdd %uint %uint_0 %uint_1
// CHECK: [[longLit:%[0-9]+]] = OpUConvert %ulong [[lit]]
// CHECK: [[t:%[0-9]+]] = OpCompositeConstruct %UT [[lit]] [[longLit]]
// CHECK: OpStore %ut [[t]]
UT ut = (UT)(0x100000000ul+1);

// TODO(6188): This is wrong because we lose most significant bits in the literal.
// CHECK: [[longLit:%[0-9]+]] = OpIAdd %ulong %ulong_4294967296 %ulong_1
// CHECK: [[lit:%[0-9]+]] = OpUConvert %uint [[longLit]]
// CHECK: [[lit2:%[0-9]+]] = OpBitcast %int [[lit]]
// CHECK: [[longLit2:%[0-9]+]] = OpBitcast %long [[longLit]]
// CHECK: [[t:%[0-9]+]] = OpCompositeConstruct %T [[lit2]] [[longLit2]]
// CHECK: OpStore %t2 [[t]]
T t2 = (T)(0x100000000ull+1);
}
18 changes: 18 additions & 0 deletions tools/clang/test/CodeGenSPIRV/cast.struct-to-int.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ struct Vectors {

RWStructuredBuffer<uint> buf : r0;
RWStructuredBuffer<uint64_t> lbuf : r1;
RWStructuredBuffer<Vectors> vbuf : r2;

// CHECK: OpName [[BUF:%[^ ]*]] "buf"
// CHECK: OpName [[LBUF:%[^ ]*]] "lbuf"
Expand Down Expand Up @@ -111,5 +112,22 @@ void main()
// CHECK: [[V3:%[^ ]*]] = OpLoad [[ULONG]] [[LBUF00_0]]
// CHECK: [[V4:%[^ ]*]] = OpIAdd [[ULONG]] [[V3]] [[V2_0]]
// CHECK: OpStore [[LBUF00_0]] [[V4]]

vbuf[0] = (Vectors) colors;
// CHECK: [[c0:%[^ ]*]] = OpLoad {{%[^ ]*}} %colors
// CHECK: [[c0_0:%[^ ]+]] = OpCompositeExtract %ColorRGBA [[c0]] 0
// The entire bit container extracted for each bitfield.
// CHECK: [[c0_0_0:%[^ ]*]] = OpCompositeExtract %uint [[c0_0]] 0
// CHECK: [[c0_0_1:%[^ ]*]] = OpCompositeExtract %uint [[c0_0]] 0
// CHECK: [[c0_0_2:%[^ ]*]] = OpCompositeExtract %uint [[c0_0]] 0
// CHECK: [[c0_0_3:%[^ ]*]] = OpCompositeExtract %uint [[c0_0]] 0
// CHECK: [[v0:%[^ ]*]] = OpCompositeConstruct %v2uint [[c0_0_0]] [[c0_0_1]]
// CHECK: [[v1:%[^ ]*]] = OpCompositeConstruct %v2uint [[c0_0_2]] [[c0_0_3]]
// CHECK: [[v:%[^ ]*]] = OpCompositeConstruct %Vectors_0 [[v0]] [[v1]]
// CHECK: [[vbuf:%[^ ]*]] = OpAccessChain %{{[^ ]*}} %vbuf [[I0]] [[U0]]
// CHECK: [[v0:%[^ ]*]] = OpCompositeExtract %v2uint [[v]] 0
// CHECK: [[v1:%[^ ]*]] = OpCompositeExtract %v2uint [[v]] 1
// CHECK: [[v:%[^ ]*]] = OpCompositeConstruct %Vectors [[v0]] [[v1]]
// CHECK: OpStore [[vbuf]] [[v]]
}

0 comments on commit d42b590

Please sign in to comment.