From 267c86618d90e7e6d393f45c282fc66c72c334f8 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 4 Dec 2024 20:25:00 +0100 Subject: [PATCH 1/2] Flip order of operands to Screma. Now they resemble the logical execution order: first the map function, then reduce/scan operations if applicable. --- src/Futhark/AD/Fwd.hs | 6 +- src/Futhark/AD/Rev/SOAC.hs | 2 +- src/Futhark/Analysis/HORep/MapNest.hs | 2 +- src/Futhark/Analysis/HORep/SOAC.hs | 6 +- src/Futhark/IR/Parse.hs | 21 ++--- src/Futhark/IR/SOACS.hs | 2 +- src/Futhark/IR/SOACS/SOAC.hs | 81 +++++++++----------- src/Futhark/IR/SOACS/Simplify.hs | 28 ++++--- src/Futhark/Optimise/Fusion.hs | 2 +- src/Futhark/Optimise/Fusion/RulesWithAccs.hs | 4 +- src/Futhark/Optimise/Fusion/TryFusion.hs | 17 ++-- src/Futhark/Optimise/GenRedOpt.hs | 2 +- src/Futhark/Pass/ExtractKernels.hs | 6 +- src/Futhark/Tools.hs | 8 +- src/Futhark/Transform/FirstOrderTransform.hs | 2 +- 15 files changed, 93 insertions(+), 96 deletions(-) diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index 52e5ff87a3..e194ec20d6 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -264,13 +264,13 @@ zeroFromSubExp (Var v) = do letExp "zero" $ zeroExp t fwdSOAC :: Pat Type -> StmAux () -> SOAC SOACS -> ADM () -fwdSOAC pat aux (Screma size xs (ScremaForm scs reds f)) = do +fwdSOAC pat aux (Screma size xs (ScremaForm f scs reds)) = do pat' <- bundleNewPat pat xs' <- bundleTangents xs + f' <- fwdLambda f scs' <- mapM fwdScan scs reds' <- mapM fwdRed reds - f' <- fwdLambda f - addStm $ Let pat' aux $ Op $ Screma size xs' $ ScremaForm scs' reds' f' + addStm $ Let pat' aux $ Op $ Screma size xs' $ ScremaForm f' scs' reds' where fwdScan :: Scan SOACS -> ADM (Scan SOACS) fwdScan sc = do diff --git a/src/Futhark/AD/Rev/SOAC.hs b/src/Futhark/AD/Rev/SOAC.hs index 7fe54468d1..7df5c47cbc 100644 --- a/src/Futhark/AD/Rev/SOAC.hs +++ b/src/Futhark/AD/Rev/SOAC.hs @@ -227,7 +227,7 @@ mapOp (Lambda [pa1, pa2] _ lam_body) cs == mempty, [map_stm] <- stmsToList (bodyStms lam_body), (Let (Pat [pe]) _ (Op scrm)) <- map_stm, - (Screma _ [a1, a2] (ScremaForm [] [] map_lam)) <- scrm, + (Screma _ [a1, a2] (ScremaForm map_lam [] [])) <- scrm, (a1 == paramName pa1 && a2 == paramName pa2) || (a1 == paramName pa2 && a2 == paramName pa1), r == Var (patElemName pe) = Just map_lam diff --git a/src/Futhark/Analysis/HORep/MapNest.hs b/src/Futhark/Analysis/HORep/MapNest.hs index 3af1367354..f3104e17c6 100644 --- a/src/Futhark/Analysis/HORep/MapNest.hs +++ b/src/Futhark/Analysis/HORep/MapNest.hs @@ -76,7 +76,7 @@ fromSOAC' :: [Ident] -> SOAC rep -> m (Maybe (MapNest rep)) -fromSOAC' bound (SOAC.Screma w inps (SOAC.ScremaForm [] [] lam)) = do +fromSOAC' bound (SOAC.Screma w inps (SOAC.ScremaForm lam [] [])) = do maybenest <- case ( stmsToList $ bodyStms $ lambdaBody lam, bodyResult $ lambdaBody lam ) of diff --git a/src/Futhark/Analysis/HORep/SOAC.hs b/src/Futhark/Analysis/HORep/SOAC.hs index bc1ac85b88..7854ab7dd4 100644 --- a/src/Futhark/Analysis/HORep/SOAC.hs +++ b/src/Futhark/Analysis/HORep/SOAC.hs @@ -435,7 +435,7 @@ newWidth (inp : _) _ = arraySize 0 $ inputType inp lambda :: SOAC rep -> Lambda rep lambda (Stream _ _ _ lam) = lam lambda (Scatter _len _ivs _spec lam) = lam -lambda (Screma _ _ (ScremaForm _ _ lam)) = lam +lambda (Screma _ _ (ScremaForm lam _ _)) = lam lambda (Hist _ _ _ lam) = lam -- | Set the lambda used in the SOAC. @@ -444,8 +444,8 @@ setLambda lam (Stream w arrs nes _) = Stream w arrs nes lam setLambda lam (Scatter len arrs spec _lam) = Scatter len arrs spec lam -setLambda lam (Screma w arrs (ScremaForm scan red _)) = - Screma w arrs (ScremaForm scan red lam) +setLambda lam (Screma w arrs (ScremaForm _ scan red)) = + Screma w arrs (ScremaForm lam scan red) setLambda lam (Hist w ops inps _) = Hist w ops inps lam diff --git a/src/Futhark/IR/Parse.hs b/src/Futhark/IR/Parse.hs index 2bafc354a1..71a2ef1755 100644 --- a/src/Futhark/IR/Parse.hs +++ b/src/Futhark/IR/Parse.hs @@ -727,24 +727,25 @@ pSOAC pr = <*> p pScremaForm = SOAC.ScremaForm - <$> braces (pScan pr `sepBy` pComma) + <$> pLambda pr <* pComma - <*> braces (pReduce pr `sepBy` pComma) + <*> braces (pScan pr `sepBy` pComma) <* pComma - <*> pLambda pr + <*> braces (pReduce pr `sepBy` pComma) pRedomapForm = - SOAC.ScremaForm mempty - <$> braces (pReduce pr `sepBy` pComma) + SOAC.ScremaForm + <$> pLambda pr + <*> pure [] <* pComma - <*> pLambda pr + <*> braces (pReduce pr `sepBy` pComma) pScanomapForm = SOAC.ScremaForm - <$> braces (pScan pr `sepBy` pComma) + <$> pLambda pr <* pComma - <*> pure mempty - <*> pLambda pr + <*> braces (pScan pr `sepBy` pComma) + <*> pure [] pMapForm = - SOAC.ScremaForm mempty mempty <$> pLambda pr + SOAC.ScremaForm <$> pLambda pr <*> pure mempty <*> pure mempty pScatter = keyword "scatter" *> parens diff --git a/src/Futhark/IR/SOACS.hs b/src/Futhark/IR/SOACS.hs index 7018acb928..151efd3923 100644 --- a/src/Futhark/IR/SOACS.hs +++ b/src/Futhark/IR/SOACS.hs @@ -55,7 +55,7 @@ usesAD prog = any stmUsesAD (progConsts prog) || any funUsesAD (progFuns prog) expUsesAD (Op JVP {}) = True expUsesAD (Op VJP {}) = True expUsesAD (Op (Stream _ _ _ lam)) = lamUsesAD lam - expUsesAD (Op (Screma _ _ (ScremaForm scans reds lam))) = + expUsesAD (Op (Screma _ _ (ScremaForm lam scans reds))) = lamUsesAD lam || any (lamUsesAD . scanLambda) scans || any (lamUsesAD . redLambda) reds diff --git a/src/Futhark/IR/SOACS/SOAC.hs b/src/Futhark/IR/SOACS/SOAC.hs index eea5fdab79..e048b287a3 100644 --- a/src/Futhark/IR/SOACS/SOAC.hs +++ b/src/Futhark/IR/SOACS/SOAC.hs @@ -148,14 +148,14 @@ data HistOp rep = HistOp -- | The essential parts of a 'Screma' factored out (everything -- except the input arrays). data ScremaForm rep = ScremaForm - { scremaScans :: [Scan rep], - scremaReduces :: [Reduce rep], - -- | The "main" lambda of the Screma. For a map, this is + { -- | The "main" lambda of the Screma. For a map, this is -- equivalent to 'isMapSOAC'. Note that the meaning of the return -- value of this lambda depends crucially on exactly which Screma -- this is. The parameters will correspond exactly to elements of -- the input arrays, however. - scremaLambda :: Lambda rep + scremaLambda :: Lambda rep, + scremaScans :: [Scan rep], + scremaReduces :: [Reduce rep] } deriving (Eq, Ord, Show) @@ -221,7 +221,7 @@ singleReduce reds = -- | The types produced by a single 'Screma', given the size of the -- input array. scremaType :: SubExp -> ScremaForm rep -> [Type] -scremaType w (ScremaForm scans reds map_lam) = +scremaType w (ScremaForm map_lam scans reds) = scan_tps ++ red_tps ++ map (`arrayOfRow` w) map_tps where scan_tps = @@ -258,12 +258,12 @@ nilFn = Lambda mempty mempty (mkBody mempty mempty) -- | Construct a Screma with possibly multiple scans, and -- the given map function. scanomapSOAC :: [Scan rep] -> Lambda rep -> ScremaForm rep -scanomapSOAC scans = ScremaForm scans [] +scanomapSOAC scans lam = ScremaForm lam scans [] -- | Construct a Screma with possibly multiple reductions, and -- the given map function. redomapSOAC :: [Reduce rep] -> Lambda rep -> ScremaForm rep -redomapSOAC = ScremaForm [] +redomapSOAC reds lam = ScremaForm lam [] reds -- | Construct a Screma with possibly multiple scans, and identity map -- function. @@ -287,11 +287,11 @@ reduceSOAC reds = redomapSOAC reds <$> mkIdentityLambda ts -- | Construct a Screma corresponding to a map. mapSOAC :: Lambda rep -> ScremaForm rep -mapSOAC = ScremaForm [] [] +mapSOAC lam = ScremaForm lam [] [] -- | Does this Screma correspond to a scan-map composition? isScanomapSOAC :: ScremaForm rep -> Maybe ([Scan rep], Lambda rep) -isScanomapSOAC (ScremaForm scans reds map_lam) = do +isScanomapSOAC (ScremaForm map_lam scans reds) = do guard $ null reds guard $ not $ null scans pure (scans, map_lam) @@ -305,7 +305,7 @@ isScanSOAC form = do -- | Does this Screma correspond to a reduce-map composition? isRedomapSOAC :: ScremaForm rep -> Maybe ([Reduce rep], Lambda rep) -isRedomapSOAC (ScremaForm scans reds map_lam) = do +isRedomapSOAC (ScremaForm map_lam scans reds) = do guard $ null scans guard $ not $ null reds pure (reds, map_lam) @@ -320,7 +320,7 @@ isReduceSOAC form = do -- | Does this Screma correspond to a simple map, without any -- reduction or scan results? isMapSOAC :: ScremaForm rep -> Maybe (Lambda rep) -isMapSOAC (ScremaForm scans reds map_lam) = do +isMapSOAC (ScremaForm map_lam scans reds) = do guard $ null scans guard $ null reds pure map_lam @@ -443,12 +443,13 @@ mapSOACM tv (Hist w arrs ops bucket_fun) = ) ops <*> mapOnSOACLambda tv bucket_fun -mapSOACM tv (Screma w arrs (ScremaForm scans reds map_lam)) = +mapSOACM tv (Screma w arrs (ScremaForm map_lam scans reds)) = Screma <$> mapOnSOACSubExp tv w <*> mapM (mapOnSOACVName tv) arrs <*> ( ScremaForm - <$> forM + <$> mapOnSOACLambda tv map_lam + <*> forM scans ( \(Scan red_lam red_nes) -> Scan @@ -462,7 +463,6 @@ mapSOACM tv (Screma w arrs (ScremaForm scans reds map_lam)) = <$> mapOnSOACLambda tv red_lam <*> mapM (mapOnSOACSubExp tv) red_nes ) - <*> mapOnSOACLambda tv map_lam ) -- | A helper for defining 'TraverseOpStms'. @@ -547,7 +547,7 @@ instance AliasedOp SOAC where consumedInOp VJP {} = mempty -- Only map functions can consume anything. The operands to scan -- and reduce functions are always considered "fresh". - consumedInOp (Screma _ arrs (ScremaForm _ _ map_lam)) = + consumedInOp (Screma _ arrs (ScremaForm map_lam _ _)) = mapNames consumedArray $ consumedByLambda map_lam where consumedArray v = fromMaybe v $ lookup v params_to_arrs @@ -586,12 +586,12 @@ instance CanBeAliased SOAC where arrs (map (mapHistOp (Alias.analyseLambda aliases)) ops) (Alias.analyseLambda aliases bucket_fun) - addOpAliases aliases (Screma w arrs (ScremaForm scans reds map_lam)) = + addOpAliases aliases (Screma w arrs (ScremaForm map_lam scans reds)) = Screma w arrs $ ScremaForm + (Alias.analyseLambda aliases map_lam) (map onScan scans) (map onRed reds) - (Alias.analyseLambda aliases map_lam) where onRed red = red {redLambda = Alias.analyseLambda aliases $ redLambda red} onScan scan = scan {scanLambda = Alias.analyseLambda aliases $ scanLambda scan} @@ -642,7 +642,7 @@ instance IsOp SOAC where lam (zipWith (<>) (map depsOf' args) (map depsOf' vec)) <> map (const $ freeIn args <> freeIn lam) (lambdaParams lam) - opDependencies (Screma w arrs (ScremaForm scans reds map_lam)) = + opDependencies (Screma w arrs (ScremaForm map_lam scans reds)) = let (scans_in, reds_in, map_deps) = splitAt3 (scanResults scans) (redResults reds) $ lambdaDependencies mempty map_lam (depsOfArrays w arrs) @@ -682,7 +682,7 @@ instance (RepTypes rep) => ST.IndexOp (SOAC rep) where SubExpRes _ (Var v) -> uncurry (flip ST.Indexed) <$> M.lookup v arr_indexes' _ -> Nothing where - lambdaAndSubExp (Screma _ arrs (ScremaForm scans reds map_lam)) = + lambdaAndSubExp (Screma _ arrs (ScremaForm map_lam scans reds)) = nthMapOut (scanResults scans + redResults reds) map_lam arrs lambdaAndSubExp _ = Nothing @@ -849,7 +849,7 @@ typeCheckSOAC (Hist w arrs ops bucket_fun) = do <> prettyTuple (lambdaReturnType bucket_fun) <> " but should have type " <> prettyTuple bucket_ret_t -typeCheckSOAC (Screma w arrs (ScremaForm scans reds map_lam)) = do +typeCheckSOAC (Screma w arrs (ScremaForm map_lam scans reds)) = do TC.require [Prim int64] w arrs' <- TC.checkSOACArrayArgs w arrs TC.checkLambda map_lam arrs' @@ -906,12 +906,12 @@ instance RephraseOp SOAC where where onOp (HistOp dest_shape rf dests nes op) = HistOp dest_shape rf dests nes <$> rephraseLambda r op - rephraseInOp r (Screma w arrs (ScremaForm scans red lam)) = + rephraseInOp r (Screma w arrs (ScremaForm lam scans red)) = Screma w arrs <$> ( ScremaForm - <$> mapM onScan scans + <$> rephraseLambda r lam + <*> mapM onScan scans <*> mapM onRed red - <*> rephraseLambda r lam ) where onScan (Scan op nes) = Scan <$> rephraseLambda r op <*> pure nes @@ -928,11 +928,11 @@ instance (OpMetrics (Op rep)) => OpMetrics (SOAC rep) where inside "Scatter" $ lambdaMetrics lam opMetrics (Hist _ _ ops bucket_fun) = inside "Hist" $ mapM_ (lambdaMetrics . histOp) ops >> lambdaMetrics bucket_fun - opMetrics (Screma _ _ (ScremaForm scans reds map_lam)) = + opMetrics (Screma _ _ (ScremaForm map_lam scans reds)) = inside "Screma" $ do + lambdaMetrics map_lam mapM_ (lambdaMetrics . scanLambda) scans mapM_ (lambdaMetrics . redLambda) reds - lambdaMetrics map_lam instance (PrettyRep rep) => PP.Pretty (SOAC rep) where pretty (VJP lam args vec) = @@ -961,56 +961,49 @@ instance (PrettyRep rep) => PP.Pretty (SOAC rep) where ppScatter w arrs dests lam pretty (Hist w arrs ops bucket_fun) = ppHist w arrs ops bucket_fun - pretty (Screma w arrs (ScremaForm scans reds map_lam)) + pretty (Screma w arrs (ScremaForm map_lam scans reds)) | null scans, null reds = "map" <> (parens . align) ( pretty w - <> comma - ppTuple' (map pretty arrs) - <> comma - pretty map_lam + <> comma ppTuple' (map pretty arrs) + <> comma pretty map_lam ) | null scans = "redomap" <> (parens . align) ( pretty w - <> comma - ppTuple' (map pretty arrs) + <> comma ppTuple' (map pretty arrs) + <> comma pretty map_lam <> comma PP.braces (mconcat $ intersperse (comma <> PP.line) $ map pretty reds) - <> comma - pretty map_lam ) | null reds = "scanomap" <> (parens . align) ( pretty w + <> comma ppTuple' (map pretty arrs) + <> comma pretty map_lam <> comma - ppTuple' (map pretty arrs) - <> comma - PP.braces (mconcat $ intersperse (comma <> PP.line) $ map pretty scans) - <> comma - pretty map_lam + PP.braces + (mconcat $ intersperse (comma <> PP.line) $ map pretty scans) ) pretty (Screma w arrs form) = ppScrema w arrs form -- | Prettyprint the given Screma. ppScrema :: (PrettyRep rep, Pretty inp) => SubExp -> [inp] -> ScremaForm rep -> Doc ann -ppScrema w arrs (ScremaForm scans reds map_lam) = +ppScrema w arrs (ScremaForm map_lam scans reds) = "screma" <> (parens . align) ( pretty w - <> comma - ppTuple' (map pretty arrs) + <> comma ppTuple' (map pretty arrs) + <> comma pretty map_lam <> comma PP.braces (mconcat $ intersperse (comma <> PP.line) $ map pretty scans) <> comma PP.braces (mconcat $ intersperse (comma <> PP.line) $ map pretty reds) - <> comma - pretty map_lam ) -- | Prettyprint the given Stream. diff --git a/src/Futhark/IR/SOACS/Simplify.hs b/src/Futhark/IR/SOACS/Simplify.hs index def030da03..e15a6332b4 100644 --- a/src/Futhark/IR/SOACS/Simplify.hs +++ b/src/Futhark/IR/SOACS/Simplify.hs @@ -117,7 +117,7 @@ simplifySOAC (Hist w imgs ops bfun) = do imgs' <- mapM Engine.simplify imgs (bfun', bfun_hoisted) <- Engine.enterLoop $ Engine.simplifyLambda mempty bfun pure (Hist w' imgs' ops' bfun', mconcat hoisted <> bfun_hoisted) -simplifySOAC (Screma w arrs (ScremaForm scans reds map_lam)) = do +simplifySOAC (Screma w arrs (ScremaForm map_lam scans reds)) = do (scans', scans_hoisted) <- fmap unzip $ forM scans $ \(Scan lam nes) -> do (lam', hoisted) <- Engine.simplifyLambda mempty lam @@ -136,7 +136,7 @@ simplifySOAC (Screma w arrs (ScremaForm scans reds map_lam)) = do <$> ( Screma <$> Engine.simplify w <*> Engine.simplify arrs - <*> pure (ScremaForm scans' reds' map_lam') + <*> pure (ScremaForm map_lam' scans' reds') ) <*> pure (mconcat scans_hoisted <> mconcat reds_hoisted <> map_lam_hoisted) @@ -399,10 +399,10 @@ removeUnusedSOACInput :: TopDownRuleOp rep removeUnusedSOACInput _ pat aux op | Just (Screma w arrs form :: SOAC rep) <- asSOAC op, - ScremaForm scan reduce map_lam <- form, + ScremaForm map_lam scan reduce <- form, Just (used_arrs, map_lam') <- remove map_lam arrs = Simplify . auxing aux . letBind pat . Op $ - soacOp (Screma w used_arrs (ScremaForm scan reduce map_lam')) + soacOp (Screma w used_arrs (ScremaForm map_lam' scan reduce)) | Just (Scatter w arrs dests map_lam :: SOAC rep) <- asSOAC op, Just (used_arrs, map_lam') <- remove map_lam arrs = Simplify . auxing aux . letBind pat . Op $ @@ -418,7 +418,7 @@ removeUnusedSOACInput _ pat aux op removeUnusedSOACInput _ _ _ _ = Skip removeDeadMapping :: BottomUpRuleOp (Wise SOACS) -removeDeadMapping (_, used) (Pat pes) aux (Screma w arrs (ScremaForm scans reds lam)) +removeDeadMapping (_, used) (Pat pes) aux (Screma w arrs (ScremaForm lam scans reds)) | (nonmap_pes, map_pes) <- splitAt num_nonmap_res pes, not $ null map_pes = let (nonmap_res, map_res) = splitAt num_nonmap_res $ bodyResult $ lambdaBody lam @@ -434,10 +434,8 @@ removeDeadMapping (_, used) (Pat pes) aux (Screma w arrs (ScremaForm scans reds in if map_pes /= map_pes' then Simplify . auxing aux $ - letBind (Pat $ nonmap_pes <> map_pes') $ - Op $ - Screma w arrs $ - ScremaForm scans reds lam' + letBind (Pat $ nonmap_pes <> map_pes') . Op $ + Screma w arrs (ScremaForm lam' scans reds) else Skip where num_nonmap_res = scanResults scans + redResults reds @@ -642,7 +640,7 @@ simplifyKnownIterationSOAC :: (Buildable rep, BuilderOps rep, HasSOAC rep) => TopDownRuleOp rep simplifyKnownIterationSOAC _ pat _ op - | Just (Screma (Constant k) arrs (ScremaForm scans reds map_lam)) <- asSOAC op, + | Just (Screma (Constant k) arrs (ScremaForm map_lam scans reds)) <- asSOAC op, oneIsh k = Simplify $ do let (Reduce _ red_lam red_nes) = singleReduce reds (Scan scan_lam scan_nes) = singleScan scans @@ -697,7 +695,7 @@ simplifyKnownIterationSOAC _ pat _ op certifying cs $ letBindNames [v] $ BasicOp $ SubExp se -- simplifyKnownIterationSOAC _ pat aux op - | Just (Screma (Constant (IntValue (Int64Value k))) arrs (ScremaForm [] [] map_lam)) <- asSOAC op, + | Just (Screma (Constant (IntValue (Int64Value k))) arrs (ScremaForm map_lam [] [])) <- asSOAC op, "unroll" `inAttrs` stmAuxAttrs aux = Simplify $ do arrs_elems <- fmap transpose . forM [0 .. k - 1] $ \i -> do map_lam' <- renameLambda map_lam @@ -836,7 +834,7 @@ simplifyMapIota :: (Buildable rep, BuilderOps rep, HasSOAC rep) => TopDownRuleOp rep simplifyMapIota vtable screma_pat aux op - | Just (Screma w arrs (ScremaForm scan reduce map_lam) :: SOAC rep) <- asSOAC op, + | Just (Screma w arrs (ScremaForm map_lam scan reduce) :: SOAC rep) <- asSOAC op, Just (p, _) <- find isIota (zip (lambdaParams map_lam) arrs), indexings <- mapMaybe (indexesWith (paramName p)) . S.toList $ @@ -855,7 +853,7 @@ simplifyMapIota vtable screma_pat aux op } auxing aux . letBind screma_pat . Op . soacOp $ - Screma w (arrs <> more_arrs) (ScremaForm scan reduce map_lam') + Screma w (arrs <> more_arrs) (ScremaForm map_lam' scan reduce) where isIota (_, arr) = case ST.lookupBasicOp arr vtable of Just (Iota _ (Constant o) (Constant s) _, _) -> @@ -908,7 +906,7 @@ simplifyMapIota _ _ _ _ = Skip -- corresponding to that transformation performed on the rows of the -- full array. moveTransformToInput :: TopDownRuleOp (Wise SOACS) -moveTransformToInput vtable screma_pat aux soac@(Screma w arrs (ScremaForm scan reduce map_lam)) +moveTransformToInput vtable screma_pat aux soac@(Screma w arrs (ScremaForm map_lam scan reduce)) | ops <- filter arrayIsMapParam $ S.toList $ arrayOps mempty $ lambdaBody map_lam, not $ null ops = Simplify $ do (more_arrs, more_params, replacements) <- @@ -923,7 +921,7 @@ moveTransformToInput vtable screma_pat aux soac@(Screma w arrs (ScremaForm scan } auxing aux . letBind screma_pat . Op $ - Screma w (arrs <> more_arrs) (ScremaForm scan reduce map_lam') + Screma w (arrs <> more_arrs) (ScremaForm map_lam' scan reduce) where -- It is not safe to move the transform if the root array is being -- consumed by the Screma. This is a bit too conservative - it's diff --git a/src/Futhark/Optimise/Fusion.hs b/src/Futhark/Optimise/Fusion.hs index 74be1e3cbd..404dffc793 100644 --- a/src/Futhark/Optimise/Fusion.hs +++ b/src/Futhark/Optimise/Fusion.hs @@ -433,7 +433,7 @@ hFuseNodeT _ _ = pure Nothing removeOutputsExcept :: [VName] -> NodeT -> NodeT removeOutputsExcept toKeep s = case s of - SoacNode ots (Pat pats1) soac@(H.Screma _ _ (ScremaForm scans_1 red_1 lam_1)) aux1 -> + SoacNode ots (Pat pats1) soac@(H.Screma _ _ (ScremaForm lam_1 scans_1 red_1)) aux1 -> SoacNode ots (Pat $ pats_unchanged <> pats_new) (H.setLambda lam_new soac) aux1 where scan_output_size = Futhark.scanResults scans_1 diff --git a/src/Futhark/Optimise/Fusion/RulesWithAccs.hs b/src/Futhark/Optimise/Fusion/RulesWithAccs.hs index b03f14c99c..435db4e654 100644 --- a/src/Futhark/Optimise/Fusion/RulesWithAccs.hs +++ b/src/Futhark/Optimise/Fusion/RulesWithAccs.hs @@ -210,7 +210,7 @@ checkSafeAndProfitable dg scat_node_id ctxs_rshp@(_ : _) ctxs_cons = isMap nT | SoacNode out_trsfs _pat soac _ <- nT, H.Screma _ _ form <- soac, - ScremaForm [] [] _ <- form = + ScremaForm _ [] [] <- form = H.nullTransforms out_trsfs isMap _ = False checkSafeAndProfitable _ _ _ _ = False @@ -341,7 +341,7 @@ mkWithAccBdy' static_arg (dim : dims) dims_rev iot_par_nms rshp_ps cons_ps = do mkWithAccBdy' static_arg dims (dim : dims_rev) (iot_par_nms ++ [paramName iota_p]) rshp_ps' cons_ps' let map_lam = Lambda (rshp_ps' ++ [iota_p] ++ cons_ps') (map paramDec cons_ps') map_lam_bdy map_inps = map paramName rshp_ps ++ [iota_arr] ++ map paramName cons_ps - map_soac = F.Screma dim map_inps $ ScremaForm [] [] map_lam + map_soac = F.Screma dim map_inps $ ScremaForm map_lam [] [] res_nms <- letTupExp "acc_res" $ Op map_soac pure $ map (subExpRes . Var) res_nms diff --git a/src/Futhark/Optimise/Fusion/TryFusion.hs b/src/Futhark/Optimise/Fusion/TryFusion.hs index 9cc9e773d3..a89691777a 100644 --- a/src/Futhark/Optimise/Fusion/TryFusion.hs +++ b/src/Futhark/Optimise/Fusion/TryFusion.hs @@ -262,11 +262,18 @@ fuseSOACwithKer mode unfus_set outVars soac_p ker = do | unfus_set /= mempty, not (SOAC.nullTransforms $ fsOutputTransform ker) -> fail "Cannot perform diagonal fusion in the presence of output transforms." - ( SOAC.Screma _ _ (ScremaForm scans_c reds_c _), - SOAC.Screma _ _ (ScremaForm scans_p reds_p _), + ( SOAC.Screma _ _ (ScremaForm _ scans_c reds_c), + SOAC.Screma _ _ (ScremaForm _ scans_p reds_p), _ ) - | scremaFusionOK (splitAt (Futhark.scanResults scans_p + Futhark.redResults reds_p) outVars) ker -> do + | scremaFusionOK + ( splitAt + ( Futhark.scanResults scans_p + + Futhark.redResults reds_p + ) + outVars + ) + ker -> do let red_nes_p = concatMap redNeutral reds_p red_nes_c = concatMap redNeutral reds_c scan_nes_p = concatMap scanNeutral scans_p @@ -300,7 +307,7 @@ fuseSOACwithKer mode unfus_set outVars soac_p ker = do $ SOAC.Screma w new_inp - (ScremaForm (scans_p ++ scans_c) (reds_p ++ reds_c) res_lam') + (ScremaForm res_lam' (scans_p ++ scans_c) (reds_p ++ reds_c)) ------------------ -- Scatter fusion -- @@ -587,7 +594,7 @@ iswim _ (SOAC.Screma w arrs form) ots t : _ -> 1 : 0 : [2 .. arrayRank t] pure - ( SOAC.Screma map_w map_arrs' (ScremaForm [] [] map_fun'), + ( SOAC.Screma map_w map_arrs' (mapSOAC map_fun'), ots SOAC.|> SOAC.Rearrange map_cs perm ) iswim _ _ _ = diff --git a/src/Futhark/Optimise/GenRedOpt.hs b/src/Futhark/Optimise/GenRedOpt.hs index 16a141024a..dae11758f0 100644 --- a/src/Futhark/Optimise/GenRedOpt.hs +++ b/src/Futhark/Optimise/GenRedOpt.hs @@ -150,7 +150,7 @@ genRed2Tile2d env kerstm@(Let pat_ker aux (Op (SegOp (SegMap seg_thd seg_space k map_lam <- renameLambda map_lam0 (k1_res, ker1_stms) <- runBuilderT' $ do iota <- letExp "iota" $ BasicOp $ Iota inv_dim_len (intConst Int64 0) (intConst Int64 1) Int64 - let op_exp = Op (OtherOp (Screma inv_dim_len [iota] (ScremaForm [] [red] map_lam))) + let op_exp = Op (OtherOp (Screma inv_dim_len [iota] (ScremaForm map_lam [] [red]))) res_redmap <- letTupExp "res_mapred" op_exp letSubExp (baseString pat_acc_nm ++ "_big_update") $ BasicOp (UpdateAcc safety acc_nm acc_inds $ map Var res_redmap) diff --git a/src/Futhark/Pass/ExtractKernels.hs b/src/Futhark/Pass/ExtractKernels.hs index 40b3e1b43f..244e4c40ed 100644 --- a/src/Futhark/Pass/ExtractKernels.hs +++ b/src/Futhark/Pass/ExtractKernels.hs @@ -596,7 +596,7 @@ worthIntrablock lam = bodyInterest (lambdaBody lam) > 1 max (bodyInterest defbody) (map (bodyInterest . caseBody) cases) - | Op (Screma w _ (ScremaForm _ _ lam')) <- stmExp stm = + | Op (Screma w _ (ScremaForm lam' _ _)) <- stmExp stm = zeroIfTooSmall w + bodyInterest (lambdaBody lam') | Op (Stream _ _ _ lam') <- stmExp stm = bodyInterest $ lambdaBody lam' @@ -625,7 +625,7 @@ worthSequentialising lam = bodyInterest (0 :: Int) (lambdaBody lam) > 1 interest depth stm | "sequential" `inAttrs` attrs = 0 :: Int - | Op (Screma _ _ form@(ScremaForm _ _ lam')) <- stmExp stm, + | Op (Screma _ _ form@(ScremaForm lam' _ _)) <- stmExp stm, isJust $ isMapSOAC form = if sequential_inner then 0 @@ -636,7 +636,7 @@ worthSequentialising lam = bodyInterest (0 :: Int) (lambdaBody lam) > 1 bodyInterest (depth + 1) body * 10 | WithAcc _ withacc_lam <- stmExp stm = bodyInterest (depth + 1) (lambdaBody withacc_lam) - | Op (Screma _ _ form@(ScremaForm _ _ lam')) <- stmExp stm = + | Op (Screma _ _ form@(ScremaForm lam' _ _)) <- stmExp stm = 1 + bodyInterest (depth + 1) (lambdaBody lam') + diff --git a/src/Futhark/Tools.hs b/src/Futhark/Tools.hs index 2d72b1fbf1..41da366d44 100644 --- a/src/Futhark/Tools.hs +++ b/src/Futhark/Tools.hs @@ -106,18 +106,16 @@ dissectScrema :: ScremaForm (Rep m) -> [VName] -> m () -dissectScrema pat w (ScremaForm scans reds map_lam) arrs = do +dissectScrema pat w (ScremaForm map_lam scans reds) arrs = do let num_reds = redResults reds num_scans = scanResults scans - (scan_res, red_res, map_res) = - splitAt3 num_scans num_reds $ patNames pat + (scan_res, red_res, map_res) = splitAt3 num_scans num_reds $ patNames pat to_red <- replicateM num_reds $ newVName "to_red" let scanomap = scanomapSOAC scans map_lam letBindNames (scan_res <> to_red <> map_res) $ - Op $ - Screma w arrs scanomap + Op (Screma w arrs scanomap) reduce <- reduceSOAC reds letBindNames red_res $ Op $ Screma w to_red reduce diff --git a/src/Futhark/Transform/FirstOrderTransform.hs b/src/Futhark/Transform/FirstOrderTransform.hs index 0a1917c32c..9ee22a6aee 100644 --- a/src/Futhark/Transform/FirstOrderTransform.hs +++ b/src/Futhark/Transform/FirstOrderTransform.hs @@ -125,7 +125,7 @@ transformSOAC _ JVP {} = error "transformSOAC: unhandled JVP" transformSOAC _ VJP {} = error "transformSOAC: unhandled VJP" -transformSOAC pat (Screma w arrs form@(ScremaForm scans reds map_lam)) = do +transformSOAC pat (Screma w arrs form@(ScremaForm map_lam scans reds)) = do -- See Note [Translation of Screma]. -- -- Start by combining all the reduction and scan parts into a single From daa7181d2f1494a3924a2486787de89ee4dad0dd Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 4 Dec 2024 20:51:19 +0100 Subject: [PATCH 2/2] Flip VJP/JVP operands to put lambda last. --- src/Futhark/IR/Parse.hs | 8 ++-- src/Futhark/IR/SOACS/SOAC.hs | 74 +++++++++++++++----------------- src/Futhark/IR/SOACS/Simplify.hs | 8 ++-- src/Futhark/Internalise/Exps.hs | 4 +- src/Futhark/Optimise/Fusion.hs | 8 ++-- src/Futhark/Pass/AD.hs | 8 ++-- 6 files changed, 52 insertions(+), 58 deletions(-) diff --git a/src/Futhark/IR/Parse.hs b/src/Futhark/IR/Parse.hs index 71a2ef1755..298e149617 100644 --- a/src/Futhark/IR/Parse.hs +++ b/src/Futhark/IR/Parse.hs @@ -798,19 +798,19 @@ pSOAC pr = pVJP = parens $ SOAC.VJP - <$> pLambda pr + <$> braces (pSubExp `sepBy` pComma) <* pComma <*> braces (pSubExp `sepBy` pComma) <* pComma - <*> braces (pSubExp `sepBy` pComma) + <*> pLambda pr pJVP = parens $ SOAC.JVP - <$> pLambda pr + <$> braces (pSubExp `sepBy` pComma) <* pComma <*> braces (pSubExp `sepBy` pComma) <* pComma - <*> braces (pSubExp `sepBy` pComma) + <*> pLambda pr pSizeClass :: Parser GPU.SizeClass pSizeClass = diff --git a/src/Futhark/IR/SOACS/SOAC.hs b/src/Futhark/IR/SOACS/SOAC.hs index e048b287a3..265582e361 100644 --- a/src/Futhark/IR/SOACS/SOAC.hs +++ b/src/Futhark/IR/SOACS/SOAC.hs @@ -125,9 +125,9 @@ data SOAC rep -- The final lambda produces indexes and values for the 'HistOp's. Hist SubExp [VName] [HistOp rep] (Lambda rep) | -- FIXME: this should not be here - JVP (Lambda rep) [SubExp] [SubExp] + JVP [SubExp] [SubExp] (Lambda rep) | -- FIXME: this should not be here - VJP (Lambda rep) [SubExp] [SubExp] + VJP [SubExp] [SubExp] (Lambda rep) | -- | A combination of scan, reduction, and map. The first -- t'SubExp' is the size of the input arrays. Screma SubExp [VName] (ScremaForm rep) @@ -399,16 +399,16 @@ mapSOACM :: SOACMapper frep trep m -> SOAC frep -> m (SOAC trep) -mapSOACM tv (JVP lam args vec) = +mapSOACM tv (JVP args vec lam) = JVP - <$> mapOnSOACLambda tv lam - <*> mapM (mapOnSOACSubExp tv) args + <$> mapM (mapOnSOACSubExp tv) args <*> mapM (mapOnSOACSubExp tv) vec -mapSOACM tv (VJP lam args vec) = + <*> mapOnSOACLambda tv lam +mapSOACM tv (VJP args vec lam) = VJP - <$> mapOnSOACLambda tv lam - <*> mapM (mapOnSOACSubExp tv) args + <$> mapM (mapOnSOACSubExp tv) args <*> mapM (mapOnSOACSubExp tv) vec + <*> mapOnSOACLambda tv lam mapSOACM tv (Stream size arrs accs lam) = Stream <$> mapOnSOACSubExp tv size @@ -514,12 +514,10 @@ instance (ASTRep rep) => Rename (SOAC rep) where -- | The type of a SOAC. soacType :: (Typed (LParamInfo rep)) => SOAC rep -> [Type] -soacType (JVP lam _ _) = - lambdaReturnType lam - ++ lambdaReturnType lam -soacType (VJP lam _ _) = - lambdaReturnType lam - ++ map paramType (lambdaParams lam) +soacType (JVP _ _ lam) = + lambdaReturnType lam ++ lambdaReturnType lam +soacType (VJP _ _ lam) = + lambdaReturnType lam ++ map paramType (lambdaParams lam) soacType (Stream outersize _ accs lam) = map (substNamesInType substs) rtp where @@ -572,10 +570,10 @@ mapHistOp f (HistOp w rf dests nes lam) = HistOp w rf dests nes $ f lam instance CanBeAliased SOAC where - addOpAliases aliases (JVP lam args vec) = - JVP (Alias.analyseLambda aliases lam) args vec - addOpAliases aliases (VJP lam args vec) = - VJP (Alias.analyseLambda aliases lam) args vec + addOpAliases aliases (JVP args vec lam) = + JVP args vec (Alias.analyseLambda aliases lam) + addOpAliases aliases (VJP args vec lam) = + VJP args vec (Alias.analyseLambda aliases lam) addOpAliases aliases (Stream size arr accs lam) = Stream size arr accs $ Alias.analyseLambda aliases lam addOpAliases aliases (Scatter len arrs dests lam) = @@ -631,12 +629,12 @@ instance IsOp SOAC where where flattenBlocks (_, arr, ivs) = oneName arr <> mconcat (map (mconcat . fst) ivs) <> mconcat (map snd ivs) - opDependencies (JVP lam args vec) = + opDependencies (JVP args vec lam) = mconcat $ replicate 2 $ lambdaDependencies mempty lam $ zipWith (<>) (map depsOf' args) (map depsOf' vec) - opDependencies (VJP lam args vec) = + opDependencies (VJP args vec lam) = lambdaDependencies mempty lam @@ -713,7 +711,7 @@ instance (RepTypes rep) => ST.IndexOp (SOAC rep) where -- | Type-check a SOAC. typeCheckSOAC :: (TC.Checkable rep) => SOAC (Aliases rep) -> TC.TypeM rep () -typeCheckSOAC (VJP lam args vec) = do +typeCheckSOAC (VJP args vec lam) = do args' <- mapM TC.checkArg args TC.checkLambda lam $ map TC.noArgAliases args' vec_ts <- mapM TC.checkSubExp vec @@ -723,7 +721,7 @@ typeCheckSOAC (VJP lam args vec) = do PP.indent 2 (pretty (lambdaReturnType lam)) "does not match type of seed vector" PP.indent 2 (pretty vec_ts) -typeCheckSOAC (JVP lam args vec) = do +typeCheckSOAC (JVP args vec lam) = do args' <- mapM TC.checkArg args TC.checkLambda lam $ map TC.noArgAliases args' vec_ts <- mapM TC.checkSubExp vec @@ -893,10 +891,10 @@ typeCheckSOAC (Screma w arrs (ScremaForm map_lam scans reds)) = do <> " wrong for given scan and reduction functions." instance RephraseOp SOAC where - rephraseInOp r (VJP lam args vec) = - VJP <$> rephraseLambda r lam <*> pure args <*> pure vec - rephraseInOp r (JVP lam args vec) = - JVP <$> rephraseLambda r lam <*> pure args <*> pure vec + rephraseInOp r (VJP args vec lam) = + VJP args vec <$> rephraseLambda r lam + rephraseInOp r (JVP args vec lam) = + JVP args vec <$> rephraseLambda r lam rephraseInOp r (Stream w arrs acc lam) = Stream w arrs acc <$> rephraseLambda r lam rephraseInOp r (Scatter w arrs dests lam) = @@ -918,9 +916,9 @@ instance RephraseOp SOAC where onRed (Reduce comm op nes) = Reduce comm <$> rephraseLambda r op <*> pure nes instance (OpMetrics (Op rep)) => OpMetrics (SOAC rep) where - opMetrics (VJP lam _ _) = + opMetrics (VJP _ _ lam) = inside "VJP" $ lambdaMetrics lam - opMetrics (JVP lam _ _) = + opMetrics (JVP _ _ lam) = inside "JVP" $ lambdaMetrics lam opMetrics (Stream _ _ _ lam) = inside "Stream" $ lambdaMetrics lam @@ -935,25 +933,21 @@ instance (OpMetrics (Op rep)) => OpMetrics (SOAC rep) where mapM_ (lambdaMetrics . redLambda) reds instance (PrettyRep rep) => PP.Pretty (SOAC rep) where - pretty (VJP lam args vec) = + pretty (VJP args vec lam) = "vjp" <> parens ( PP.align $ - pretty lam - <> comma - PP.braces (commasep $ map pretty args) - <> comma - PP.braces (commasep $ map pretty vec) + PP.braces (commasep $ map pretty args) + <> comma PP.braces (commasep $ map pretty vec) + <> comma pretty lam ) - pretty (JVP lam args vec) = + pretty (JVP args vec lam) = "jvp" <> parens ( PP.align $ - pretty lam - <> comma - PP.braces (commasep $ map pretty args) - <> comma - PP.braces (commasep $ map pretty vec) + PP.braces (commasep $ map pretty args) + <> comma PP.braces (commasep $ map pretty vec) + <> comma pretty lam ) pretty (Stream size arrs acc lam) = ppStream size arrs acc lam diff --git a/src/Futhark/IR/SOACS/Simplify.hs b/src/Futhark/IR/SOACS/Simplify.hs index e15a6332b4..38254646ba 100644 --- a/src/Futhark/IR/SOACS/Simplify.hs +++ b/src/Futhark/IR/SOACS/Simplify.hs @@ -82,16 +82,16 @@ simplifyConsts = simplifySOAC :: (Simplify.SimplifiableRep rep) => Simplify.SimplifyOp rep (SOAC (Wise rep)) -simplifySOAC (VJP lam arr vec) = do +simplifySOAC (VJP arr vec lam) = do (lam', hoisted) <- Engine.simplifyLambda mempty lam arr' <- mapM Engine.simplify arr vec' <- mapM Engine.simplify vec - pure (VJP lam' arr' vec', hoisted) -simplifySOAC (JVP lam arr vec) = do + pure (VJP arr' vec' lam', hoisted) +simplifySOAC (JVP arr vec lam) = do (lam', hoisted) <- Engine.simplifyLambda mempty lam arr' <- mapM Engine.simplify arr vec' <- mapM Engine.simplify vec - pure (JVP lam' arr' vec', hoisted) + pure (JVP arr' vec' lam', hoisted) simplifySOAC (Stream outerdim arr nes lam) = do outerdim' <- Engine.simplify outerdim nes' <- mapM Engine.simplify nes diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 8d02701647..4c310aa8cf 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1710,8 +1710,8 @@ isIntrinsicFunction qname args loc = do lam <- internaliseLambdaCoerce f =<< mapM subExpType x' fmap (map I.Var) . letTupExp desc . Op $ case fname of - "jvp2" -> JVP lam x' v' - _ -> VJP lam x' v' + "jvp2" -> JVP x' v' lam + _ -> VJP x' v' lam handleAD _ _ = Nothing handleRest [a, si, v] "scatter" = Just $ scatterF 1 a si v diff --git a/src/Futhark/Optimise/Fusion.hs b/src/Futhark/Optimise/Fusion.hs index 404dffc793..703ce2a047 100644 --- a/src/Futhark/Optimise/Fusion.hs +++ b/src/Futhark/Optimise/Fusion.hs @@ -541,12 +541,12 @@ runInnerFusionOnContext c@(incoming, node, nodeT, outgoing) = case nodeT of cases' <- mapM (traverse $ renameBody <=< (`doFusionWithDelayed` to_fuse)) cases defbody' <- doFusionWithDelayed defbody to_fuse pure (incoming, node, MatchNode (Let pat aux (Match cond cases' defbody' dec)) [], outgoing) - StmNode (Let pat aux (Op (Futhark.VJP lam args vec))) -> doFuseScans $ do + StmNode (Let pat aux (Op (Futhark.VJP args vec lam))) -> doFuseScans $ do lam' <- fst <$> doFusionInLambda lam - pure (incoming, node, StmNode (Let pat aux (Op (Futhark.VJP lam' args vec))), outgoing) - StmNode (Let pat aux (Op (Futhark.JVP lam args vec))) -> doFuseScans $ do + pure (incoming, node, StmNode (Let pat aux (Op (Futhark.VJP args vec lam'))), outgoing) + StmNode (Let pat aux (Op (Futhark.JVP args vec lam))) -> doFuseScans $ do lam' <- fst <$> doFusionInLambda lam - pure (incoming, node, StmNode (Let pat aux (Op (Futhark.JVP lam' args vec))), outgoing) + pure (incoming, node, StmNode (Let pat aux (Op (Futhark.JVP args vec lam'))), outgoing) StmNode (Let pat aux (WithAcc inputs lam)) -> doFuseScans $ do lam' <- fst <$> doFusionInLambda lam pure (incoming, node, StmNode (Let pat aux (WithAcc inputs lam')), outgoing) diff --git a/src/Futhark/Pass/AD.hs b/src/Futhark/Pass/AD.hs index 3d7691b034..dca9ad088c 100644 --- a/src/Futhark/Pass/AD.hs +++ b/src/Futhark/Pass/AD.hs @@ -36,20 +36,20 @@ bindLambda pat aux (Lambda params _ body) args = do certifying cs $ letBindNames [v] $ BasicOp $ SubExp se onStm :: Mode -> Scope SOACS -> Stm SOACS -> PassM (Stms SOACS) -onStm mode scope (Let pat aux (Op (VJP lam args vec))) = do +onStm mode scope (Let pat aux (Op (VJP args vec lam))) = do lam' <- onLambda mode scope lam if mode == All || lam == lam' then do lam'' <- (`runReaderT` scope) . simplifyLambda =<< revVJP scope lam' runBuilderT_ (bindLambda pat aux lam'' $ args ++ vec) scope - else pure $ oneStm $ Let pat aux $ Op $ VJP lam' args vec -onStm mode scope (Let pat aux (Op (JVP lam args vec))) = do + else pure $ oneStm $ Let pat aux $ Op $ VJP args vec lam' +onStm mode scope (Let pat aux (Op (JVP args vec lam))) = do lam' <- onLambda mode scope lam if mode == All || lam == lam' then do lam'' <- fwdJVP scope lam' runBuilderT_ (bindLambda pat aux lam'' $ args ++ vec) scope - else pure $ oneStm $ Let pat aux $ Op $ JVP lam' args vec + else pure $ oneStm $ Let pat aux $ Op $ JVP args vec lam' onStm mode scope (Let pat aux e) = oneStm . Let pat aux <$> mapExpM mapper e where mapper =