Skip to content

Commit

Permalink
GEP LSR pass: improve MulExpr handling
Browse files Browse the repository at this point in the history
Improve how mul expressions are handled in GEP LSR pass.
  • Loading branch information
pkwasnie-intel authored and igcbot committed Feb 14, 2025
1 parent 97698ad commit 547b3b2
Showing 1 changed file with 63 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,30 @@ class Analyzer
void analyze(SmallVectorImpl<ReductionCandidateGroup> &Result);

private:

// Represents deconstructed SCEV expression { start, +, step }.
// Start SCEV will be used to calculate base pointer, and Step SCEV
// will increase new induction variable on each iteration.
struct DeconstructedSCEV
{
DeconstructedSCEV() : Start(nullptr), Step(nullptr), ConvertedMulExpr(false) {}

bool isValid();

const SCEV* Start;
const SCEV* Step;

// True if input SCEV:
// x * { start, +, step }
// Was converted into:
// { x * start, +, x * step }
bool ConvertedMulExpr;
};

void analyzeGEP(GetElementPtrInst *GEP);
bool doInitialValidation(GetElementPtrInst *GEP);

bool deconstructSCEV(const SCEV *S, const SCEV *&Start, const SCEV *&Step);
bool isValidStep(const SCEV* S);
bool deconstructSCEV(const SCEV *S, DeconstructedSCEV &Result);

DominatorTree &DT;
Loop &L;
Expand Down Expand Up @@ -961,15 +980,16 @@ void Analyzer::analyzeGEP(GetElementPtrInst *GEP)
if (!SCEVHelper::isValid(S))
return;

const SCEV *Start = nullptr;
const SCEV *Step = nullptr;

if (!deconstructSCEV(S, Start, Step))
Analyzer::DeconstructedSCEV Result;
if (!deconstructSCEV(S, Result))
return;

if (!isValidStep(Step))
if (!Result.isValid())
return;

const SCEV* Start = Result.Start;
const SCEV* Step = Result.Step;

if (S->getType() != Start->getType())
Start = isa<SCEVSignExtendExpr>(S) ? SE.getSignExtendExpr(Start, S->getType()) : SE.getZeroExtendExpr(Start, S->getType());

Expand Down Expand Up @@ -1063,7 +1083,7 @@ bool Analyzer::doInitialValidation(GetElementPtrInst *GEP)
// Takes SCEV expression returned by ScalarEvolution and deconstructs it into
// expected format { start, +, step }. Returns false if expressions can't be
// parsed and reduced.
bool Analyzer::deconstructSCEV(const SCEV *S, const SCEV *&Start, const SCEV *&Step)
bool Analyzer::deconstructSCEV(const SCEV *S, Analyzer::DeconstructedSCEV &Result)
{
// Drop ext instructions to analyze nested content.
S = SCEVHelper::dropExt(S);
Expand All @@ -1075,8 +1095,8 @@ bool Analyzer::deconstructSCEV(const SCEV *S, const SCEV *&Start, const SCEV *&S
// induction variable.
if (SE.isLoopInvariant(S, &L))
{
Start = S;
Step = SE.getConstant(Type::getInt64Ty(L.getHeader()->getContext()), 0);
Result.Start = S;
Result.Step = SE.getConstant(Type::getInt64Ty(L.getHeader()->getContext()), 0);
return true;
}

Expand All @@ -1097,10 +1117,10 @@ bool Analyzer::deconstructSCEV(const SCEV *S, const SCEV *&Start, const SCEV *&S
if (!SE.isLoopInvariant(OpStep, &L))
return false;

Start = Add->getStart();
Step = OpStep;
Result.Start = Add->getStart();
Result.Step = OpStep;

return IGCLLVM::isSafeToExpandAt(Start, &L.getLoopPreheader()->back(), &SE, &E);
return IGCLLVM::isSafeToExpandAt(Result.Start, &L.getLoopPreheader()->back(), &SE, &E);
}

// If expression is:
Expand All @@ -1114,30 +1134,30 @@ bool Analyzer::deconstructSCEV(const SCEV *S, const SCEV *&Start, const SCEV *&S
if (auto *Add = dyn_cast<SCEVAddExpr>(S))
{
// There can be only one expression with step != 0.
Step = SE.getConstant(Type::getInt64Ty(L.getHeader()->getContext()), 0);
Result.Step = SE.getConstant(Type::getInt64Ty(L.getHeader()->getContext()), 0);

const SCEV *OpSCEV = nullptr;
const SCEV *OpStep = nullptr;
SCEVHelper::SCEVAddBuilder Builder(SE);

for (auto *Op : Add->operands())
{
if (!deconstructSCEV(Op, OpSCEV, OpStep))
Analyzer::DeconstructedSCEV OpResult;

if (!deconstructSCEV(Op, OpResult))
return false;

if (!OpStep->isZero())
if (!OpResult.Step->isZero())
{
if (!Step->isZero())
if (!Result.Step->isZero())
return false; // unsupported expression with multiple steps
Step = OpStep;
Result.Step = OpResult.Step;
}

Builder.add(OpSCEV);
Builder.add(OpResult.Start);
}

Start = Builder.build();
Result.Start = Builder.build();

return IGCLLVM::isSafeToExpandAt(Start, &L.getLoopPreheader()->back(), &SE, &E);
return IGCLLVM::isSafeToExpandAt(Result.Start, &L.getLoopPreheader()->back(), &SE, &E);
}

// If expression is:
Expand All @@ -1148,60 +1168,62 @@ bool Analyzer::deconstructSCEV(const SCEV *S, const SCEV *&Start, const SCEV *&S
// Warning: GEP's new index will not be a constant integer, but a new SCEV expression.
if (auto *Mul = dyn_cast<SCEVMulExpr>(S))
{
if (IGC_IS_FLAG_DISABLED(EnableGEPLSRMulExpr))
return false;

// SCEVAddRecExpr will be SCEV with step != 0. Any other SCEV is a multiplier.
bool FoundAddRec = false;
SCEVHelper::SCEVMulBuilder StartBuilder(SE), StepBuilder(SE);

for (auto *Op : Mul->operands())
{
const SCEV *OpSCEV = nullptr;
const SCEV *OpStep = nullptr;
if (!deconstructSCEV(Op, OpSCEV, OpStep))
Analyzer::DeconstructedSCEV OpResult;
if (!deconstructSCEV(Op, OpResult))
return false;

if (OpStep->isZero())
if (OpResult.Step->isZero())
{
StartBuilder.add(OpSCEV);
StepBuilder.add(OpSCEV);
StartBuilder.add(OpResult.Start);
StepBuilder.add(OpResult.Start);
}
else
{
if (FoundAddRec)
return false; // unsupported expression with multiple SCEVAddRecExpr
FoundAddRec = true;

StartBuilder.add(OpSCEV);
StepBuilder.add(OpStep);
StartBuilder.add(OpResult.Start);
StepBuilder.add(OpResult.Step);
}
}

if (!FoundAddRec)
return false;

Start = StartBuilder.build();
Step = StepBuilder.build();
Result.Start = StartBuilder.build();
Result.Step = StepBuilder.build();
Result.ConvertedMulExpr = true;

if (!SE.isLoopInvariant(Step, &L))
if (!SE.isLoopInvariant(Result.Step, &L))
return false;

return IGCLLVM::isSafeToExpandAt(Start, &L.getLoopPreheader()->back(), &SE, &E);
return IGCLLVM::isSafeToExpandAt(Result.Start, &L.getLoopPreheader()->back(), &SE, &E);
}

return false;
}


bool Analyzer::isValidStep(const SCEV* S)
bool Analyzer::DeconstructedSCEV::isValid()
{
auto Ty = SCEVHelper::dropExt(S)->getSCEVType();
if (!Start || !Step)
return false;

// Validate step.
auto Ty = SCEVHelper::dropExt(Step)->getSCEVType();

if (Ty == scConstant)
return true;

if (Ty == scMulExpr && IGC_IS_FLAG_ENABLED(EnableGEPLSRMulExpr))
bool IsMul = Ty == scMulExpr || ConvertedMulExpr;
if (IsMul && IGC_IS_FLAG_ENABLED(EnableGEPLSRMulExpr))
return true;

return IGC_IS_FLAG_ENABLED(EnableGEPLSRUnknownConstantStep);
Expand Down

0 comments on commit 547b3b2

Please sign in to comment.