Skip to content

Commit

Permalink
[SPIR-V] Avoid emitting Int64 when loading Float64 (#7073)
Browse files Browse the repository at this point in the history
When loading a Float64 from a raw buffer, we used an Int64, which
required an additional capability, even if the code wasn't using any
Int64.
In practice, it seems most devices supporting Float64 do also support
Int64, but this it doesn't have to.
By changing the codegen a bit, we can avoid the Int64 value.

Tested the word-order using a vulkan compute shader, and checking the
returned value on the API side.
```hlsl
double tmp = buffer.Load<double>(0);
if (tmp == 12.0)
  buffer.Store<double>(0, 13.0);
```

Fixes #7038

---------

Signed-off-by: Nathan Gauër <[email protected]>
  • Loading branch information
Keenuts authored Jan 17, 2025
1 parent e52b6bc commit e4636f0
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 235 deletions.
82 changes: 29 additions & 53 deletions tools/clang/lib/SPIRV/RawBufferMethods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,48 +117,32 @@ SpirvInstruction *RawBufferHandler::load64Bits(SpirvInstruction *buffer,
SpirvInstruction *ptr = nullptr;
auto *constUint0 =
spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
auto *constUint32 =
spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 32));

// Load the first word and increment index.
auto *index = address.getWordIndex(loc, range);

// Need to perform two 32-bit uint loads and construct a 64-bit value.

// Load the first 32-bit uint (word0).
ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
{constUint0, index}, loc, range);
SpirvInstruction *word0 =
spvBuilder.createLoad(astContext.UnsignedIntTy, ptr, loc, range);
// Increment the base index
address.incrementWordIndex(loc, range);

// Load the second word and increment index.
index = address.getWordIndex(loc, range);
// Load the second 32-bit uint (word1).
ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
{constUint0, index}, loc, range);
SpirvInstruction *word1 =
spvBuilder.createLoad(astContext.UnsignedIntTy, ptr, loc, range);

// Convert both word0 and word1 to 64-bit uints.
word0 = spvBuilder.createUnaryOp(
spv::Op::OpUConvert, astContext.UnsignedLongLongTy, word0, loc, range);
word1 = spvBuilder.createUnaryOp(
spv::Op::OpUConvert, astContext.UnsignedLongLongTy, word1, loc, range);

// Shift word1 to the left by 32 bits.
word1 = spvBuilder.createBinaryOp(spv::Op::OpShiftLeftLogical,
astContext.UnsignedLongLongTy, word1,
constUint32, loc, range);

// BitwiseOr word0 and word1.
result = spvBuilder.createBinaryOp(spv::Op::OpBitwiseOr,
astContext.UnsignedLongLongTy, word0,
word1, loc, range);
result = bitCastToNumericalOrBool(result, astContext.UnsignedLongLongTy,
target64BitType, loc, range);
result->setRValue();

address.incrementWordIndex(loc, range);

// Combine the 2 words into a composite, and bitcast into the destination
// type.
const auto uintVec2Type =
astContext.getExtVectorType(astContext.UnsignedIntTy, 2);
auto *operand = spvBuilder.createCompositeConstruct(
uintVec2Type, {word0, word1}, loc, range);
result = spvBuilder.createUnaryOp(spv::Op::OpBitcast, target64BitType,
operand, loc, range);
result->setRValue();
return result;
}

Expand Down Expand Up @@ -441,39 +425,31 @@ void RawBufferHandler::store64Bits(SpirvInstruction *value,
const auto loc = buffer->getSourceLocation();
auto *constUint0 =
spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
auto *constUint32 =
spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 32));

auto *index = address.getWordIndex(loc, range);
// Bitcast the source into a 32-bit words composite.
const auto uintVec2Type =
astContext.getExtVectorType(astContext.UnsignedIntTy, 2);
auto *tmp = spvBuilder.createUnaryOp(spv::Op::OpBitcast, uintVec2Type, value,
loc, range);

// The underlying element type of the ByteAddressBuffer is uint. So we
// need to store two 32-bit values.
// Extract the low and high word (careful! word order).
auto *A = spvBuilder.createCompositeExtract(astContext.UnsignedIntTy, tmp,
{0}, loc, range);
auto *B = spvBuilder.createCompositeExtract(astContext.UnsignedIntTy, tmp,
{1}, loc, range);

// Store the first word, and increment counter.
auto *index = address.getWordIndex(loc, range);
auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
{constUint0, index}, loc, range);
// First convert the 64-bit value to uint64_t. Then extract two 32-bit words
// from it.
value = bitCastToNumericalOrBool(value, valueType,
astContext.UnsignedLongLongTy, loc, range);

// Use OpUConvert to perform truncation (produces the least significant bits).
SpirvInstruction *lsb = spvBuilder.createUnaryOp(
spv::Op::OpUConvert, astContext.UnsignedIntTy, value, loc, range);

// Shift uint64_t to the right by 32 bits and truncate to get the most
// significant bits.
SpirvInstruction *msb = spvBuilder.createUnaryOp(
spv::Op::OpUConvert, astContext.UnsignedIntTy,
spvBuilder.createBinaryOp(spv::Op::OpShiftRightLogical,
astContext.UnsignedLongLongTy, value,
constUint32, loc, range),
loc, range);

spvBuilder.createStore(ptr, lsb, loc, range);
spvBuilder.createStore(ptr, A, loc, range);
address.incrementWordIndex(loc, range);

// Store the second word, and increment counter.
index = address.getWordIndex(loc, range);
ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
{constUint0, index}, loc, range);
spvBuilder.createStore(ptr, msb, loc, range);
spvBuilder.createStore(ptr, B, loc, range);
address.incrementWordIndex(loc, range);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: %dxc -T cs_6_0 -E main -O0 %s -spirv | FileCheck %s

// CHECK-NOT: OpCapability Int64
// CHECK-DAG: OpCapability Float64
// CHECK-NOT: OpCapability Int64

RWByteAddressBuffer buffer;

[numthreads(1, 1, 1)]
void main() {
double tmp;

// CHECK: [[addr1:%[0-9]+]] = OpShiftRightLogical %uint %uint_0 %uint_2
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buffer %uint_0 [[addr1]]
// CHECK: [[word0:%[0-9]+]] = OpLoad %uint [[ptr]]
// CHECK: [[addr2:%[0-9]+]] = OpIAdd %uint [[addr1]] %uint_1
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buffer %uint_0 [[addr2]]
// CHECK: [[word1:%[0-9]+]] = OpLoad %uint [[ptr]]
// CHECK: [[addr3:%[0-9]+]] = OpIAdd %uint [[addr2]] %uint_1
// CHECK: [[merge:%[0-9]+]] = OpCompositeConstruct %v2uint [[word0]] [[word1]]
// CHECK: [[value:%[0-9]+]] = OpBitcast %double [[merge]]
// CHECK: OpStore %tmp [[value]]
tmp = buffer.Load<double>(0);

// CHECK: [[value:%[0-9]+]] = OpLoad %double %tmp
// CHECK: [[merge:%[0-9]+]] = OpBitcast %v2uint [[value]]
// CHECK: [[word0:%[0-9]+]] = OpCompositeExtract %uint [[merge]] 0
// CHECK: [[word1:%[0-9]+]] = OpCompositeExtract %uint [[merge]] 1

// CHECK: [[addr1:%[0-9]+]] = OpShiftRightLogical %uint %uint_0 %uint_2
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buffer %uint_0 [[addr1]]
// CHECK: OpStore [[ptr]] [[word0]]
// CHECK: [[addr2:%[0-9]+]] = OpIAdd %uint [[addr1]] %uint_1
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buffer %uint_0 [[addr2]]
// CHECK: OpStore [[ptr]] [[word1]]
// CHECK: [[addr3:%[0-9]+]] = OpIAdd %uint [[addr2]] %uint_1
buffer.Store<double>(0, tmp);
}

Original file line number Diff line number Diff line change
Expand Up @@ -98,53 +98,46 @@ void main(uint3 tid : SV_DispatchThreadId)
// ********* 64-bit matrix ********************

// CHECK: [[index_1:%[0-9]+]] = OpShiftRightLogical %uint [[addr0_1:%[0-9]+]] %uint_2
// CHECK: [[ptr_11:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_1]]
// CHECK: [[word0_2:%[0-9]+]] = OpLoad %uint [[ptr_11]]
// CHECK: [[index_1_2:%[0-9]+]] = OpIAdd %uint [[index_1]] %uint_1
// CHECK: [[ptr_12:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_1_2]]
// CHECK: [[word1_3:%[0-9]+]] = OpLoad %uint [[ptr_12]]
// CHECK: [[word0_ulong:%[0-9]+]] = OpUConvert %ulong [[word0_2]]
// CHECK: [[word1_ulong:%[0-9]+]] = OpUConvert %ulong [[word1_3]]
// CHECK: [[word1_ulong_shifted:%[0-9]+]] = OpShiftLeftLogical %ulong [[word1_ulong]] %uint_32
// CHECK: [[val0_ulong:%[0-9]+]] = OpBitwiseOr %ulong [[word0_ulong]] [[word1_ulong_shifted]]
// CHECK: [[val0_1:%[0-9]+]] = OpBitcast %double [[val0_ulong]]
// CHECK: [[index_2_2:%[0-9]+]] = OpIAdd %uint [[index_1_2]] %uint_1
// CHECK: [[ptr_13:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_2_2]]
// CHECK: [[word2_2:%[0-9]+]] = OpLoad %uint [[ptr_13]]
// CHECK: [[index_3_0:%[0-9]+]] = OpIAdd %uint [[index_2_2]] %uint_1
// CHECK: [[ptr_14:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_3_0]]
// CHECK: [[word3_0:%[0-9]+]] = OpLoad %uint [[ptr_14]]
// CHECK: [[word2_ulong:%[0-9]+]] = OpUConvert %ulong [[word2_2]]
// CHECK: [[word3_ulong:%[0-9]+]] = OpUConvert %ulong [[word3_0]]
// CHECK: [[word3_ulong_shifted:%[0-9]+]] = OpShiftLeftLogical %ulong [[word3_ulong]] %uint_32
// CHECK: [[val1_ulong:%[0-9]+]] = OpBitwiseOr %ulong [[word2_ulong]] [[word3_ulong_shifted]]
// CHECK: [[val1_1:%[0-9]+]] = OpBitcast %double [[val1_ulong]]
// CHECK: [[index_4_0:%[0-9]+]] = OpIAdd %uint [[index_3_0]] %uint_1
// CHECK: [[ptr_15:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_4_0]]
// CHECK: [[word4_0:%[0-9]+]] = OpLoad %uint [[ptr_15]]
// CHECK: [[index_5_0:%[0-9]+]] = OpIAdd %uint [[index_4_0]] %uint_1
// CHECK: [[ptr_16:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_5_0]]
// CHECK: [[word5_0:%[0-9]+]] = OpLoad %uint [[ptr_16]]
// CHECK: [[word4_ulong:%[0-9]+]] = OpUConvert %ulong [[word4_0]]
// CHECK: [[word5_ulong:%[0-9]+]] = OpUConvert %ulong [[word5_0]]
// CHECK: [[word5_ulong_shifted:%[0-9]+]] = OpShiftLeftLogical %ulong [[word5_ulong]] %uint_32
// CHECK: [[val2_ulong:%[0-9]+]] = OpBitwiseOr %ulong [[word4_ulong]] [[word5_ulong_shifted]]
// CHECK: [[val2_1:%[0-9]+]] = OpBitcast %double [[val2_ulong]]
// CHECK: [[ptr_11:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_1]]
// CHECK: [[word0_2:%[0-9]+]] = OpLoad %uint [[ptr_11]]
// CHECK: [[index_1_2:%[0-9]+]] = OpIAdd %uint [[index_1]] %uint_1
// CHECK: [[ptr_12:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_1_2]]
// CHECK: [[word1_3:%[0-9]+]] = OpLoad %uint [[ptr_12]]
// CHECK: [[index_2_2:%[0-9]+]] = OpIAdd %uint [[index_1_2]] %uint_1
// CHECK: [[merge:%[0-9]+]] = OpCompositeConstruct %v2uint [[word0_2]] [[word1_3]]
// CHECK: [[val0_1:%[0-9]+]] = OpBitcast %double [[merge]]

// CHECK: [[ptr_13:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_2_2]]
// CHECK: [[word2_2:%[0-9]+]] = OpLoad %uint [[ptr_13]]
// CHECK: [[index_3_0:%[0-9]+]] = OpIAdd %uint [[index_2_2]] %uint_1
// CHECK: [[ptr_14:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_3_0]]
// CHECK: [[word3_0:%[0-9]+]] = OpLoad %uint [[ptr_14]]
// CHECK: [[index_4_0:%[0-9]+]] = OpIAdd %uint [[index_3_0]] %uint_1
// CHECK: [[merge:%[0-9]+]] = OpCompositeConstruct %v2uint [[word2_2]] [[word3_0]]
// CHECK: [[val1_1:%[0-9]+]] = OpBitcast %double [[merge]]

// CHECK: [[ptr_15:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_4_0]]
// CHECK: [[word4_0:%[0-9]+]] = OpLoad %uint [[ptr_15]]
// CHECK: [[index_5_0:%[0-9]+]] = OpIAdd %uint [[index_4_0]] %uint_1
// CHECK: [[ptr_16:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_5_0]]
// CHECK: [[word5_0:%[0-9]+]] = OpLoad %uint [[ptr_16]]
// CHECK: [[index_6:%[0-9]+]] = OpIAdd %uint [[index_5_0]] %uint_1
// CHECK: [[ptr_17:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_6]]
// CHECK: [[merge:%[0-9]+]] = OpCompositeConstruct %v2uint [[word4_0]] [[word5_0]]
// CHECK: [[val2_1:%[0-9]+]] = OpBitcast %double [[merge]]

// CHECK: [[ptr_17:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_6]]
// CHECK: [[word6:%[0-9]+]] = OpLoad %uint [[ptr_17]]
// CHECK: [[index_7:%[0-9]+]] = OpIAdd %uint [[index_6]] %uint_1
// CHECK: [[ptr_18:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_7]]
// CHECK: [[ptr_18:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_7]]
// CHECK: [[word7:%[0-9]+]] = OpLoad %uint [[ptr_18]]
// CHECK: [[word6_ulong:%[0-9]+]] = OpUConvert %ulong [[word6]]
// CHECK: [[word7_ulong:%[0-9]+]] = OpUConvert %ulong [[word7]]
// CHECK: [[word7_ulong_shifted:%[0-9]+]] = OpShiftLeftLogical %ulong [[word7_ulong]] %uint_32
// CHECK: [[val3_ulong:%[0-9]+]] = OpBitwiseOr %ulong [[word6_ulong]] [[word7_ulong_shifted]]
// CHECK: [[val3_1:%[0-9]+]] = OpBitcast %double [[val3_ulong]]
// CHECK: [[row0_1:%[0-9]+]] = OpCompositeConstruct %v2double [[val0_1]] [[val2_1]]
// CHECK: [[row1_1:%[0-9]+]] = OpCompositeConstruct %v2double [[val1_1]] [[val3_1]]
// CHECK: [[matrix_1:%[0-9]+]] = OpCompositeConstruct %mat2v2double [[row0_1]] [[row1_1]]
// CHECK: OpStore %f64 [[matrix_1]]
// CHECK: [[index_8:%[0-9]+]] = OpIAdd %uint [[index_7]] %uint_1
// CHECK: [[merge:%[0-9]+]] = OpCompositeConstruct %v2uint [[word6]] [[word7]]
// CHECK: [[val3_1:%[0-9]+]] = OpBitcast %double [[merge]]

// CHECK: [[row0_1:%[0-9]+]] = OpCompositeConstruct %v2double [[val0_1]] [[val2_1]]
// CHECK: [[row1_1:%[0-9]+]] = OpCompositeConstruct %v2double [[val1_1]] [[val3_1]]
// CHECK: [[matrix_1:%[0-9]+]] = OpCompositeConstruct %mat2v2double [[row0_1]] [[row1_1]]
// CHECK: OpStore %f64 [[matrix_1]]
float64_t2x2 f64 = buf.Load<float64_t2x2>(tid.x);

// ********* array of matrices ********************
Expand Down
Loading

0 comments on commit e4636f0

Please sign in to comment.