From 547b3b2ba6c25fbd8e12604a08b5d123bfac180b Mon Sep 17 00:00:00 2001 From: "Kwasniewski, Patryk" Date: Thu, 6 Feb 2025 13:31:50 +0000 Subject: [PATCH] GEP LSR pass: improve MulExpr handling Improve how mul expressions are handled in GEP LSR pass. --- .../GEPLoopStrengthReduction.cpp | 104 +++++++++++------- 1 file changed, 63 insertions(+), 41 deletions(-) diff --git a/IGC/Compiler/Optimizer/OpenCLPasses/GEPLoopStrengthReduction/GEPLoopStrengthReduction.cpp b/IGC/Compiler/Optimizer/OpenCLPasses/GEPLoopStrengthReduction/GEPLoopStrengthReduction.cpp index 56741b4c19b6..0a89d59a08d6 100644 --- a/IGC/Compiler/Optimizer/OpenCLPasses/GEPLoopStrengthReduction/GEPLoopStrengthReduction.cpp +++ b/IGC/Compiler/Optimizer/OpenCLPasses/GEPLoopStrengthReduction/GEPLoopStrengthReduction.cpp @@ -289,11 +289,30 @@ class Analyzer void analyze(SmallVectorImpl &Result); private: + + // Represents deconstructed SCEV expression { start, +, step }. + // Start SCEV will be used to calculate base pointer, and Step SCEV + // will increase new induction variable on each iteration. + struct DeconstructedSCEV + { + DeconstructedSCEV() : Start(nullptr), Step(nullptr), ConvertedMulExpr(false) {} + + bool isValid(); + + const SCEV* Start; + const SCEV* Step; + + // True if input SCEV: + // x * { start, +, step } + // Was converted into: + // { x * start, +, x * step } + bool ConvertedMulExpr; + }; + void analyzeGEP(GetElementPtrInst *GEP); bool doInitialValidation(GetElementPtrInst *GEP); - bool deconstructSCEV(const SCEV *S, const SCEV *&Start, const SCEV *&Step); - bool isValidStep(const SCEV* S); + bool deconstructSCEV(const SCEV *S, DeconstructedSCEV &Result); DominatorTree &DT; Loop &L; @@ -961,15 +980,16 @@ void Analyzer::analyzeGEP(GetElementPtrInst *GEP) if (!SCEVHelper::isValid(S)) return; - const SCEV *Start = nullptr; - const SCEV *Step = nullptr; - - if (!deconstructSCEV(S, Start, Step)) + Analyzer::DeconstructedSCEV Result; + if (!deconstructSCEV(S, Result)) return; - if (!isValidStep(Step)) + if (!Result.isValid()) return; + const SCEV* Start = Result.Start; + const SCEV* Step = Result.Step; + if (S->getType() != Start->getType()) Start = isa(S) ? SE.getSignExtendExpr(Start, S->getType()) : SE.getZeroExtendExpr(Start, S->getType()); @@ -1063,7 +1083,7 @@ bool Analyzer::doInitialValidation(GetElementPtrInst *GEP) // Takes SCEV expression returned by ScalarEvolution and deconstructs it into // expected format { start, +, step }. Returns false if expressions can't be // parsed and reduced. -bool Analyzer::deconstructSCEV(const SCEV *S, const SCEV *&Start, const SCEV *&Step) +bool Analyzer::deconstructSCEV(const SCEV *S, Analyzer::DeconstructedSCEV &Result) { // Drop ext instructions to analyze nested content. S = SCEVHelper::dropExt(S); @@ -1075,8 +1095,8 @@ bool Analyzer::deconstructSCEV(const SCEV *S, const SCEV *&Start, const SCEV *&S // induction variable. if (SE.isLoopInvariant(S, &L)) { - Start = S; - Step = SE.getConstant(Type::getInt64Ty(L.getHeader()->getContext()), 0); + Result.Start = S; + Result.Step = SE.getConstant(Type::getInt64Ty(L.getHeader()->getContext()), 0); return true; } @@ -1097,10 +1117,10 @@ bool Analyzer::deconstructSCEV(const SCEV *S, const SCEV *&Start, const SCEV *&S if (!SE.isLoopInvariant(OpStep, &L)) return false; - Start = Add->getStart(); - Step = OpStep; + Result.Start = Add->getStart(); + Result.Step = OpStep; - return IGCLLVM::isSafeToExpandAt(Start, &L.getLoopPreheader()->back(), &SE, &E); + return IGCLLVM::isSafeToExpandAt(Result.Start, &L.getLoopPreheader()->back(), &SE, &E); } // If expression is: @@ -1114,30 +1134,30 @@ bool Analyzer::deconstructSCEV(const SCEV *S, const SCEV *&Start, const SCEV *&S if (auto *Add = dyn_cast(S)) { // There can be only one expression with step != 0. - Step = SE.getConstant(Type::getInt64Ty(L.getHeader()->getContext()), 0); + Result.Step = SE.getConstant(Type::getInt64Ty(L.getHeader()->getContext()), 0); - const SCEV *OpSCEV = nullptr; - const SCEV *OpStep = nullptr; SCEVHelper::SCEVAddBuilder Builder(SE); for (auto *Op : Add->operands()) { - if (!deconstructSCEV(Op, OpSCEV, OpStep)) + Analyzer::DeconstructedSCEV OpResult; + + if (!deconstructSCEV(Op, OpResult)) return false; - if (!OpStep->isZero()) + if (!OpResult.Step->isZero()) { - if (!Step->isZero()) + if (!Result.Step->isZero()) return false; // unsupported expression with multiple steps - Step = OpStep; + Result.Step = OpResult.Step; } - Builder.add(OpSCEV); + Builder.add(OpResult.Start); } - Start = Builder.build(); + Result.Start = Builder.build(); - return IGCLLVM::isSafeToExpandAt(Start, &L.getLoopPreheader()->back(), &SE, &E); + return IGCLLVM::isSafeToExpandAt(Result.Start, &L.getLoopPreheader()->back(), &SE, &E); } // If expression is: @@ -1148,24 +1168,20 @@ bool Analyzer::deconstructSCEV(const SCEV *S, const SCEV *&Start, const SCEV *&S // Warning: GEP's new index will not be a constant integer, but a new SCEV expression. if (auto *Mul = dyn_cast(S)) { - if (IGC_IS_FLAG_DISABLED(EnableGEPLSRMulExpr)) - return false; - // SCEVAddRecExpr will be SCEV with step != 0. Any other SCEV is a multiplier. bool FoundAddRec = false; SCEVHelper::SCEVMulBuilder StartBuilder(SE), StepBuilder(SE); for (auto *Op : Mul->operands()) { - const SCEV *OpSCEV = nullptr; - const SCEV *OpStep = nullptr; - if (!deconstructSCEV(Op, OpSCEV, OpStep)) + Analyzer::DeconstructedSCEV OpResult; + if (!deconstructSCEV(Op, OpResult)) return false; - if (OpStep->isZero()) + if (OpResult.Step->isZero()) { - StartBuilder.add(OpSCEV); - StepBuilder.add(OpSCEV); + StartBuilder.add(OpResult.Start); + StepBuilder.add(OpResult.Start); } else { @@ -1173,35 +1189,41 @@ bool Analyzer::deconstructSCEV(const SCEV *S, const SCEV *&Start, const SCEV *&S return false; // unsupported expression with multiple SCEVAddRecExpr FoundAddRec = true; - StartBuilder.add(OpSCEV); - StepBuilder.add(OpStep); + StartBuilder.add(OpResult.Start); + StepBuilder.add(OpResult.Step); } } if (!FoundAddRec) return false; - Start = StartBuilder.build(); - Step = StepBuilder.build(); + Result.Start = StartBuilder.build(); + Result.Step = StepBuilder.build(); + Result.ConvertedMulExpr = true; - if (!SE.isLoopInvariant(Step, &L)) + if (!SE.isLoopInvariant(Result.Step, &L)) return false; - return IGCLLVM::isSafeToExpandAt(Start, &L.getLoopPreheader()->back(), &SE, &E); + return IGCLLVM::isSafeToExpandAt(Result.Start, &L.getLoopPreheader()->back(), &SE, &E); } return false; } -bool Analyzer::isValidStep(const SCEV* S) +bool Analyzer::DeconstructedSCEV::isValid() { - auto Ty = SCEVHelper::dropExt(S)->getSCEVType(); + if (!Start || !Step) + return false; + + // Validate step. + auto Ty = SCEVHelper::dropExt(Step)->getSCEVType(); if (Ty == scConstant) return true; - if (Ty == scMulExpr && IGC_IS_FLAG_ENABLED(EnableGEPLSRMulExpr)) + bool IsMul = Ty == scMulExpr || ConvertedMulExpr; + if (IsMul && IGC_IS_FLAG_ENABLED(EnableGEPLSRMulExpr)) return true; return IGC_IS_FLAG_ENABLED(EnableGEPLSRUnknownConstantStep);