Skip to content

Commit

Permalink
[bug] Fix crash when lowering multi-dimension groupshared variable (m…
Browse files Browse the repository at this point in the history
…icrosoft#5895)

This commit fixes a crash in the compiler when lowering a groupshared
variable with a multi-dimensional array type. The root cause of the bug
was that we had a nested gep expression that could not be merged into a
single gep because of an intervening addrspacecast.

The `MultiDimArrayToOneDimArray` pass flattens the multi-dimension
global variables to a single dimension. It relies on the `MergeGepUse`
function to flatten any nested geps into a single gep that fully
dereferences a scalar element.

The fix is to modify the `MergeGepUse` function to look through
addrspacecast instructions when trying to merge geps. We can now merge
geps like

    gep(addrspacecast(gep(p0, gep_args0)) to p1*, gep_args1)

into

    addrspacecast(gep(p0, gep_args0+gep_args1) to p1*)

We also added a call to `removeDeadConstantUsers` before flattening
multi-dimension globals because we can have some dead constants hanging
around after merging geps and these constants should be ignored by the
flattening pass.
  • Loading branch information
dmpots authored Oct 23, 2023
1 parent 0a5396c commit 481090a
Show file tree
Hide file tree
Showing 5 changed files with 347 additions and 17 deletions.
20 changes: 20 additions & 0 deletions include/llvm/IR/Operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,26 @@ class BitCastOperator
}
};

// HLSL CHANGE: Add this helper class from upstream.
class AddrSpaceCastOperator
: public ConcreteOperator<Operator, Instruction::AddrSpaceCast> {
friend class AddrSpaceCastInst;
friend class ConstantExpr;

public:
Value *getPointerOperand() { return getOperand(0); }

const Value *getPointerOperand() const { return getOperand(0); }

unsigned getSrcAddressSpace() const {
return getPointerOperand()->getType()->getPointerAddressSpace();
}

unsigned getDestAddressSpace() const {
return getType()->getPointerAddressSpace();
}
};

} // End llvm namespace

#endif
88 changes: 71 additions & 17 deletions lib/DXIL/DxilUtilDbgInfoAndMisc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,18 @@ using namespace hlsl;

namespace {

Value *MergeGEP(GEPOperator *SrcGEP, GEPOperator *GEP) {
// Attempt to merge the two GEPs into a single GEP.
//
// If `AsCast` is non-null the merged GEP will be wrapped
// in an addrspacecast before replacing users. This allows
// merging GEPs of the form
//
// gep(addrspacecast(gep(p0, gep_args0) to p1*), gep_args1)
// into
// addrspacecast(gep(p0, gep_args0+gep_args1) to p1*)
//
Value *MergeGEP(GEPOperator *SrcGEP, GEPOperator *GEP,
AddrSpaceCastOperator *AsCast) {
IRBuilder<> Builder(GEP->getContext());
StringRef Name = "";
if (Instruction *I = dyn_cast<Instruction>(GEP)) {
Expand Down Expand Up @@ -75,7 +86,7 @@ Value *MergeGEP(GEPOperator *SrcGEP, GEPOperator *GEP) {
}

// Update the GEP in place if possible.
if (SrcGEP->getNumOperands() == 2) {
if (SrcGEP->getNumOperands() == 2 && !AsCast) {
GEP->setOperand(0, SrcGEP->getOperand(0));
GEP->setOperand(1, Sum);
return GEP;
Expand All @@ -94,12 +105,64 @@ Value *MergeGEP(GEPOperator *SrcGEP, GEPOperator *GEP) {
DXASSERT(!Indices.empty(), "must merge");
Value *newGEP =
Builder.CreateInBoundsGEP(nullptr, SrcGEP->getOperand(0), Indices, Name);

// Wrap the new gep in an addrspacecast if needed.
if (AsCast)
newGEP = Builder.CreateAddrSpaceCast(
newGEP, PointerType::get(GEP->getType()->getPointerElementType(),
AsCast->getDestAddressSpace()));
GEP->replaceAllUsesWith(newGEP);
if (Instruction *I = dyn_cast<Instruction>(GEP))
I->eraseFromParent();
return newGEP;
}

// Examine the gep and try to merge it when the input pointer is
// itself a gep. We handle two forms here:
//
// gep(gep(p))
// gep(addrspacecast(gep(p)))
//
// If the gep was merged successfully then return the updated value, otherwise
// return nullptr.
//
// When the gep is sucessfully merged we will delete the gep and also try to
// delete the nested gep and addrspacecast.
static Value *TryMegeWithNestedGEP(GEPOperator *GEP) {
// Sentinal value to return when we fail to merge.
Value *FailedToMerge = nullptr;

Value *Ptr = GEP->getPointerOperand();
GEPOperator *prevGEP = dyn_cast<GEPOperator>(Ptr);
AddrSpaceCastOperator *AsCast = nullptr;

// If there is no directly nested gep try looking through an addrspacecast to
// find one.
if (!prevGEP) {
AsCast = dyn_cast<AddrSpaceCastOperator>(Ptr);
if (AsCast)
prevGEP = dyn_cast<GEPOperator>(AsCast->getPointerOperand());
}

// Not a nested gep expression.
if (!prevGEP)
return FailedToMerge;

// Try merging the two geps.
Value *newGEP = MergeGEP(prevGEP, GEP, AsCast);
if (!newGEP)
return FailedToMerge;

// Delete the nested gep and addrspacecast if no more users.
if (AsCast && AsCast->user_empty() && isa<AddrSpaceCastInst>(AsCast))
cast<AddrSpaceCastInst>(AsCast)->eraseFromParent();

if (prevGEP->user_empty() && isa<GetElementPtrInst>(prevGEP))
cast<GetElementPtrInst>(prevGEP)->eraseFromParent();

return newGEP;
}

} // namespace

namespace hlsl {
Expand Down Expand Up @@ -130,23 +193,14 @@ bool MergeGepUse(Value *V) {
// merge any GEP users of the untranslated bitcast
addUsersToWorklist(V);
}
} else if (isa<AddrSpaceCastOperator>(V)) {
addUsersToWorklist(V);
} else if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) {
if (GEPOperator *prevGEP =
dyn_cast<GEPOperator>(GEP->getPointerOperand())) {
// merge the 2 GEPs, returns nullptr if couldn't merge
if (Value *newGEP = MergeGEP(prevGEP, GEP)) {
changed = true;
worklist.push_back(newGEP);
// delete prevGEP if no more users
if (prevGEP->user_empty() && isa<GetElementPtrInst>(prevGEP)) {
cast<GetElementPtrInst>(prevGEP)->eraseFromParent();
}
} else {
addUsersToWorklist(GEP);
}
if (Value *newGEP = TryMegeWithNestedGEP(GEP)) {
changed = true;
worklist.push_back(newGEP);
} else {
// nothing to merge yet, add GEP users
addUsersToWorklist(V);
addUsersToWorklist(GEP);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions lib/Transforms/Scalar/LowerTypePasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ bool LowerTypePass::runOnModule(Module &M) {
HLModule::UpdateGlobalVariableDebugInfo(GV, Finder, NewGV);
}
// Replace users.
GV->removeDeadConstantUsers();
lowerUseWithNewValue(GV, NewGV);
// Remove GV.
GV->removeDeadConstantUsers();
Expand Down
198 changes: 198 additions & 0 deletions test/HLSL/passes/multi_dim_one_dim/gep_addrspacecast_gep.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
; RUN: opt -S -multi-dim-one-dim %s | FileCheck %s
;
; Tests for the pass that changes multi-dimension global variable accesses into
; a flattened one-dimensional access. The tests focus on the case where the geps
; need to be merged but are separated by an addrspacecast operation. This was
; causing the pass to fail because it could not merge the gep through the
; addrspace cast.

; Naming convention: gep0_addrspacecast_gep1

target datalayout = "e-m:e-p:32:32-i1:32-i8:32-i16:32-i32:32-i64:64-f16:32-f32:32-f64:64-n8:16:32:64"
target triple = "dxil-ms-dx"

@ArrayOfArray = addrspace(3) global [256 x [9 x float]] undef, align 4
@ArrayOfArrayOfArray = addrspace(3) global [256 x [9 x [3 x float]]] undef, align 4

; Test that we can merge the geps when all parts are instructions.
; CHECK-LABEL: @merge_gep_instr_instr_instr
; CHECK: load float, float* addrspacecast (float addrspace(3)* getelementptr inbounds ([2304 x float], [2304 x float] addrspace(3)* @ArrayOfArray.1dim, i32 0, i32 1) to float*)
define void @merge_gep_instr_instr_instr() {
entry:
%gep0 = getelementptr inbounds [256 x [9 x float]], [256 x [9 x float]] addrspace(3)* @ArrayOfArray, i32 0, i32 0
%asc = addrspacecast [9 x float] addrspace(3)* %gep0 to [9 x float]*
%gep1 = getelementptr inbounds [9 x float], [9 x float]* %asc, i32 0, i32 1
%load = load float, float* %gep1
ret void
}

; Test that we can merge the geps when the inner gep are constants.
; CHECK-LABEL: @merge_gep_instr_instr_const
; CHECK: load float, float* addrspacecast (float addrspace(3)* getelementptr inbounds ([2304 x float], [2304 x float] addrspace(3)* @ArrayOfArray.1dim, i32 0, i32 1) to float*)
define void @merge_gep_instr_instr_const() {
entry:
%asc = addrspacecast [9 x float] addrspace(3)* getelementptr inbounds ([256 x [9 x float]], [256 x [9 x float]] addrspace(3)* @ArrayOfArray, i32 0, i32 0) to [9 x float]*
%gep1 = getelementptr inbounds [9 x float], [9 x float]* %asc, i32 0, i32 1
%load = load float, float* %gep1
ret void
}

; Test that we can merge the geps when the addrspace and inner gep are constants.
; CHECK-LABEL: @merge_gep_instr_const_const
; CHECK: load float, float* addrspacecast (float addrspace(3)* getelementptr inbounds ([2304 x float], [2304 x float] addrspace(3)* @ArrayOfArray.1dim, i32 0, i32 1) to float*)
define void @merge_gep_instr_const_const() {
entry:
%gep1 = getelementptr inbounds [9 x float], [9 x float]* addrspacecast ([9 x float] addrspace(3)* getelementptr inbounds ([256 x [9 x float]], [256 x [9 x float]] addrspace(3)* @ArrayOfArray, i32 0, i32 0) to [9 x float]*), i32 0, i32 1
%load = load float, float* %gep1
ret void
}

; Test that we can merge the geps when all parts are constants.
; CHECK-LABEL: @merge_gep_const_const
; CHECK: load float, float* addrspacecast (float addrspace(3)* getelementptr inbounds ([2304 x float], [2304 x float] addrspace(3)* @ArrayOfArray.1dim, i32 0, i32 1) to float*)
define void @merge_gep_const_const_const() {
entry:
%load = load float, float* getelementptr inbounds ([9 x float], [9 x float]* addrspacecast ([9 x float] addrspace(3)* getelementptr inbounds ([256 x [9 x float]], [256 x [9 x float]] addrspace(3)* @ArrayOfArray, i32 0, i32 0) to [9 x float]*), i32 0, i32 1)
ret void
}

; Test that we compute the correct index when the outer array has
; a non-zero constant index.
; CHECK-LABEL: @merge_gep_const_outer_array_index
; CHECK: load float, float* addrspacecast (float addrspace(3)* getelementptr inbounds ([2304 x float], [2304 x float] addrspace(3)* @ArrayOfArray.1dim, i32 0, i32 66) to float*)
define void @merge_gep_const_outer_array_index() {
entry:
%gep0 = getelementptr inbounds [256 x [9 x float]], [256 x [9 x float]] addrspace(3)* @ArrayOfArray, i32 0, i32 7
%asc = addrspacecast [9 x float] addrspace(3)* %gep0 to [9 x float]*
%gep1 = getelementptr inbounds [9 x float], [9 x float]* %asc, i32 0, i32 3
%load = load float, float* %gep1
ret void
}

; Test that we compute the correct index when the outer array has
; a non-constant index.
; CHECK-LABEL: @merge_gep_dynamic_outer_array_index
; CHECK: %0 = mul i32 %idx, 9
; CHECK: %1 = add i32 3, %0
; CHECK: %2 = getelementptr [2304 x float], [2304 x float] addrspace(3)* @ArrayOfArray.1dim, i32 0, i32 %1
; CHECK: %3 = addrspacecast float addrspace(3)* %2 to float*
; CHECK: load float, float* %3
define void @merge_gep_dynamic_outer_array_index(i32 %idx) {
entry:
%gep0 = getelementptr inbounds [256 x [9 x float]], [256 x [9 x float]] addrspace(3)* @ArrayOfArray, i32 0, i32 %idx
%asc = addrspacecast [9 x float] addrspace(3)* %gep0 to [9 x float]*
%gep1 = getelementptr inbounds [9 x float], [9 x float]* %asc, i32 0, i32 3
%load = load float, float* %gep1
ret void
}

; Test that we compute the correct index when the both arrays have
; a non-constant index.
; CHECK-LABEL: @merge_gep_dynamic_array_index
; CHECK: %0 = mul i32 %idx0, 9
; CHECK: %1 = add i32 %idx1, %0
; CHECK: %2 = getelementptr [2304 x float], [2304 x float] addrspace(3)* @ArrayOfArray.1dim, i32 0, i32 %1
; CHECK: %3 = addrspacecast float addrspace(3)* %2 to float*
; CHECK: load float, float* %3
define void @merge_gep_dynamic_array_index(i32 %idx0, i32 %idx1) {
entry:
%gep0 = getelementptr inbounds [256 x [9 x float]], [256 x [9 x float]] addrspace(3)* @ArrayOfArray, i32 0, i32 %idx0
%asc = addrspacecast [9 x float] addrspace(3)* %gep0 to [9 x float]*
%gep1 = getelementptr inbounds [9 x float], [9 x float]* %asc, i32 0, i32 %idx1
%load = load float, float* %gep1
ret void
}

; Test that we compute the correct index when there are multiple
; geps after the addrspacecast. This also exercises the case
; where one of the outer geps ends in an array which hits
; an early return in MergeGEP.
; CHECK-LABEL: @merge_gep_multi_level_end_in_sequential_with_addrspace
; CHECK: %0 = mul i32 %idx0, 9
; CHECK: %1 = add i32 %idx1, %0
; CHECK: %2 = getelementptr [2304 x float], [2304 x float] addrspace(3)* @ArrayOfArray.1dim, i32 0, i32 %1
; CHECK: %3 = addrspacecast float addrspace(3)* %2 to float*
; CHECK: load float, float* %3
define void @merge_gep_multi_level_end_in_sequential_with_addrspace(i32 %idx0, i32 %idx1) {
entry:
%gep0 = getelementptr inbounds [256 x [9 x float]], [256 x [9 x float]] addrspace(3)* @ArrayOfArray, i32 0
%asc = addrspacecast [256 x [9 x float]] addrspace(3)* %gep0 to [256 x [9 x float]]*
%gep1 = getelementptr inbounds [256 x [9 x float]], [256 x [9 x float]]* %asc, i32 0, i32 %idx0
%gep2 = getelementptr inbounds [9 x float], [9 x float]* %gep1, i32 0, i32 %idx1
%load = load float, float* %gep2
ret void
}

; Test that we compute the correct index when there are three levels of geps.
; This also exercises the case where one of the outer geps ends in an
; array which hits an early return in MergeGEP.
; CHECK-LABEL: @merge_gep_multi_level_end_in_sequential
; CHECK: %0 = mul i32 %idx0, 9
; CHECK: %1 = add i32 %idx1, %0
; CHECK: %2 = getelementptr [2304 x float], [2304 x float] addrspace(3)* @ArrayOfArray.1dim, i32 0, i32 %1
; CHECK: load float, float addrspace(3)* %2
define void @merge_gep_multi_level_end_in_sequential(i32 %idx0, i32 %idx1) {
entry:
%gep0 = getelementptr inbounds [256 x [9 x float]], [256 x [9 x float]] addrspace(3)* @ArrayOfArray, i32 0
%gep1 = getelementptr inbounds [256 x [9 x float]], [256 x [9 x float]] addrspace(3)* %gep0, i32 0, i32 %idx0
%gep2 = getelementptr inbounds [9 x float], [9 x float] addrspace(3)* %gep1, i32 0, i32 %idx1
%load = load float, float addrspace(3)* %gep2
ret void
}

; Test that we compute the correct index when the global has 3 levels of
; nested arrays and an addrspacecast.
; CHECK-LABEL: @merge_gep_multi_level_with_addrspace
; CHECK: %0 = mul i32 %idx0, 9
; CHECK: %1 = add i32 %idx1, %0
; CHECK: %2 = mul i32 %1, 3
; CHECK: %3 = add i32 %idx2, %2
; CHECK: %4 = getelementptr [6912 x float], [6912 x float] addrspace(3)* @ArrayOfArrayOfArray.1dim, i32 0, i32 %3
; CHECK: %5 = addrspacecast float addrspace(3)* %4 to float*
; CHECK: load float, float* %5
define void @merge_gep_multi_level_with_addrspace(i32 %idx0, i32 %idx1, i32 %idx2) {
entry:
%gep0 = getelementptr inbounds [256 x [9 x [3 x float]]], [256 x [9 x [3 x float]]] addrspace(3)* @ArrayOfArrayOfArray, i32 0, i32 %idx0
%asc = addrspacecast [9 x [3 x float]] addrspace(3)* %gep0 to [9 x [3 x float]]*
%gep1 = getelementptr inbounds [9 x [3 x float]], [9 x [3 x float]]* %asc, i32 0, i32 %idx1
%gep2 = getelementptr inbounds [3 x float], [3 x float]* %gep1, i32 0, i32 %idx2
%load = load float, float* %gep2
ret void
}

; Test that we compute the correct index when the global has 3 levels of
; nested arrays.
; CHECK-LABEL: @merge_gep_multi_level
; CHECK: %0 = mul i32 %idx0, 9
; CHECK: %1 = add i32 %idx1, %0
; CHECK: %2 = mul i32 %1, 3
; CHECK: %3 = add i32 %idx2, %2
; CHECK: %4 = getelementptr [6912 x float], [6912 x float] addrspace(3)* @ArrayOfArrayOfArray.1dim, i32 0, i32 %3
; CHECK: load float, float addrspace(3)* %4
define void @merge_gep_multi_level(i32 %idx0, i32 %idx1, i32 %idx2) {
entry:
%gep0 = getelementptr inbounds [256 x [9 x [3 x float]]], [256 x [9 x [3 x float]]] addrspace(3)* @ArrayOfArrayOfArray, i32 0, i32 %idx0
%gep1 = getelementptr inbounds [9 x [3 x float]], [9 x [3 x float]] addrspace(3)* %gep0, i32 0, i32 %idx1
%gep2 = getelementptr inbounds [3 x float], [3 x float] addrspace(3)* %gep1, i32 0, i32 %idx2
%load = load float, float addrspace(3)* %gep2
ret void
}

; Test that we compute the correct index when the addrspacecast includes both a
; change in address space and a change in the underlying type. I did not see
; this pattern in IR generated from hlsl, but we can handle this case so I am
; adding a test for it anyway.
; CHECK-LABEL: addrspace_cast_new_type
; CHECK: %0 = mul i32 %idx0, 9
; CHECK: %1 = add i32 %idx1, %0
; CHECK: %2 = getelementptr [2304 x float], [2304 x float] addrspace(3)* @ArrayOfArray.1dim, i32 0, i32 %1
; CHECK: %3 = addrspacecast float addrspace(3)* %2 to i32*
; CHECK: load i32, i32* %3
define void @addrspace_cast_new_type(i32 %idx0, i32 %idx1) {
entry:
%gep0 = getelementptr inbounds [256 x [9 x float]], [256 x [9 x float]] addrspace(3)* @ArrayOfArray, i32 0, i32 %idx0
%asc = addrspacecast [9 x float] addrspace(3)* %gep0 to [3 x i32]*
%gep1 = getelementptr inbounds [3 x i32], [3 x i32]* %asc, i32 0, i32 %idx1
%load = load i32, i32* %gep1
ret void
}
Loading

0 comments on commit 481090a

Please sign in to comment.