Skip to content

Commit

Permalink
Iterate over the spir-v fields to handle bitfields (#6746)
Browse files Browse the repository at this point in the history
The code that implements `RWByteAddressBuffer::Store` will iterate over
all of the fields in a struct to write each element in the struct.
However, it does not use the "Spir-V fields", which accounts for
multiple fields being packed into the same bitfield. This is fixed by
using the `forEachSpirvField` function to make sure that the bitfield
are correctly handled.

Fixes #6483
  • Loading branch information
s-perron authored Jul 19, 2024
1 parent 74ba845 commit c01b4f4
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 63 deletions.
134 changes: 71 additions & 63 deletions tools/clang/lib/SPIRV/RawBufferMethods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "RawBufferMethods.h"
#include "AlignmentSizeCalculator.h"
#include "LowerTypeVisitor.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/CharUnits.h"
#include "clang/AST/RecordLayout.h"
Expand Down Expand Up @@ -284,44 +285,48 @@ SpirvInstruction *RawBufferHandler::processTemplatedLoadFromBuffer(
// aligned like their field with the largest alignment.
// As a result, there might exist some padding after some struct members.
if (const auto *structType = targetType->getAs<RecordType>()) {
const auto *decl = structType->getDecl();
LowerTypeVisitor lowerTypeVisitor(astContext, theEmitter.getSpirvContext(),
theEmitter.getSpirvOptions(), spvBuilder);
auto *decl = targetType->getAsTagDecl();
assert(decl && "Expected all structs to be tag decls.");
const StructType *spvType = dyn_cast<StructType>(lowerTypeVisitor.lowerType(
targetType, theEmitter.getSpirvOptions().sBufferLayoutRule, llvm::None,
decl->getLocation()));
llvm::SmallVector<SpirvInstruction *, 4> loadedElems;
uint32_t fieldOffsetInBytes = 0;
uint32_t structAlignment = 0, structSize = 0, stride = 0;
std::tie(structAlignment, structSize) =
AlignmentSizeCalculator(astContext, theEmitter.getSpirvOptions())
.getAlignmentAndSize(targetType,
theEmitter.getSpirvOptions().sBufferLayoutRule,
llvm::None, &stride);
for (const auto *field : decl->fields()) {
AlignmentSizeCalculator alignmentCalc(astContext,
theEmitter.getSpirvOptions());
uint32_t fieldSize = 0, fieldAlignment = 0;
std::tie(fieldAlignment, fieldSize) = alignmentCalc.getAlignmentAndSize(
field->getType(), theEmitter.getSpirvOptions().sBufferLayoutRule,
/*isRowMajor*/ llvm::None, &stride);
fieldOffsetInBytes = roundToPow2(fieldOffsetInBytes, fieldAlignment);
auto *byteOffset = address.getByteAddress();
if (fieldOffsetInBytes != 0) {
byteOffset = spvBuilder.createBinaryOp(
spv::Op::OpIAdd, astContext.UnsignedIntTy, byteOffset,
spvBuilder.getConstantInt(astContext.UnsignedIntTy,
llvm::APInt(32, fieldOffsetInBytes)),
loc, range);
}

loadedElems.push_back(processTemplatedLoadFromBuffer(
buffer, byteOffset, field->getType(), range));
forEachSpirvField(
structType, spvType,
[this, &buffer, &address, range,
&loadedElems](size_t spirvFieldIndex, const QualType &fieldType,
const auto &field) {
auto *baseOffset = address.getByteAddress();
if (field.offset.hasValue() && field.offset.getValue() != 0) {
const auto loc = buffer->getSourceLocation();
SpirvConstant *offset = spvBuilder.getConstantInt(
astContext.UnsignedIntTy,
llvm::APInt(32, field.offset.getValue()));
baseOffset = spvBuilder.createBinaryOp(
spv::Op::OpIAdd, astContext.UnsignedIntTy, baseOffset, offset,
loc, range);
}

fieldOffsetInBytes += fieldSize;
}
loadedElems.push_back(processTemplatedLoadFromBuffer(
buffer, baseOffset, fieldType, range));
return true;
});

// After we're done with loading the entire struct, we need to update the
// byteAddress (in case we are loading an array of structs).
//
// struct size = 34 bytes (34 / 8) = 4 full words (34 % 8) = 2 > 0,
// therefore need to move to the next aligned address So the starting byte
// offset after loading the entire struct is: 8 * (4 + 1) = 40
uint32_t structAlignment = 0, structSize = 0, stride = 0;
std::tie(structAlignment, structSize) =
AlignmentSizeCalculator(astContext, theEmitter.getSpirvOptions())
.getAlignmentAndSize(targetType,
theEmitter.getSpirvOptions().sBufferLayoutRule,
llvm::None, &stride);

assert(structAlignment != 0);
SpirvInstruction *structWidth = spvBuilder.getConstantInt(
astContext.UnsignedIntTy,
Expand Down Expand Up @@ -577,7 +582,7 @@ void RawBufferHandler::processTemplatedStoreToBuffer(SpirvInstruction *value,
return;
default:
theEmitter.emitError(
"templated load of ByteAddressBuffer is only implemented for "
"templated store of ByteAddressBuffer is only implemented for "
"16, 32, and 64-bit types",
loc);
return;
Expand All @@ -604,40 +609,36 @@ void RawBufferHandler::processTemplatedStoreToBuffer(SpirvInstruction *value,
// aligned like their field with the largest alignment.
// As a result, there might exist some padding after some struct members.
if (const auto *structType = valueType->getAs<RecordType>()) {
const auto *decl = structType->getDecl();
uint32_t fieldOffsetInBytes = 0;
uint32_t structAlignment = 0, structSize = 0, stride = 0;
std::tie(structAlignment, structSize) =
AlignmentSizeCalculator(astContext, theEmitter.getSpirvOptions())
.getAlignmentAndSize(valueType,
theEmitter.getSpirvOptions().sBufferLayoutRule,
llvm::None, &stride);
uint32_t fieldIndex = 0;
for (const auto *field : decl->fields()) {
AlignmentSizeCalculator alignmentCalc(astContext,
theEmitter.getSpirvOptions());
uint32_t fieldSize = 0, fieldAlignment = 0;
std::tie(fieldAlignment, fieldSize) = alignmentCalc.getAlignmentAndSize(
field->getType(), theEmitter.getSpirvOptions().sBufferLayoutRule,
/*isRowMajor*/ llvm::None, &stride);
fieldOffsetInBytes = roundToPow2(fieldOffsetInBytes, fieldAlignment);
auto *byteOffset = address.getByteAddress();
if (fieldOffsetInBytes != 0) {
byteOffset = spvBuilder.createBinaryOp(
spv::Op::OpIAdd, astContext.UnsignedIntTy, byteOffset,
spvBuilder.getConstantInt(astContext.UnsignedIntTy,
llvm::APInt(32, fieldOffsetInBytes)),
loc, range);
}
LowerTypeVisitor lowerTypeVisitor(astContext, theEmitter.getSpirvContext(),
theEmitter.getSpirvOptions(), spvBuilder);
auto *decl = valueType->getAsTagDecl();
assert(decl && "Expected all structs to be tag decls.");
const StructType *spvType = dyn_cast<StructType>(lowerTypeVisitor.lowerType(
valueType, theEmitter.getSpirvOptions().sBufferLayoutRule, llvm::None,
decl->getLocation()));
assert(spvType);
forEachSpirvField(
structType, spvType,
[this, &address, loc, range, buffer, value](size_t spirvFieldIndex,
const QualType &fieldType,
const auto &field) {
auto *baseOffset = address.getByteAddress();
if (field.offset.hasValue() && field.offset.getValue() != 0) {
SpirvConstant *offset = spvBuilder.getConstantInt(
astContext.UnsignedIntTy,
llvm::APInt(32, field.offset.getValue()));
baseOffset = spvBuilder.createBinaryOp(
spv::Op::OpIAdd, astContext.UnsignedIntTy, baseOffset, offset,
loc, range);
}

processTemplatedStoreToBuffer(
spvBuilder.createCompositeExtract(field->getType(), value,
{fieldIndex}, loc, range),
buffer, byteOffset, field->getType(), range);

fieldOffsetInBytes += fieldSize;
++fieldIndex;
}
processTemplatedStoreToBuffer(
spvBuilder.createCompositeExtract(
fieldType, value, {static_cast<uint32_t>(spirvFieldIndex)},
loc, range),
buffer, baseOffset, fieldType, range);
return true;
});

// After we're done with storing the entire struct, we need to update the
// byteAddress (in case we are storing an array of structs).
Expand All @@ -647,6 +648,13 @@ void RawBufferHandler::processTemplatedStoreToBuffer(SpirvInstruction *value,
// (34 % 8) = 2 > 0, therefore need to move to the next aligned address
// So the starting byte offset after loading the entire struct is:
// 8 * (4 + 1) = 40
uint32_t structAlignment = 0, structSize = 0, stride = 0;
std::tie(structAlignment, structSize) =
AlignmentSizeCalculator(astContext, theEmitter.getSpirvOptions())
.getAlignmentAndSize(valueType,
theEmitter.getSpirvOptions().sBufferLayoutRule,
llvm::None, &stride);

assert(structAlignment != 0);
auto *structWidth = spvBuilder.getConstantInt(
astContext.UnsignedIntTy,
Expand Down
12 changes: 12 additions & 0 deletions tools/clang/test/CodeGenSPIRV/method.byte-address-buffer.load.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

ByteAddressBuffer myBuffer;

struct S {
uint32_t x : 8;
uint32_t y : 8;
};

[numthreads(1, 1, 1)]
void main() {
uint addr = 0;
Expand Down Expand Up @@ -50,4 +55,11 @@ void main() {
// CHECK-NEXT: [[load4_word3:%[0-9]+]] = OpLoad %uint [[load_ptr6]]
// CHECK-NEXT: {{%[0-9]+}} = OpCompositeConstruct %v4uint [[load4_word0]] [[load4_word1]] [[load4_word2]] [[load4_word3]]
uint4 word4 = myBuffer.Load4(addr);

// CHECK: [[idx:%[0-9]+]] = OpShiftRightLogical %uint %uint_0 %uint_2
// CHECK: [[ac:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[idx]]
// CHECK: [[bitfield:%[0-9]+]] = OpLoad %uint [[ac]]
// CHECK: [[s:%[0-9]+]] = OpCompositeConstruct %S [[bitfield]]
// CHECK: OpStore %s [[s]]
S s = myBuffer.Load<S>(0);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

RWByteAddressBuffer outBuffer;

struct S
{
uint32_t x:8;
uint32_t y:8;
};

[numthreads(1, 1, 1)]
void main() {
uint addr = 0;
Expand Down Expand Up @@ -67,4 +73,13 @@ void main() {
// CHECK-NEXT: [[outBufPtr3:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %outBuffer %uint_0 [[baseAddr_plus3]]
// CHECK-NEXT: OpStore [[outBufPtr3]] [[word3]]
outBuffer.Store4(addr, words4);

// CHECK: [[s:%[0-9]+]] = OpLoad %S %s
// CHECK: [[bitfield:%[0-9]+]] = OpCompositeExtract %uint [[s]] 0
// CHECK: [[idx:%[0-9]+]] = OpShiftRightLogical %uint %uint_0 %uint_2
// CHECK: [[ac:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %outBuffer %uint_0 [[idx]]
// CHECK: OpStore [[ac]] [[bitfield]]
S s = (S)0;
s.x = 5;
outBuffer.Store(0, s);
}

0 comments on commit c01b4f4

Please sign in to comment.