From dbf468980422b3597948de016791a7fa0107fac4 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 28 Jan 2025 13:07:24 +0100 Subject: [PATCH] Flattening of concat (#2214) --- futhark.cabal | 4 + rewritefut/segupdate.fut | 36 + src/Futhark/CLI/Dev.hs | 2 + src/Futhark/IR/Pretty.hs | 1 + src/Futhark/IR/TypeCheck.hs | 12 +- src/Futhark/Pass/ExtractKernels/ToGPU.hs | 4 + src/Futhark/Pass/Flatten.hs | 1456 +++++++++++++++++ src/Futhark/Pass/Flatten/Builtins.hs | 483 ++++++ src/Futhark/Pass/Flatten/Distribute.hs | 261 +++ src/Futhark/Passes.hs | 4 +- src/Futhark/Util.hs | 7 + tests/flattening/CosminArrayExample.fut | 17 - tests/flattening/HighlyNestedMap.fut | 41 - tests/flattening/IntmRes1.fut | 23 - tests/flattening/IntmRes2.fut | 30 - tests/flattening/IntmRes3.fut | 36 - tests/flattening/LoopInv1.fut | 24 - tests/flattening/LoopInv2.fut | 26 - tests/flattening/LoopInv3.fut | 34 - tests/flattening/LoopInvReshape.fut | 16 - tests/flattening/Map-IotaMapReduce.fut | 14 - tests/flattening/Map-Map-IotaMapReduce.fut | 18 - tests/flattening/MapIotaReduce.fut | 12 - tests/flattening/MatrixAddition.fut | 16 - tests/flattening/SimpleReduce.fut | 11 - tests/flattening/VectorAddition.fut | 10 - tests/flattening/binop.fut | 5 + tests/flattening/concat-check-index.fut | 27 + tests/flattening/concat-iota.fut | 8 + tests/flattening/concat-rep.fut | 8 + tests/flattening/dup2d.fut | 7 + tests/flattening/dup3d.fut | 9 + tests/flattening/flattening-pipeline | 2 - tests/flattening/flattening-test | 11 - .../function-lifting/func_const.fut | 22 + .../flattening/function-lifting/func_free.fut | 27 + .../function-lifting/func_fully_irreg.fut | 23 + .../function-lifting/func_irreg_input.fut | 17 + .../function-lifting/func_irreg_result.fut | 17 + .../function-lifting/func_irreg_update.fut | 23 + .../flattening/function-lifting/func_mix.fut | 25 + .../function-lifting/func_mix_nested.fut | 31 + .../function-lifting/func_simple.fut | 16 + tests/flattening/iota-index.fut | 10 + tests/flattening/iota-opaque-index.fut | 9 + tests/flattening/iota-opaque-slice-red.fut | 11 + tests/flattening/iota-red.fut | 7 + tests/flattening/map-nested-deeper.fut | 9 + tests/flattening/map-nested-free2d.fut | 9 + tests/flattening/map-nested.fut | 5 + tests/flattening/map-slice-nested.fut | 5 + tests/flattening/mapout.fut | 11 + tests/flattening/match-case/if.fut | 17 + .../flattening/match-case/if_fully_irreg.fut | 24 + .../flattening/match-case/if_irreg_input.fut | 17 + .../flattening/match-case/if_irreg_result.fut | 20 + .../match-case/match_fully_irreg.fut | 25 + tests/flattening/range-irreg-stride.fut | 7 + tests/flattening/range-opaque-red.fut | 7 + tests/flattening/rearrange0.fut | 5 + tests/flattening/rearrange1.fut | 5 + tests/flattening/redomap1.fut | 17 - tests/flattening/redomap2.fut | 13 - tests/flattening/replicate0.fut | 6 + tests/flattening/replicate1.fut | 7 + tests/flattening/slice-red.fut | 5 + tests/flattening/slice2d-red.fut | 5 + tests/flattening/update_dimfix.fut | 38 + tests/flattening/update_fully_irregular.fut | 7 + tests/flattening/update_invariant_is.fut | 7 + tests/flattening/update_invariant_vs.fut | 8 + tests/flattening/update_invariant_xs.fut | 7 + tests/flattening/update_mixdim.fut | 12 + tests/flattening/update_multdim.fut | 11 + tests/flattening/update_variant_is.fut | 7 + tests/flattening/update_variant_vs.fut | 7 + tests/flattening/update_variant_xs.fut | 7 + 77 files changed, 2864 insertions(+), 379 deletions(-) create mode 100644 rewritefut/segupdate.fut create mode 100644 src/Futhark/Pass/Flatten.hs create mode 100644 src/Futhark/Pass/Flatten/Builtins.hs create mode 100644 src/Futhark/Pass/Flatten/Distribute.hs delete mode 100644 tests/flattening/CosminArrayExample.fut delete mode 100644 tests/flattening/HighlyNestedMap.fut delete mode 100644 tests/flattening/IntmRes1.fut delete mode 100644 tests/flattening/IntmRes2.fut delete mode 100644 tests/flattening/IntmRes3.fut delete mode 100644 tests/flattening/LoopInv1.fut delete mode 100644 tests/flattening/LoopInv2.fut delete mode 100644 tests/flattening/LoopInv3.fut delete mode 100644 tests/flattening/LoopInvReshape.fut delete mode 100644 tests/flattening/Map-IotaMapReduce.fut delete mode 100644 tests/flattening/Map-Map-IotaMapReduce.fut delete mode 100644 tests/flattening/MapIotaReduce.fut delete mode 100644 tests/flattening/MatrixAddition.fut delete mode 100644 tests/flattening/SimpleReduce.fut delete mode 100644 tests/flattening/VectorAddition.fut create mode 100644 tests/flattening/binop.fut create mode 100644 tests/flattening/concat-check-index.fut create mode 100644 tests/flattening/concat-iota.fut create mode 100644 tests/flattening/concat-rep.fut create mode 100644 tests/flattening/dup2d.fut create mode 100644 tests/flattening/dup3d.fut delete mode 100755 tests/flattening/flattening-pipeline delete mode 100755 tests/flattening/flattening-test create mode 100644 tests/flattening/function-lifting/func_const.fut create mode 100644 tests/flattening/function-lifting/func_free.fut create mode 100644 tests/flattening/function-lifting/func_fully_irreg.fut create mode 100644 tests/flattening/function-lifting/func_irreg_input.fut create mode 100644 tests/flattening/function-lifting/func_irreg_result.fut create mode 100644 tests/flattening/function-lifting/func_irreg_update.fut create mode 100644 tests/flattening/function-lifting/func_mix.fut create mode 100644 tests/flattening/function-lifting/func_mix_nested.fut create mode 100644 tests/flattening/function-lifting/func_simple.fut create mode 100644 tests/flattening/iota-index.fut create mode 100644 tests/flattening/iota-opaque-index.fut create mode 100644 tests/flattening/iota-opaque-slice-red.fut create mode 100644 tests/flattening/iota-red.fut create mode 100644 tests/flattening/map-nested-deeper.fut create mode 100644 tests/flattening/map-nested-free2d.fut create mode 100644 tests/flattening/map-nested.fut create mode 100644 tests/flattening/map-slice-nested.fut create mode 100644 tests/flattening/mapout.fut create mode 100644 tests/flattening/match-case/if.fut create mode 100644 tests/flattening/match-case/if_fully_irreg.fut create mode 100644 tests/flattening/match-case/if_irreg_input.fut create mode 100644 tests/flattening/match-case/if_irreg_result.fut create mode 100644 tests/flattening/match-case/match_fully_irreg.fut create mode 100644 tests/flattening/range-irreg-stride.fut create mode 100644 tests/flattening/range-opaque-red.fut create mode 100644 tests/flattening/rearrange0.fut create mode 100644 tests/flattening/rearrange1.fut delete mode 100644 tests/flattening/redomap1.fut delete mode 100644 tests/flattening/redomap2.fut create mode 100644 tests/flattening/replicate0.fut create mode 100644 tests/flattening/replicate1.fut create mode 100644 tests/flattening/slice-red.fut create mode 100644 tests/flattening/slice2d-red.fut create mode 100644 tests/flattening/update_dimfix.fut create mode 100644 tests/flattening/update_fully_irregular.fut create mode 100644 tests/flattening/update_invariant_is.fut create mode 100644 tests/flattening/update_invariant_vs.fut create mode 100644 tests/flattening/update_invariant_xs.fut create mode 100644 tests/flattening/update_mixdim.fut create mode 100644 tests/flattening/update_multdim.fut create mode 100644 tests/flattening/update_variant_is.fut create mode 100644 tests/flattening/update_variant_vs.fut create mode 100644 tests/flattening/update_variant_xs.fut diff --git a/futhark.cabal b/futhark.cabal index c717050d96..9567a221e3 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -369,6 +369,9 @@ library Futhark.Pass.ExtractKernels.ToGPU Futhark.Pass.ExtractMulticore Futhark.Pass.FirstOrderTransform + Futhark.Pass.Flatten + Futhark.Pass.Flatten.Builtins + Futhark.Pass.Flatten.Distribute Futhark.Pass.LiftAllocations Futhark.Pass.LowerAllocations Futhark.Pass.Simplify @@ -475,6 +478,7 @@ library , lsp-types >= 2.0.1.0 , mainland-pretty >=0.7.1 , cmark-gfm >=0.2.1 + , OneTuple , megaparsec >=9.0.0 , mtl >=2.2.1 , neat-interpolation >=0.3 diff --git a/rewritefut/segupdate.fut b/rewritefut/segupdate.fut new file mode 100644 index 0000000000..980d3702db --- /dev/null +++ b/rewritefut/segupdate.fut @@ -0,0 +1,36 @@ +-- Flat-Parallel Segmented Update +-- == +-- compiled input { [1i64,2i64,3i64,1i64,2i64,1i64,2i64,3i64,4i64] [3i64,2i64,4i64] [0i64,0i64,0i64,0i64,0i64] [2i64,1i64,2i64] [0i64, 1i64, 0i64] [1i64, 1i64, 2i64] } output { [0i64,0i64,3i64,1i64,0i64,0i64,2i64,0i64,4i64] } + +let sgmSumI64 [n] (flg : [n]i64) (arr : [n]i64) : [n]i64 = + let flgs_vals = + scan ( \ (f1, x1) (f2,x2) -> + let f = f1 | f2 in + if f2 != 0 then (f, x2) + else (f, x1 + x2) ) + (0, 0i64) (zip flg arr) + let (_, vals) = unzip flgs_vals + in vals + +let mkFlagArray [m] (aoa_shp: [m]i64) (zero: i64) + (aoa_val: [m]i64) : []i64 = + let shp_rot = map(\i -> if i==0i64 then 0i64 else aoa_shp[i-1]) (iota m) + let shp_scn = scan (+) 0i64 shp_rot + let aoa_len = shp_scn[m-1]+aoa_shp[m-1] + let shp_ind = map2 (\shp ind -> if shp==0 then -1i64 else ind) aoa_shp shp_scn + in scatter (replicate aoa_len zero) shp_ind aoa_val + +let segUpdate [n][m][t] (xss_val : *[n]i64) (shp_xss : [t]i64) (vss_val : [m]i64) + (shp_vss : [t]i64) (bs : [t]i64) (ss : [t]i64): [n]i64 = + let fvss = (mkFlagArray shp_vss 0 (1...t :> [t]i64)) :> [m]i64 + let II1 = sgmSumI64 fvss fvss |> map (\x -> x - 1) + let shp_xss_rot = map(\i -> if i==0i64 then 0i64 else shp_xss[i-1]) (iota t) + let bxss = scan (+) 0 shp_xss_rot + let II2 = sgmSumI64 fvss (replicate m 1) |> map (\x -> x - 1) + let iss = map (\i -> bxss[II1[i]] + bs[II1[i]] + (II2[i] * ss[II1[i]])) (iota m) + in scatter xss_val iss vss_val + + +let main [n][m][t] (xss_val : *[n]i64) (shp_xss : [t]i64) (vss_val : [m]i64) + (shp_vss : [t]i64) (bs : [t]i64) (ss : [t]i64): [n]i64 = + segUpdate xss_val shp_xss vss_val shp_vss bs ss \ No newline at end of file diff --git a/src/Futhark/CLI/Dev.hs b/src/Futhark/CLI/Dev.hs index 9f3a4cce9a..d5ad228786 100644 --- a/src/Futhark/CLI/Dev.hs +++ b/src/Futhark/CLI/Dev.hs @@ -54,6 +54,7 @@ import Futhark.Pass.ExplicitAllocations.Seq qualified as Seq import Futhark.Pass.ExtractKernels import Futhark.Pass.ExtractMulticore import Futhark.Pass.FirstOrderTransform +import Futhark.Pass.Flatten (flattenSOACs) import Futhark.Pass.LiftAllocations as LiftAllocations import Futhark.Pass.LowerAllocations as LowerAllocations import Futhark.Pass.Simplify @@ -654,6 +655,7 @@ commandLineOptions = sinkOption [], kernelsPassOption reduceDeviceSyncs [], typedPassOption soacsProg GPU extractKernels [], + typedPassOption soacsProg GPU flattenSOACs [], typedPassOption soacsProg MC extractMulticore [], allocateOption "a", kernelsMemPassOption doubleBufferGPU [], diff --git a/src/Futhark/IR/Pretty.hs b/src/Futhark/IR/Pretty.hs index 609e18eedb..b7f49d0581 100644 --- a/src/Futhark/IR/Pretty.hs +++ b/src/Futhark/IR/Pretty.hs @@ -8,6 +8,7 @@ module Futhark.IR.Pretty ( prettyTuple, prettyTupleLines, prettyString, + prettyRet, PrettyRep (..), ) where diff --git a/src/Futhark/IR/TypeCheck.hs b/src/Futhark/IR/TypeCheck.hs index 8a690ff4ee..fa3c769534 100644 --- a/src/Futhark/IR/TypeCheck.hs +++ b/src/Futhark/IR/TypeCheck.hs @@ -58,7 +58,7 @@ import Futhark.Analysis.PrimExp import Futhark.Construct (instantiateShapes) import Futhark.IR.Aliases hiding (lookupAliases) import Futhark.Util -import Futhark.Util.Pretty (align, docText, indent, ppTuple', pretty, (<+>), ()) +import Futhark.Util.Pretty hiding (width) -- | Information about an error during type checking. The 'Show' -- instance for this type produces a human-readable description. @@ -742,7 +742,7 @@ checkSubExp (Var ident) = context ("In subexp " <> prettyText ident) $ do lookupType ident checkCerts :: (Checkable rep) => Certs -> TypeM rep () -checkCerts (Certs cs) = mapM_ (requireI [Prim Unit]) cs +checkCerts = mapM_ lookupType . unCerts checkSubExpRes :: (Checkable rep) => SubExpRes -> TypeM rep Type checkSubExpRes (SubExpRes cs se) = do @@ -1023,9 +1023,9 @@ checkExp (Apply fname args rettype_annot _) = do when (rettype_derived /= rettype_annot) $ bad . TypeError . docText $ "Expected apply result type:" - indent 2 (pretty $ map fst rettype_derived) + indent 2 (braces $ commasep $ map prettyRet rettype_derived) "But annotation is:" - indent 2 (pretty $ map fst rettype_annot) + indent 2 (braces $ commasep $ map prettyRet rettype_annot) consumeArgs paramtypes argflows checkExp (Loop merge form loopbody) = do let (mergepat, mergeexps) = unzip merge @@ -1252,8 +1252,8 @@ checkStm :: Stm (Aliases rep) -> TypeM rep a -> TypeM rep a -checkStm stm@(Let pat (StmAux (Certs cs) _ (_, dec)) e) m = do - context "When checking certificates" $ mapM_ (requireI [Prim Unit]) cs +checkStm stm@(Let pat (StmAux cs _ (_, dec)) e) m = do + context "When checking certificates" $ checkCerts cs context "When checking expression annotation" $ checkExpDec dec context ("When matching\n" <> message " " pat <> "\nwith\n" <> message " " e) $ matchPat pat e diff --git a/src/Futhark/Pass/ExtractKernels/ToGPU.hs b/src/Futhark/Pass/ExtractKernels/ToGPU.hs index 121b8fabc6..6104107ee4 100644 --- a/src/Futhark/Pass/ExtractKernels/ToGPU.hs +++ b/src/Futhark/Pass/ExtractKernels/ToGPU.hs @@ -5,6 +5,7 @@ module Futhark.Pass.ExtractKernels.ToGPU segThread, soacsLambdaToGPU, soacsStmToGPU, + soacsExpToGPU, scopeForGPU, scopeForSOACs, injectSOACS, @@ -74,6 +75,9 @@ injectSOACS f = soacsStmToGPU :: Stm SOACS -> Stm GPU soacsStmToGPU = runIdentity . rephraseStm (injectSOACS OtherOp) +soacsExpToGPU :: Exp SOACS -> Exp GPU +soacsExpToGPU = runIdentity . rephraseExp (injectSOACS OtherOp) + soacsLambdaToGPU :: Lambda SOACS -> Lambda GPU soacsLambdaToGPU = runIdentity . rephraseLambda (injectSOACS OtherOp) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs new file mode 100644 index 0000000000..5101ceedbd --- /dev/null +++ b/src/Futhark/Pass/Flatten.hs @@ -0,0 +1,1456 @@ +{-# LANGUAGE TypeFamilies #-} + +-- The idea is to perform distribution on one level at a time, and +-- produce "irregular Maps" that can accept and produce irregular +-- arrays. These irregular maps will then be transformed into flat +-- parallelism based on their contents. This is a sensitive detail, +-- but if irregular maps contain only a single Stm, then it is fairly +-- straightforward, as we simply implement flattening rules for every +-- single kind of expression. Of course that is also somewhat +-- inefficient, so we want to support multiple Stms for things like +-- scalar code. +module Futhark.Pass.Flatten (flattenSOACs) where + +import Control.Monad +import Control.Monad.Reader +import Data.Bifunctor (bimap, first, second) +import Data.Foldable +import Data.List qualified as L +import Data.List.NonEmpty qualified as NE +import Data.Map qualified as M +import Data.Maybe (fromMaybe, isNothing, mapMaybe) +import Data.Tuple.Solo +import Debug.Trace +import Futhark.IR.GPU +import Futhark.IR.SOACS +import Futhark.MonadFreshNames +import Futhark.Pass +import Futhark.Pass.ExtractKernels.ToGPU (soacsStmToGPU) +import Futhark.Pass.Flatten.Builtins +import Futhark.Pass.Flatten.Distribute +import Futhark.Tools +import Futhark.Transform.Rename +import Futhark.Transform.Substitute +import Futhark.Util.IntegralExp +import Prelude hiding (div, rem) + +-- Note [Representation of Flat Arrays] +-- +-- This flattening implementation uses largely the nomenclature and +-- structure described by Cosmin Oancea. In particular, consider an +-- irregular array 'A' where +-- +-- - A has 'n' segments (outermost dimension). +-- +-- - A has element type 't'. +-- +-- - A has a total of 'm' elements (where 'm' is divisible by 'n', +-- and may indeed be 'm'). +-- +-- Then A is represented by the following arrays: +-- +-- - A_D : [m]t; the "data array". +-- +-- - A_S : [n]i64; the "shape array" giving the size of each segment. +-- +-- - A_F : [m]bool; the "flag array", indicating when an element begins a +-- new segment. +-- +-- - A_O : [n]i64; the offset array, indicating for each segment +-- where it starts in the data (and flag) array. +-- +-- - A_II1 : [m]t; the "segment indices"; a mapping from element +-- index to index of the segment it belongs to. +-- +-- - A_II2 : [m]t; the "inner indices"; a mapping from element index +-- to index within its corresponding segment. +-- +-- The arrays that are not the data array are collectively called the +-- "structure arrays". All of the structure arrays can be computed +-- from each other, but conceptually they all coexist. +-- +-- Note that we only consider the *outer* dimension to be the +-- "segments". Also, 't' may actually be an array itself (although in +-- this case, the shape of 't' must be invariant to all parallel +-- dimensions). The inner structure is preserved through code, not +-- data. (Or in practice, ad-hoc auxiliary arrays produced by code.) +-- In Cosmin's notation, we maintain only the information for the +-- outermost dimension. +-- +-- As an example, consider an irregular array +-- +-- A = [ [], [ [1,2,3], [4], [], [5,6] ], [ [7], [], [8,9,10] ] ] +-- +-- then +-- +-- n = 3 +-- +-- m = 10 +-- +-- A_D = [1,2,3,4,5,6,7,8,9,10] +-- +-- A_S = [0, 4, 3] +-- +-- A_F = [T,F,F,F,F,F,T,F,F,F] +-- +-- A_O = [0, 0, 6] +-- +-- A_II1 = [1,1,1,1,1,1,2,2,2,2] +-- +-- A_II2 = [0,0,0,1,3,3,0,2,2,2] + +data IrregularRep = IrregularRep + { -- | Array of size of each segment, type @[]i64@. + irregularS :: VName, + irregularF :: VName, + irregularO :: VName, + irregularD :: VName + } + deriving (Show) + +data ResRep + = -- | This variable is represented + -- completely straightforwardly- if it is + -- an array, it is a regular array. + Regular VName + | -- | The representation of an + -- irregular array. + Irregular IrregularRep + deriving (Show) + +newtype DistEnv = DistEnv {distResMap :: M.Map ResTag ResRep} + +insertRep :: ResTag -> ResRep -> DistEnv -> DistEnv +insertRep rt rep env = env {distResMap = M.insert rt rep $ distResMap env} + +insertReps :: [(ResTag, ResRep)] -> DistEnv -> DistEnv +insertReps = flip $ foldl (flip $ uncurry insertRep) + +insertIrregular :: VName -> VName -> VName -> ResTag -> VName -> DistEnv -> DistEnv +insertIrregular ns flags offsets rt elems env = + let rep = Irregular $ IrregularRep ns flags offsets elems + in insertRep rt rep env + +insertIrregulars :: VName -> VName -> VName -> [(ResTag, VName)] -> DistEnv -> DistEnv +insertIrregulars ns flags offsets bnds env = + let (tags, elems) = unzip bnds + mkRep = Irregular . IrregularRep ns flags offsets + in insertReps (zip tags $ map mkRep elems) env + +insertRegulars :: [ResTag] -> [VName] -> DistEnv -> DistEnv +insertRegulars rts xs = + insertReps (zip rts $ map Regular xs) + +instance Monoid DistEnv where + mempty = DistEnv mempty + +instance Semigroup DistEnv where + DistEnv x <> DistEnv y = DistEnv (x <> y) + +resVar :: ResTag -> DistEnv -> ResRep +resVar rt env = fromMaybe bad $ M.lookup rt $ distResMap env + where + bad = error $ "resVar: unknown tag: " ++ show rt + +segsAndElems :: DistEnv -> [DistInput] -> (Maybe (VName, VName, VName), [VName]) +segsAndElems _ [] = (Nothing, []) +segsAndElems env (DistInputFree v _ : vs) = + second (v :) $ segsAndElems env vs +segsAndElems env (DistInput rt _ : vs) = + case resVar rt env of + Regular v' -> + second (v' :) $ segsAndElems env vs + Irregular (IrregularRep segments flags offsets elems) -> + bimap (mplus $ Just (segments, flags, offsets)) (elems :) $ segsAndElems env vs + +-- Mapping from original variable names to their distributed resreps +inputReps :: DistInputs -> DistEnv -> M.Map VName (Type, ResRep) +inputReps inputs env = M.fromList $ map (second getRep) inputs + where + getRep di = case di of + DistInput rt t -> (t, resVar rt env) + DistInputFree v' t -> (t, Regular v') + +type Segments = NE.NonEmpty SubExp + +segmentsShape :: Segments -> Shape +segmentsShape = Shape . toList + +segmentsRank :: Segments -> Int +segmentsRank = shapeRank . segmentsShape + +readInput :: Segments -> DistEnv -> [SubExp] -> DistInputs -> SubExp -> Builder GPU SubExp +readInput _ _ _ _ (Constant x) = pure $ Constant x +readInput _segments env is inputs (Var v) = + case lookup v inputs of + Nothing -> pure $ Var v + Just (DistInputFree arr _) -> + letSubExp (baseString v) =<< eIndex arr (map eSubExp is) + Just (DistInput rt _) -> do + case resVar rt env of + Regular arr -> + letSubExp (baseString v) =<< eIndex arr (map eSubExp is) + Irregular (IrregularRep _ _flags _offsets _elems) -> + undefined + +readInputs :: Segments -> DistEnv -> [SubExp] -> DistInputs -> Builder GPU () +readInputs _segments env is = mapM_ onInput + where + onInput (v, DistInputFree arr _) = + letBindNames [v] =<< eIndex arr (map eSubExp is) + onInput (v, DistInput rt t) = + case resVar rt env of + Regular arr -> + letBindNames [v] =<< eIndex arr (map eSubExp is) + Irregular (IrregularRep _ _ v_O v_D) -> do + offset <- letSubExp "offset" =<< eIndex v_O (map eSubExp is) + case arrayDims t of + [num_elems] -> do + let slice = Slice [DimSlice offset num_elems (intConst Int64 1)] + letBindNames [v] $ BasicOp $ Index v_D slice + _ -> do + num_elems <- + letSubExp "num_elems" =<< toExp (product $ map pe64 $ arrayDims t) + let slice = Slice [DimSlice offset num_elems (intConst Int64 1)] + v_flat <- + letExp (baseString v <> "_flat") $ BasicOp $ Index v_D slice + letBindNames [v] . BasicOp $ + Reshape ReshapeArbitrary (arrayShape t) v_flat + +transformScalarStms :: + Segments -> + DistEnv -> + DistInputs -> + [DistResult] -> + Stms SOACS -> + [VName] -> + Builder GPU DistEnv +transformScalarStms segments env inps distres stms res = do + vs <- letTupExp "scalar_dist" <=< renameExp <=< segMap segments $ \is -> do + readInputs segments env (toList is) inps + addStms $ fmap soacsStmToGPU stms + pure $ subExpsRes $ map Var res + pure $ insertReps (zip (map distResTag distres) $ map Regular vs) env + +transformScalarStm :: + Segments -> + DistEnv -> + DistInputs -> + [DistResult] -> + Stm SOACS -> + Builder GPU DistEnv +transformScalarStm segments env inps res stm = + transformScalarStms segments env inps res (oneStm stm) (patNames (stmPat stm)) + +distCerts :: DistInputs -> StmAux a -> DistEnv -> Certs +distCerts inps aux env = Certs $ map f $ unCerts $ stmAuxCerts aux + where + f v = case lookup v inps of + Nothing -> v + Just (DistInputFree vs _) -> vs + Just (DistInput rt _) -> + case resVar rt env of + Regular vs -> vs + Irregular r -> irregularD r + +-- | Only sensible for variables of segment-invariant type. +dataArr :: Segments -> DistEnv -> DistInputs -> SubExp -> Builder GPU VName +dataArr segments env inps (Var v) + | Just v_inp <- lookup v inps = + case v_inp of + DistInputFree vs _ -> irregularD <$> mkIrregFromReg segments vs + DistInput rt _ -> case resVar rt env of + Irregular r -> pure $ irregularD r + Regular vs -> irregularD <$> mkIrregFromReg segments vs +dataArr segments _ _ se = do + rep <- letExp "rep" $ BasicOp $ Replicate (segmentsShape segments) se + dims <- arrayDims <$> lookupType rep + if length dims == 1 + then pure rep + else do + n <- toSubExp "n" $ product $ map pe64 dims + letExp "reshape" $ BasicOp $ Reshape ReshapeArbitrary (Shape [n]) rep + +mkIrregFromReg :: + Segments -> + VName -> + Builder GPU IrregularRep +mkIrregFromReg segments arr = do + arr_t <- lookupType arr + segment_size <- + letSubExp "reg_seg_size" <=< toExp . product . map pe64 $ + drop (segmentsRank segments) (arrayDims arr_t) + arr_S <- + letExp "reg_segments" . BasicOp $ + Replicate (segmentsShape segments) segment_size + num_elems <- + letSubExp "reg_num_elems" <=< toExp $ product $ map pe64 $ arrayDims arr_t + arr_D <- + letExp "reg_D" . BasicOp $ + Reshape ReshapeArbitrary (Shape [num_elems]) arr + arr_F <- letExp "reg_F" <=< segMap (MkSolo num_elems) $ \(MkSolo i) -> do + flag <- letSubExp "flag" <=< toExp $ (pe64 i `rem` pe64 segment_size) .==. 0 + pure [subExpRes flag] + arr_O <- letExp "reg_O" <=< segMap (shapeDims (segmentsShape segments)) $ \is -> do + let flat_seg_i = + flattenIndex + (map pe64 (shapeDims (segmentsShape segments))) + (map pe64 is) + offset <- letSubExp "offset" <=< toExp $ flat_seg_i * pe64 segment_size + pure [subExpRes offset] + pure $ + IrregularRep + { irregularS = arr_S, + irregularF = arr_F, + irregularO = arr_O, + irregularD = arr_D + } + +-- Get the irregular representation of a var. +getIrregRep :: Segments -> DistEnv -> DistInputs -> VName -> Builder GPU IrregularRep +getIrregRep segments env inps v = + case lookup v inps of + Just v_inp -> case v_inp of + DistInputFree arr _ -> mkIrregFromReg segments arr + DistInput rt _ -> case resVar rt env of + Irregular r -> pure r + Regular arr -> mkIrregFromReg segments arr + Nothing -> do + v' <- + letExp (baseString v <> "_rep") . BasicOp $ + Replicate (segmentsShape segments) (Var v) + mkIrregFromReg segments v' + +-- Do 'map2 (++) A B' where 'A' and 'B' are irregular arrays and have the same +-- number of subarrays +concatIrreg :: + Segments -> + DistEnv -> + VName -> + [IrregularRep] -> + Builder GPU IrregularRep +concatIrreg _segments _env ns reparr = do + -- Concatenation does not change the number of segments - it simply + -- makes each of them larger. + + num_segments <- arraySize 0 <$> lookupType ns + + -- Constructs the full list size / shape that should hold the final results. + let zero = Constant $ IntValue $ intValue Int64 (0 :: Int) + ns_full <- letExp (baseString ns <> "_full") <=< segMap (MkSolo num_segments) $ + \(MkSolo i) -> do + old_segments <- + forM reparr $ \rep -> + letSubExp "old_segment" =<< eIndex (irregularS rep) [eSubExp i] + new_segment <- + letSubExp "new_segment" + =<< toExp (foldl (+) (pe64 zero) $ map pe64 old_segments) + pure $ subExpsRes [new_segment] + + (ns_full_F, ns_full_O, ns_II1) <- doRepIota ns_full + + repIota <- mapM (doRepIota . irregularS) reparr + segIota <- mapM (doSegIota . irregularS) reparr + + let (_, _, rep_II1) = unzip3 repIota + let (_, _, rep_II2) = unzip3 segIota + + n_arr <- mapM (fmap (arraySize 0) . lookupType) rep_II1 + + -- Calculate offsets for the scatter operations + let shapes = map irregularS reparr + scatter_offsets <- + letTupExp "irregular_scatter_offsets" <=< segMap (MkSolo num_segments) $ + \(MkSolo i) -> do + segment_sizes <- + forM shapes $ \shape -> + letSubExp "segment_size" =<< eIndex shape [eSubExp i] + let prefixes = L.init $ L.inits segment_sizes + sumprefix <- + mapM + ( letSubExp "segment_prefix" + <=< foldBinOp (Add Int64 OverflowUndef) (intConst Int64 0) + ) + prefixes + pure $ subExpsRes sumprefix + + scatter_offsets_T <- + letTupExp "irregular_scatter_offsets_T" <=< segMap (MkSolo num_segments) $ + \(MkSolo i) -> do + columns <- + forM scatter_offsets $ \offsets -> + letSubExp "segment_offset" =<< eIndex offsets [eSubExp i] + pure $ subExpsRes columns + + -- Scatter data into result array + elems <- + foldlM + ( \elems (reparr1, scatter_offset, n, ii1, ii2) -> do + letExp "irregular_scatter_elems" <=< genScatter elems n $ \gid -> do + -- Which segment we are in. + segment_i <- + letSubExp "segment_i" =<< eIndex ii1 [eSubExp gid] + + -- Get segment offset in final array + segment_o <- + letSubExp "segment_o" =<< eIndex ns_full_O [eSubExp segment_i] + + -- Get local segment offset + segment_local_o <- + letSubExp "segment_local_o" + =<< eIndex scatter_offset [eSubExp segment_i] + + -- Value to write + v' <- + letSubExp "v" =<< eIndex (irregularD reparr1) [eSubExp gid] + o' <- letSubExp "o" =<< eIndex ii2 [eSubExp gid] + + -- Index to write `v'` at + i <- + letExp "i" =<< toExp (pe64 o' + pe64 segment_local_o + pe64 segment_o) + + pure (i, v') + ) + ns_II1 + $ L.zip5 reparr scatter_offsets_T n_arr rep_II1 rep_II2 + + pure $ + IrregularRep + { irregularS = ns_full, + irregularF = ns_full_F, + irregularO = ns_full_O, + irregularD = elems + } + +-- Do 'map2 replicate ns A', where 'A' is an irregular array (and so +-- is the result, obviously). +replicateIrreg :: + Segments -> + DistEnv -> + VName -> + String -> + IrregularRep -> + Builder GPU IrregularRep +replicateIrreg _segments _env ns desc rep = do + -- Replication does not change the number of segments - it simply + -- makes each of them larger. + + num_segments <- arraySize 0 <$> lookupType ns + + -- ns multipled with existing segment sizes. + ns_full <- letExp (baseString ns <> "_full") <=< segMap (MkSolo num_segments) $ + \(MkSolo i) -> do + n <- + letSubExp "n" =<< eIndex ns [eSubExp i] + old_segment <- + letSubExp "old_segment" =<< eIndex (irregularS rep) [eSubExp i] + full_segment <- + letSubExp "new_segment" =<< toExp (pe64 n * pe64 old_segment) + pure $ subExpsRes [full_segment] + + (ns_full_F, ns_full_O, ns_full_D) <- doRepIota ns_full + (_, _, flat_to_segs) <- doSegIota ns_full + + w <- arraySize 0 <$> lookupType ns_full_D + + elems <- letExp (desc <> "_rep_D") <=< segMap (MkSolo w) $ \(MkSolo i) -> do + -- Which segment we are in. + segment_i <- + letSubExp "segment_i" =<< eIndex ns_full_D [eSubExp i] + -- Size of original segment. + old_segment <- + letSubExp "old_segment" =<< eIndex (irregularS rep) [eSubExp segment_i] + -- Index of value inside *new* segment. + j_new <- + letSubExp "j_new" =<< eIndex flat_to_segs [eSubExp i] + -- Index of value inside *old* segment. + j_old <- + letSubExp "j_old" =<< toExp (pe64 j_new `rem` pe64 old_segment) + -- Offset of values in original segment. + offset <- + letSubExp "offset" =<< eIndex (irregularO rep) [eSubExp segment_i] + v <- + letSubExp "v" + =<< eIndex (irregularD rep) [toExp $ pe64 offset + pe64 j_old] + pure $ subExpsRes [v] + + pure $ + IrregularRep + { irregularS = ns_full, + irregularF = ns_full_F, + irregularO = ns_full_O, + irregularD = elems + } + +-- | Flatten the arrays of an IrregularRep to be entirely one-dimensional. +flattenIrregularRep :: IrregularRep -> Builder GPU IrregularRep +flattenIrregularRep ir@(IrregularRep shape flags offsets elems) = do + elems_t <- lookupType elems + if arrayRank elems_t == 1 + then pure ir + else do + n <- arraySize 0 <$> lookupType shape + m' <- letSubExp "flat_m" <=< toExp $ product $ map pe64 $ arrayDims elems_t + elems' <- + letExp (baseString elems <> "_flat") $ + BasicOp $ + Reshape ReshapeArbitrary (Shape [m']) elems + + shape' <- letExp (baseString shape <> "_flat") <=< renameExp <=< segMap (MkSolo n) $ + \(MkSolo i) -> do + old_shape <- letSubExp "old_shape" =<< eIndex shape [toExp i] + segment_shape <- + letSubExp "segment_shape" <=< toExp $ + pe64 old_shape * product (map pe64 $ tail $ arrayDims elems_t) + pure [subExpRes segment_shape] + + offsets' <- letExp (baseString offsets <> "_flat") <=< renameExp <=< segMap (MkSolo n) $ + \(MkSolo i) -> do + old_offsets <- letSubExp "old_offsets" =<< eIndex offsets [toExp i] + segment_offsets <- + letSubExp "segment_offsets" <=< toExp $ + pe64 old_offsets * product (map pe64 $ tail $ arrayDims elems_t) + pure [subExpRes segment_offsets] + + flags' <- letExp (baseString flags <> "_flat") <=< renameExp <=< segMap (MkSolo m') $ + \(MkSolo i) -> do + let head_i = head $ unflattenIndex (map pe64 $ arrayDims elems_t) (pe64 i) + flag <- letSubExp "flag" =<< eIndex flags [toExp head_i] + pure [subExpRes flag] + pure $ IrregularRep shape' flags' offsets' elems' + +rearrangeFlat :: (IntegralExp num) => [Int] -> [num] -> num -> num +rearrangeFlat perm dims i = + -- TODO? Maybe we need to invert one of these permutations. + flattenIndex dims $ + rearrangeShape perm $ + unflattenIndex (rearrangeShape perm dims) i + +rearrangeIrreg :: + Segments -> + DistEnv -> + TypeBase Shape u -> + [Int] -> + IrregularRep -> + Builder GPU IrregularRep +rearrangeIrreg _segments _env v_t perm ir = do + (IrregularRep shape flags offsets elems) <- flattenIrregularRep ir + m <- arraySize 0 <$> lookupType elems + (_, _, ii1_vss) <- doRepIota shape + (_, _, ii2_vss) <- doSegIota shape + elems' <- letExp "elems_rearrange" <=< renameExp <=< segMap (MkSolo m) $ + \(MkSolo i) -> do + seg_i <- letSubExp "seg_i" =<< eIndex ii1_vss [eSubExp i] + offset <- letSubExp "offset" =<< eIndex offsets [eSubExp seg_i] + in_seg_i <- letSubExp "in_seg_i" =<< eIndex ii2_vss [eSubExp i] + let v_dims = map pe64 $ arrayDims v_t + in_seg_is_tr = rearrangeFlat perm v_dims $ pe64 in_seg_i + v' <- + letSubExp "v" + =<< eIndex elems [toExp $ pe64 offset + in_seg_is_tr] + pure [subExpRes v'] + pure $ + IrregularRep + { irregularS = shape, + irregularF = flags, + irregularO = offsets, + irregularD = elems' + } + +transformDistBasicOp :: + Segments -> + DistEnv -> + ( DistInputs, + DistResult, + PatElem Type, + StmAux (), + BasicOp + ) -> + Builder GPU DistEnv +transformDistBasicOp segments env (inps, res, pe, aux, e) = + case e of + BinOp {} -> + scalarCase + CmpOp {} -> + scalarCase + ConvOp {} -> + scalarCase + UnOp {} -> + scalarCase + Assert {} -> + scalarCase + Opaque _op se + | Var v <- se, + Just (DistInput rt_in _) <- lookup v inps -> + -- TODO: actually insert opaques + pure $ insertRep (distResTag res) (resVar rt_in env) env + | otherwise -> + scalarCase + Reshape _ _ arr + | Just (DistInput rt_in _) <- lookup arr inps -> + pure $ insertRep (distResTag res) (resVar rt_in env) env + Index arr slice + | null $ sliceDims slice -> + scalarCase + | otherwise -> do + -- Maximally irregular case. + ns <- letExp "slice_sizes" <=< segMap segments $ \is -> do + slice_ns <- mapM (readInput segments env (toList is) inps) $ sliceDims slice + fmap varsRes . letTupExp "n" <=< toExp $ product $ map pe64 slice_ns + (_n, offsets, m) <- exScanAndSum ns + (_, _, repiota_D) <- doRepIota ns + flags <- genFlags m offsets + elems <- letExp "elems" <=< renameExp <=< segMap (NE.singleton m) $ \is -> do + segment <- letSubExp "segment" =<< eIndex repiota_D (toList $ fmap eSubExp is) + segment_start <- letSubExp "segment_start" =<< eIndex offsets [eSubExp segment] + readInputs segments env [segment] inps + -- TODO: multidimensional segments + let slice' = + fixSlice (fmap pe64 slice) $ + unflattenIndex (map pe64 (sliceDims slice)) $ + subtract (pe64 segment_start) . pe64 $ + NE.head is + auxing aux $ + fmap (subExpsRes . pure) . letSubExp "v" + =<< eIndex arr (map toExp slice') + pure $ insertIrregular ns flags offsets (distResTag res) elems env + Iota n (Constant x) (Constant s) Int64 + | zeroIsh x, + oneIsh s -> do + ns <- dataArr segments env inps n + (flags, offsets, elems) <- certifying (distCerts inps aux env) $ doSegIota ns + pure $ insertIrregular ns flags offsets (distResTag res) elems env + Iota n x s it -> do + ns <- dataArr segments env inps n + xs <- dataArr segments env inps x + ss <- dataArr segments env inps s + (res_F, res_O, res_D) <- certifying (distCerts inps aux env) $ doSegIota ns + (_, _, repiota_D) <- doRepIota ns + m <- arraySize 0 <$> lookupType res_D + res_D' <- letExp "iota_D_fixed" <=< segMap (MkSolo m) $ \(MkSolo i) -> do + segment <- letSubExp "segment" =<< eIndex repiota_D [eSubExp i] + v' <- letSubExp "v" =<< eIndex res_D [eSubExp i] + x' <- letSubExp "x" =<< eIndex xs [eSubExp segment] + s' <- letSubExp "s" =<< eIndex ss [eSubExp segment] + fmap (subExpsRes . pure) . letSubExp "v" <=< toExp $ + primExpFromSubExp (IntType it) x' + ~+~ sExt it (untyped (pe64 v')) + ~*~ primExpFromSubExp (IntType it) s' + pure $ insertIrregular ns res_F res_O (distResTag res) res_D' env + Concat 0 arr shp -> do + ns <- dataArr segments env inps shp + reparr <- mapM (getIrregRep segments env inps) (NE.toList arr) + rep' <- concatIrreg segments env ns reparr + pure $ insertRep (distResTag res) (Irregular rep') env + Replicate (Shape [n]) (Var v) -> do + ns <- dataArr segments env inps n + rep <- getIrregRep segments env inps v + rep' <- replicateIrreg segments env ns (baseString v) rep + pure $ insertRep (distResTag res) (Irregular rep') env + Replicate (Shape [n]) (Constant v) -> do + ns <- dataArr segments env inps n + (res_F, res_O, res_D) <- + certifying (distCerts inps aux env) $ doSegIota ns + w <- arraySize 0 <$> lookupType res_D + res_D' <- letExp "rep_const" $ BasicOp $ Replicate (Shape [w]) (Constant v) + pure $ insertIrregular ns res_F res_O (distResTag res) res_D' env + Replicate (Shape []) (Var v) -> + case lookup v inps of + Just (DistInputFree v' _) -> do + v'' <- + letExp (baseString v' <> "_copy") . BasicOp $ + Replicate mempty (Var v') + pure $ insertRegulars [distResTag res] [v''] env + Just (DistInput rt _) -> + case resVar rt env of + Irregular r -> do + let name = baseString (irregularD r) <> "_copy" + elems_copy <- + letExp name . BasicOp $ + Replicate mempty (Var $ irregularD r) + let rep = Irregular $ r {irregularD = elems_copy} + pure $ insertRep (distResTag res) rep env + Regular v' -> do + v'' <- + letExp (baseString v' <> "_copy") . BasicOp $ + Replicate mempty (Var v') + pure $ insertRegulars [distResTag res] [v''] env + Nothing -> do + v' <- + letExp (baseString v <> "_copy_free") . BasicOp $ + Replicate (segmentsShape segments) (Var v) + pure $ insertRegulars [distResTag res] [v'] env + Update _ as slice (Var v) + | Just as_t <- distInputType <$> lookup as inps -> do + ns <- letExp "slice_sizes" + <=< renameExp + <=< segMap (shapeDims (segmentsShape segments)) + $ \is -> do + readInputs segments env is $ + filter ((`elem` sliceDims slice) . Var . fst) inps + n <- letSubExp "n" <=< toExp $ product $ map pe64 $ sliceDims slice + pure [subExpRes n] + -- Irregular representation of `as` + IrregularRep shape flags offsets elems <- getIrregRep segments env inps as + -- Inner indices (1 and 2) of `ns` + (_, _, ii1_vss) <- doRepIota ns + (_, _, ii2_vss) <- certifying (distCerts inps aux env) $ doSegIota ns + -- Number of updates to perform + m <- arraySize 0 <$> lookupType ii2_vss + elems' <- letExp "elems_scatter" <=< renameExp <=< genScatter elems m $ \gid -> do + seg_i <- letSubExp "seg_i" =<< eIndex ii1_vss [eSubExp gid] + in_seg_i <- letSubExp "in_seg_i" =<< eIndex ii2_vss [eSubExp gid] + readInputs segments env [seg_i] $ filter ((/= as) . fst) inps + v_t <- lookupType v + let in_seg_is = + unflattenIndex (map pe64 (arrayDims v_t)) (pe64 in_seg_i) + slice' = fmap pe64 slice + flat_i = + flattenIndex + (map pe64 $ arrayDims as_t) + (fixSlice slice' in_seg_is) + -- Value to write + v' <- letSubExp "v" =<< eIndex v (map toExp in_seg_is) + o' <- letSubExp "o" =<< eIndex offsets [eSubExp seg_i] + -- Index to write `v'` at + i <- letExp "i" =<< toExp (pe64 o' + flat_i) + pure (i, v') + pure $ insertIrregular shape flags offsets (distResTag res) elems' env + | otherwise -> + error "Flattening update: destination is not input." + Rearrange perm v -> do + case lookup v inps of + Just (DistInputFree v' _) -> do + v'' <- + letExp (baseString v' <> "_tr") . BasicOp $ + Rearrange perm v' + pure $ insertRegulars [distResTag res] [v''] env + Just (DistInput rt v_t) -> do + case resVar rt env of + Irregular rep -> do + rep' <- + certifying (distCerts inps aux env) $ + rearrangeIrreg segments env v_t perm rep + pure $ insertRep (distResTag res) (Irregular rep') env + Regular v' -> do + let r = segmentsRank segments + v'' <- + letExp (baseString v' <> "_tr") . BasicOp $ + Rearrange ([0 .. r - 1] ++ map (+ r) perm) v' + pure $ insertRegulars [distResTag res] [v''] env + Nothing -> do + let r = segmentsRank segments + v' <- + letExp (baseString v <> "_tr") . BasicOp $ + Rearrange ([0 .. r - 1] ++ map (+ r) perm) v + pure $ insertRegulars [distResTag res] [v'] env + _ -> error $ "Unhandled BasicOp:\n" ++ prettyString e + where + scalarCase = + transformScalarStm segments env inps [res] $ + Let (Pat [pe]) aux (BasicOp e) + +-- Replicates inner dimension for inputs. +onMapFreeVar :: + Segments -> + DistEnv -> + DistInputs -> + VName -> + (VName, VName, VName) -> + VName -> + Maybe (Builder GPU (VName, MapArray IrregularRep)) +onMapFreeVar _segments env inps ws (_ws_F, _ws_O, ws_data) v = do + let segments_per_elem = ws_data + v_inp <- lookup v inps + pure $ do + ws_prod <- arraySize 0 <$> lookupType ws_data + fmap (v,) $ case v_inp of + DistInputFree v' t -> do + fmap (`MapArray` t) + . letExp (baseString v <> "_rep_free_free_inp") + <=< segMap (MkSolo ws_prod) + $ \(MkSolo i) -> do + segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] + subExpsRes . pure <$> (letSubExp "v" =<< eIndex v' [eSubExp segment]) + DistInput rt t -> case resVar rt env of + Irregular rep -> do + offsets <- letExp (baseString v <> "_rep_free_irreg_O") + <=< segMap (MkSolo ws_prod) + $ \(MkSolo i) -> do + segment <- letSubExp "segment" =<< eIndex ws_data [eSubExp i] + subExpsRes . pure <$> (letSubExp "v" =<< eIndex (irregularO rep) [eSubExp segment]) + let rep' = + IrregularRep + { irregularS = ws, + irregularF = irregularF rep, + irregularO = offsets, + irregularD = irregularD rep + } + pure $ MapOther rep' t + Regular vs -> + fmap (`MapArray` t) + . letExp (baseString v <> "_rep_free_reg_inp") + <=< segMap (MkSolo ws_prod) + $ \(MkSolo i) -> do + segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] + subExpsRes . pure <$> (letSubExp "v" =<< eIndex vs [eSubExp segment]) + +onMapInputArr :: + Segments -> + DistEnv -> + DistInputs -> + VName -> + Param Type -> + VName -> + Builder GPU (MapArray IrregularRep) +onMapInputArr segments env inps ii2 p arr = do + ws_prod <- arraySize 0 <$> lookupType ii2 + case lookup arr inps of + Just v_inp -> + case v_inp of + DistInputFree vs t -> do + let inner_shape = arrayShape $ paramType p + v <- + letExp (baseString vs <> "_flat") . BasicOp $ + Reshape ReshapeArbitrary (Shape [ws_prod] <> inner_shape) vs + pure $ MapArray v t + DistInput rt _ -> + case resVar rt env of + Irregular rep -> do + elems_t <- lookupType $ irregularD rep + -- If parameter type of the map corresponds to the + -- element type of the value array, we can map it + -- directly. + if stripArray (segmentsRank segments) elems_t == paramType p + then pure $ MapArray (irregularD rep) elems_t + else do + -- Otherwise we need to perform surgery on the metadata. + ~[p_segments, p_O] <- letTupExp + (baseString (paramName p) <> "_rep_inp_irreg") + <=< segMap (MkSolo ws_prod) + $ \(MkSolo i) -> do + segment_i <- + letSubExp "segment" =<< eIndex ii2 [eSubExp i] + segment <- + letSubExp "v" =<< eIndex (irregularS rep) [eSubExp segment_i] + offset <- + letSubExp "v" =<< eIndex (irregularO rep) [eSubExp segment_i] + pure $ subExpsRes [segment, offset] + let rep' = + IrregularRep + { irregularD = irregularD rep, + irregularF = irregularF rep, + irregularS = p_segments, + irregularO = p_O + } + pure $ MapOther rep' elems_t + Regular _vs -> + undefined + Nothing -> do + arr_row_t <- rowType <$> lookupType arr + arr_rep <- + letExp (baseString arr <> "_inp_rep") . BasicOp $ + Replicate (segmentsShape segments) (Var arr) + v <- + letExp (baseString arr <> "_inp_rep_flat") . BasicOp $ + Reshape ReshapeArbitrary (Shape [ws_prod] <> arrayShape arr_row_t) arr_rep + pure $ MapArray v arr_row_t + +scopeOfDistInputs :: DistInputs -> Scope GPU +scopeOfDistInputs = scopeOfLParams . map f + where + f (v, inp) = Param mempty v (distInputType inp) + +transformInnerMap :: + Segments -> + DistEnv -> + DistInputs -> + Pat Type -> + SubExp -> + [VName] -> + Lambda SOACS -> + Builder GPU (VName, VName, VName) +transformInnerMap segments env inps pat w arrs map_lam = do + ws <- dataArr segments env inps w + (ws_F, ws_O, ws_data) <- doRepIota ws + new_segment <- arraySize 0 <$> lookupType ws_data + arrs' <- + zipWithM + (onMapInputArr segments env inps ws_data) + (lambdaParams map_lam) + arrs + let free = freeIn map_lam + free_sizes <- + localScope (scopeOfDistInputs inps) $ + foldMap freeIn <$> mapM lookupType (namesToList free) + let free_and_sizes = namesToList $ free <> free_sizes + (free_replicated, replicated) <- + fmap unzip . sequence $ + mapMaybe + (onMapFreeVar segments env inps ws (ws_F, ws_O, ws_data)) + free_and_sizes + free_ps <- + zipWithM + newParam + (map ((<> "_free") . baseString) free_and_sizes) + (map mapArrayRowType replicated) + scope <- askScope + let substs = M.fromList $ zip free_replicated $ map paramName free_ps + map_lam' = + substituteNames + substs + ( map_lam + { lambdaParams = free_ps <> lambdaParams map_lam + } + ) + (distributed, arrmap) = + distributeMap scope pat new_segment (replicated <> arrs') map_lam' + m = + transformDistributed arrmap (NE.singleton new_segment) distributed + traceM $ unlines ["inner map distributed", prettyString distributed] + addStms =<< runReaderT (runBuilder_ m) scope + pure (ws_F, ws_O, ws) + +transformDistStm :: Segments -> DistEnv -> DistStm -> Builder GPU DistEnv +transformDistStm segments env (DistStm inps res stm) = do + case stm of + Let pat aux (BasicOp e) -> do + let ~[res'] = res + ~[pe] = patElems pat + transformDistBasicOp segments env (inps, res', pe, aux, e) + Let pat _ (Op (Screma w arrs form)) + | Just reds <- isReduceSOAC form, + Just arrs' <- mapM (`lookup` inps) arrs, + (Just (arr_segments, flags, offsets), elems) <- segsAndElems env arrs' -> do + elems' <- genSegRed arr_segments flags offsets elems $ singleReduce reds + pure $ insertReps (zip (map distResTag res) (map Regular elems')) env + | Just (reds, map_lam) <- isRedomapSOAC form -> do + map_pat <- fmap Pat $ forM (lambdaReturnType map_lam) $ \t -> + PatElem <$> newVName "map" <*> pure (t `arrayOfRow` w) + (ws_F, ws_O, ws) <- + transformInnerMap segments env inps map_pat w arrs map_lam + let (redout_names, mapout_names) = + splitAt (redResults reds) (patNames map_pat) + elems' <- + genSegRed ws ws_F ws_O redout_names $ + singleReduce reds + let (red_tags, map_tags) = splitAt (redResults reds) $ map distResTag res + pure $ + insertRegulars red_tags elems' $ + insertIrregulars ws ws_F ws_O (zip map_tags mapout_names) env + | Just map_lam <- isMapSOAC form -> do + (ws_F, ws_O, ws) <- transformInnerMap segments env inps pat w arrs map_lam + pure $ insertIrregulars ws ws_F ws_O (zip (map distResTag res) $ patNames pat) env + Let _ _ (Match scrutinees cases defaultCase _) -> do + let [w] = NE.toList segments + + -- Lift the scrutinees. + -- If it's a variable, we know it's a scalar and the lifted version will therefore be a regular array. + lifted_scrutinees <- forM scrutinees $ \scrut -> do + (_, rep) <- liftSubExp segments inps env scrut + case rep of + Regular v' -> pure v' + Irregular {} -> + error $ + "transformDistStm: Non-scalar match scrutinee: " ++ prettyString scrut + -- Cases for tagging values that match the same branch. + -- The default case is the 0'th equvalence class. + let equiv_cases = + zipWith + ( \(Case pat _) n -> + Case pat $ eBody [toExp $ intConst Int64 n] + ) + cases + [1 ..] + let equiv_case_default = eBody [toExp $ intConst Int64 0] + -- Match the scrutinees againts the branch cases + equiv_classes <- letExp "equiv_classes" <=< segMap (MkSolo w) $ \(MkSolo i) -> do + scruts <- mapM (letSubExp "scruts" <=< flip eIndex [toExp i]) lifted_scrutinees + cls <- letSubExp "cls" =<< eMatch scruts equiv_cases equiv_case_default + pure [subExpRes cls] + let num_cases = fromIntegral $ length cases + 1 + n_cases <- letExp "n_cases" <=< toExp $ intConst Int64 num_cases + -- Parition the indices of the scrutinees by their equvalence class such + -- that (the indices) of the scrutinees belonging to class 0 come first, + -- then those belonging to class 1 and so on. + (partition_sizes, partition_offs, partition_inds) <- doPartition n_cases equiv_classes + inds_t <- lookupType partition_inds + -- Get the indices of each scrutinee by equivalence class + inds <- forM [0 .. num_cases - 1] $ \i -> do + num_data <- + letSubExp ("size" ++ show i) + =<< eIndex partition_sizes [toExp $ intConst Int64 i] + begin <- + letSubExp ("idx_begin" ++ show i) + =<< eIndex partition_offs [toExp $ intConst Int64 i] + letExp ("inds_branch" ++ show i) $ + BasicOp $ + Index partition_inds $ + fullSlice inds_t [DimSlice begin num_data (intConst Int64 1)] + + -- Take the elements at index `is` from an input `v`. + let splitInput is v = do + (t, rep) <- liftSubExp segments inps env (Var v) + (t,v,) <$> case rep of + Regular arr -> do + -- In the regular case we just take the elements + -- of the array given by `is` + n <- letSubExp "n" =<< (toExp . arraySize 0 =<< lookupType is) + arr' <- letExp "split_arr" <=< segMap (MkSolo n) $ \(MkSolo i) -> do + idx <- letExp "idx" =<< eIndex is [eSubExp i] + subExpsRes . pure <$> (letSubExp "arr" =<< eIndex arr [toExp idx]) + pure $ Regular arr' + Irregular (IrregularRep segs flags offsets elems) -> do + -- In the irregular case we take the elements + -- of the `segs` array given by `is` like in the regular case + n <- letSubExp "n" =<< (toExp . arraySize 0 =<< lookupType is) + segs' <- letExp "split_segs" <=< segMap (MkSolo n) $ \(MkSolo i) -> do + idx <- letExp "idx" =<< eIndex is [eSubExp i] + subExpsRes . pure <$> (letSubExp "segs" =<< eIndex segs [toExp idx]) + -- From this we calculate the offsets and number of elements + (_, offsets', num_data) <- exScanAndSum segs' + (_, _, ii1) <- doRepIota segs' + (_, _, ii2) <- doSegIota segs' + -- We then take the elements we need from `elems` and `flags` + -- For each index `i`, we roughly: + -- Get the offset of the segment we want to copy by indexing + -- `offsets` through `is` further through `ii1` i.e. + -- `offset = offsets[is[ii1[i]]]` + -- We then add `ii2[i]` to `offset` + -- and use that to index into `elems` and `flags`. + ~[flags', elems'] <- letTupExp "split_F_data" <=< segMap (MkSolo num_data) $ \(MkSolo i) -> do + offset <- letExp "offset" =<< eIndex offsets [eIndex is [eIndex ii1 [eSubExp i]]] + idx <- letExp "idx" =<< eBinOp (Add Int64 OverflowUndef) (toExp offset) (eIndex ii2 [eSubExp i]) + flags_split <- letSubExp "flags" =<< eIndex flags [toExp idx] + elems_split <- letSubExp "elems" =<< eIndex elems [toExp idx] + pure $ subExpsRes [flags_split, elems_split] + pure $ + Irregular $ + IrregularRep + { irregularS = segs', + irregularF = flags', + irregularO = offsets', + irregularD = elems' + } + -- Given the indices for which a branch is taken and its body, + -- distribute the statements of the body of that branch. + let distributeBranch is body = do + (ts, vs, reps) <- unzip3 <$> mapM (splitInput is) (namesToList $ freeIn body) + let inputs = do + (v, t, i) <- zip3 vs ts [0 ..] + pure (v, DistInput (ResTag i) t) + let env' = DistEnv $ M.fromList $ zip (map ResTag [0 ..]) reps + scope <- askScope + let (inputs', dstms) = distributeBody scope w inputs body + pure (inputs', env', dstms) + + -- Distribute and lift the branch bodies. + -- We put the default case at the start as it's the 0'th equivalence class + -- and is therefore the first segment after the partition. + let branch_bodies = defaultCase : map (\(Case _ body) -> body) cases + (branch_inputs, branch_envs, branch_dstms) <- unzip3 <$> zipWithM distributeBranch inds branch_bodies + let branch_results = map bodyResult branch_bodies + lifted_bodies <- forM [0 .. num_cases - 1] $ \i -> do + size <- letSubExp "size" =<< eIndex partition_sizes [toExp $ intConst Int64 i] + let inputs = branch_inputs !! fromIntegral i + let env' = branch_envs !! fromIntegral i + let dstms = branch_dstms !! fromIntegral i + let result = branch_results !! fromIntegral i + res' <- liftBody size inputs env' dstms result + subExpsRes <$> mapM (\(SubExpRes _ se) -> letSubExp ("result" ++ show i) =<< toExp se) res' + + let result_types = map ((\(DistType _ _ t) -> t) . distResType) res + branch_reps <- + mapM + ( fmap (resultToResReps result_types) + . mapM (letExp "branch_result" <=< toExp . resSubExp) + ) + lifted_bodies + + -- Write back the regular results of a branch to a (partially) blank space + let scatterRegular space (is, xs) = do + ~(Array _ (Shape [size]) _) <- lookupType xs + letExp "regular_scatter" <=< genScatter space size $ \gtid -> do + x <- letSubExp "x" =<< eIndex xs [eSubExp gtid] + i <- letExp "i" =<< eIndex is [eSubExp gtid] + pure (i, x) + -- Write back the irregular elements of a branch to a (partially) blank space + -- The `offsets` variable is the offsets of the final result, + -- whereas `irregRep` is the irregular representation of the result of a single branch. + let scatterIrregular offsets space (is, irregRep) = do + let IrregularRep {irregularS = segs, irregularD = elems} = irregRep + (_, _, ii1) <- doRepIota segs + (_, _, ii2) <- doSegIota segs + ~(Array _ (Shape [size]) _) <- lookupType elems + letExp "irregular_scatter" <=< genScatter space size $ \gtid -> do + x <- letSubExp "x" =<< eIndex elems [eSubExp gtid] + offset <- letExp "offset" =<< eIndex offsets [eIndex is [eIndex ii1 [eSubExp gtid]]] + i <- letExp "i" =<< eBinOp (Add Int64 OverflowUndef) (toExp offset) (eIndex ii2 [eSubExp gtid]) + pure (i, x) + -- Given a single result from each branch as well the *unlifted* + -- result type, merge the results of all branches into a single result. + let mergeResult iss branchesRep resType = + case resType of + -- Regular case + Prim pt -> do + let xs = map (\(Regular v) -> v) branchesRep + let resultType = Array pt (Shape [w]) NoUniqueness + -- Create the blank space for the result + resultSpace <- letExp "blank_res" =<< eBlank resultType + -- Write back the values of each branch to the blank space + result <- foldM scatterRegular resultSpace $ zip iss xs + pure $ Regular result + -- Irregular case + Array pt _ _ -> do + let branchesIrregRep = map (\(Irregular irregRep) -> irregRep) branchesRep + let segsType = Array (IntType Int64) (Shape [w]) NoUniqueness + -- Create a blank space for the 'segs' + segsSpace <- letExp "blank_segs" =<< eBlank segsType + -- Write back the segs of each branch to the blank space + segs <- foldM scatterRegular segsSpace $ zip iss (irregularS <$> branchesIrregRep) + (_, offsets, num_data) <- exScanAndSum segs + let resultType = Array pt (Shape [num_data]) NoUniqueness + -- Create the blank space for the result + resultSpace <- letExp "blank_res" =<< eBlank resultType + -- Write back the values of each branch to the blank space + elems <- foldM (scatterIrregular offsets) resultSpace $ zip iss branchesIrregRep + flags <- genFlags num_data offsets + pure $ + Irregular $ + IrregularRep + { irregularS = segs, + irregularF = flags, + irregularO = offsets, + irregularD = elems + } + Acc {} -> error "transformDistStm: Acc" + Mem {} -> error "transformDistStm: Mem" + + -- Merge the results of the branches and insert the resulting res reps + reps <- zipWithM (mergeResult inds) (L.transpose branch_reps) result_types + pure $ insertReps (zip (map distResTag res) reps) env + Let _ _ (Apply name args rettype s) -> do + let [w] = NE.toList segments + name' = liftFunName name + args' <- ((w, Observe) :) . concat <$> mapM (liftArg segments inps env) args + args_ts <- mapM (subExpType . fst) args' + let dietToUnique Consume = Unique + dietToUnique Observe = Nonunique + dietToUnique ObservePrim = Nonunique + param_ts = zipWith toDecl args_ts $ map (dietToUnique . snd) args' + rettype' = addRetAls param_ts $ liftRetType w $ map fst rettype + + result <- letTupExp (nameToString name' <> "_res") $ Apply name' args' rettype' s + let reps = resultToResReps (map fst rettype) result + pure $ insertReps (zip (map distResTag res) reps) env + _ -> error $ "Unhandled Stm:\n" ++ prettyString stm + +-- | This function walks through the *unlifted* result types +-- and uses the *lifted* results to construct the corresponding res reps. +-- +-- See the 'liftResult' function for the opposite process i.e. +-- turning 'ResRep's into results. +resultToResReps :: [TypeBase s u] -> [VName] -> [ResRep] +resultToResReps types results = + snd $ + L.mapAccumL + ( \rs t -> case t of + Prim {} -> + let (v : rs') = rs + rep = Regular v + in (rs', rep) + Array {} -> + let (_ : segs : flags : offsets : elems : rs') = rs + rep = Irregular $ IrregularRep segs flags offsets elems + in (rs', rep) + Acc {} -> error "resultToResReps: Illegal type 'Acc'" + Mem {} -> error "resultToResReps: Illegal type 'Mem'" + ) + results + types + +distResCerts :: DistEnv -> [DistInput] -> Certs +distResCerts env = Certs . map f + where + f (DistInputFree v _) = v + f (DistInput rt _) = case resVar rt env of + Regular v -> v + Irregular {} -> error "resCerts: irregular" + +transformDistributed :: + M.Map ResTag IrregularRep -> + Segments -> + Distributed -> + Builder GPU () +transformDistributed irregs segments dist = do + let Distributed dstms (DistResults resmap reps) = dist + env <- foldM (transformDistStm segments) env_initial dstms + forM_ (M.toList resmap) $ \(rt, (cs_inps, v, v_t)) -> + certifying (distResCerts env cs_inps) $ + -- FIXME: the copies are because we have too liberal aliases on + -- lifted functions. + case resVar rt env of + Regular v' -> letBindNames [v] $ BasicOp $ Replicate mempty $ Var v' + Irregular irreg -> do + -- It might have an irregular representation, but we know + -- that it is actually regular because it is a result. + let shape = segmentsShape segments <> arrayShape v_t + v_copy <- + letExp (baseString v) . BasicOp $ + Replicate mempty (Var $ irregularD irreg) + letBindNames [v] $ + BasicOp (Reshape ReshapeArbitrary shape v_copy) + forM_ reps $ \(v, r) -> + case r of + Left se -> + letBindNames [v] $ BasicOp $ Replicate (segmentsShape segments) se + Right (DistInputFree arr _) -> + letBindNames [v] $ BasicOp $ SubExp $ Var arr + Right DistInput {} -> + error "replication of irregular identity result" + where + env_initial = DistEnv {distResMap = M.map Irregular irregs} + +transformStm :: Scope SOACS -> Stm SOACS -> PassM (Stms GPU) +transformStm scope (Let pat _ (Op (Screma w arrs form))) + | Just lam <- isMapSOAC form = do + let arrs' = + zipWith MapArray arrs $ + map paramType (lambdaParams (scremaLambda form)) + (distributed, _) = distributeMap scope pat w arrs' lam + m = transformDistributed mempty (NE.singleton w) distributed + traceM $ prettyString distributed + runReaderT (runBuilder_ m) scope +transformStm _ stm = pure $ oneStm $ soacsStmToGPU stm + +transformStms :: Scope SOACS -> Stms SOACS -> PassM (Stms GPU) +transformStms scope stms = + fold <$> traverse (transformStm (scope <> scopeOf stms)) stms + +-- If the sub-expression is a constant, replicate it to match the shape of `segments` +-- If it's a variable, lookup the variable in the dist inputs and dist env, +-- and if it can't be found it is a free variable, so we replicate it to match the shape of `segments`. +liftSubExp :: Segments -> DistInputs -> DistEnv -> SubExp -> Builder GPU (Type, ResRep) +liftSubExp segments inps env se = case se of + c@(Constant prim) -> + let t = Prim $ primValueType prim + in ((t,) . Regular <$> letExp "lifted_const" (BasicOp $ Replicate (segmentsShape segments) c)) + Var v -> case M.lookup v $ inputReps inps env of + Just (t, Regular v') -> do + (t,) + <$> case t of + Prim {} -> pure $ Regular v' + Array {} -> Irregular <$> mkIrregFromReg segments v' + Acc {} -> error "getRepSubExp: Acc" + Mem {} -> error "getRepSubExp: Mem" + Just (t, Irregular irreg) -> pure (t, Irregular irreg) + Nothing -> do + t <- lookupType v + v' <- letExp "free_replicated" $ BasicOp $ Replicate (segmentsShape segments) (Var v) + (t,) + <$> case t of + Prim {} -> pure $ Regular v' + Array {} -> Irregular <$> mkIrregFromReg segments v' + Acc {} -> error "getRepSubExp: Acc" + Mem {} -> error "getRepSubExp: Mem" + +liftParam :: SubExp -> FParam SOACS -> PassM ([FParam GPU], ResRep) +liftParam w fparam = + case declTypeOf fparam of + Prim pt -> do + p <- + newParam + (desc <> "_lifted") + (arrayOf (Prim pt) (Shape [w]) Nonunique) + pure ([p], Regular $ paramName p) + Array pt _ u -> do + num_data <- + newParam (desc <> "_num_data") $ Prim int64 + segments <- + newParam (desc <> "_segments") $ + arrayOf (Prim int64) (Shape [w]) Nonunique + flags <- + newParam (desc <> "_F") $ + arrayOf (Prim Bool) (Shape [Var (paramName num_data)]) Nonunique + offsets <- + newParam (desc <> "_O") $ + arrayOf (Prim int64) (Shape [w]) Nonunique + elems <- + newParam (desc <> "_data") $ + arrayOf (Prim pt) (Shape [Var (paramName num_data)]) u + pure + ( [num_data, segments, flags, offsets, elems], + Irregular $ + IrregularRep + { irregularS = paramName segments, + irregularF = paramName flags, + irregularO = paramName offsets, + irregularD = paramName elems + } + ) + Acc {} -> + error "liftParam: Acc" + Mem {} -> + error "liftParam: Mem" + where + desc = baseString (paramName fparam) + +liftArg :: Segments -> DistInputs -> DistEnv -> (SubExp, Diet) -> Builder GPU [(SubExp, Diet)] +liftArg segments inps env (se, d) = do + (_, rep) <- liftSubExp segments inps env se + case rep of + Regular v -> pure [(Var v, d)] + Irregular irreg -> mkIrrep irreg + where + mkIrrep + ( IrregularRep + { irregularS = segs, + irregularF = flags, + irregularO = offsets, + irregularD = elems + } + ) = do + t <- lookupType elems + num_data <- letExp "num_data" =<< toExp (product $ map pe64 $ arrayDims t) + flags' <- letExp "flags" $ BasicOp $ Reshape ReshapeArbitrary (Shape [Var num_data]) flags + elems' <- letExp "elems" $ BasicOp $ Reshape ReshapeArbitrary (Shape [Var num_data]) elems + -- Only apply the original diet to the 'elems' array + let diets = replicate 4 Observe ++ [d] + pure $ zipWith (curry (first Var)) [num_data, segs, flags', offsets, elems'] diets + +-- Lifts a functions return type such that it matches the lifted functions return type. +liftRetType :: SubExp -> [RetType SOACS] -> [RetType GPU] +liftRetType w = concat . snd . L.mapAccumL liftType 0 + where + liftType i rettype = + let lifted = case rettype of + Prim pt -> pure $ arrayOf (Prim pt) (Shape [Free w]) Nonunique + Array pt _ u -> + let num_data = Prim int64 + segs = arrayOf (Prim int64) (Shape [Free w]) Nonunique + flags = arrayOf (Prim Bool) (Shape [Ext i]) Nonunique + offsets = arrayOf (Prim int64) (Shape [Free w]) Nonunique + elems = arrayOf (Prim pt) (Shape [Ext i]) u + in [num_data, segs, flags, offsets, elems] + Acc {} -> error "liftRetType: Acc" + Mem {} -> error "liftRetType: Mem" + in (i + length lifted, lifted) + +-- Lift a result of a function. +liftResult :: Segments -> DistInputs -> DistEnv -> SubExpRes -> Builder GPU Result +liftResult segments inps env res = map (SubExpRes mempty . Var) <$> vs + where + vs = do + (_, rep) <- liftSubExp segments inps env (resSubExp res) + case rep of + Regular v -> pure [v] + Irregular irreg -> mkIrrep irreg + mkIrrep + ( IrregularRep + { irregularS = segs, + irregularF = flags, + irregularO = offsets, + irregularD = elems + } + ) = do + t <- lookupType elems + num_data <- letExp "num_data" =<< toExp (product $ map pe64 $ arrayDims t) + flags' <- letExp "flags" $ BasicOp $ Reshape ReshapeArbitrary (Shape [Var num_data]) flags + elems' <- letExp "elems" $ BasicOp $ Reshape ReshapeArbitrary (Shape [Var num_data]) elems + pure [num_data, segs, flags', offsets, elems'] + +liftBody :: SubExp -> DistInputs -> DistEnv -> [DistStm] -> Result -> Builder GPU Result +liftBody w inputs env dstms result = do + let segments = NE.singleton w + env' <- foldM (transformDistStm segments) env dstms + result' <- mapM (liftResult segments inputs env') result + pure $ concat result' + +liftFunName :: Name -> Name +liftFunName name = name <> "_lifted" + +addRetAls :: [DeclType] -> [RetType GPU] -> [(RetType GPU, RetAls)] +addRetAls params rettype = zip rettype $ map possibleAliases rettype + where + aliasable (Array _ _ Nonunique) = True + aliasable _ = False + aliasable_params = + map snd $ filter (aliasable . fst) $ zip params [0 ..] + aliasable_rets = + map snd $ filter (aliasable . declExtTypeOf . fst) $ zip rettype [0 ..] + possibleAliases t + | aliasable t = RetAls aliasable_params aliasable_rets + | otherwise = mempty + +liftFunDef :: Scope SOACS -> FunDef SOACS -> PassM (FunDef GPU) +liftFunDef const_scope fd = do + let FunDef + { funDefBody = body, + funDefParams = fparams, + funDefRetType = rettype + } = fd + wp <- newParam "w" $ Prim int64 + let w = Var $ paramName wp + (fparams', reps) <- mapAndUnzipM (liftParam w) fparams + let fparams'' = wp : concat fparams' + let inputs = do + (p, i) <- zip fparams [0 ..] + pure (paramName p, DistInput (ResTag i) (paramType p)) + let rettype' = + addRetAls (map paramDeclType fparams'') $ + liftRetType w (map fst rettype) + let (inputs', dstms) = + distributeBody const_scope (Var (paramName wp)) inputs body + env = DistEnv $ M.fromList $ zip (map ResTag [0 ..]) reps + -- Lift the body of the function and get the results + (result, stms) <- + runReaderT + (runBuilder $ liftBody w inputs' env dstms $ bodyResult body) + (const_scope <> scopeOfFParams fparams'') + let name = liftFunName $ funDefName fd + pure $ + fd + { funDefName = name, + funDefBody = Body () stms result, + funDefParams = fparams'', + funDefRetType = rettype' + } + +transformFunDef :: Scope SOACS -> FunDef SOACS -> PassM (FunDef GPU) +transformFunDef consts_scope fd = do + let FunDef + { funDefBody = Body () stms res, + funDefParams = fparams, + funDefRetType = rettype + } = fd + stms' <- transformStms (consts_scope <> scopeOfFParams fparams) stms + pure $ + fd + { funDefBody = Body () stms' res, + funDefRetType = rettype, + funDefParams = fparams + } + +transformProg :: Prog SOACS -> PassM (Prog GPU) +transformProg prog = do + consts' <- transformStms mempty $ progConsts prog + funs' <- mapM (transformFunDef $ scopeOf (progConsts prog)) $ progFuns prog + lifted_funs <- + mapM (liftFunDef $ scopeOf (progConsts prog)) $ + filter (isNothing . funDefEntryPoint) $ + progFuns prog + pure $ + prog + { progConsts = consts', + progFuns = flatteningBuiltins <> lifted_funs <> funs' + } + +-- | Transform a SOACS program to a GPU program, using flattening. +flattenSOACs :: Pass SOACS GPU +flattenSOACs = + Pass + { passName = "flatten", + passDescription = "Perform full flattening", + passFunction = transformProg + } +{-# NOINLINE flattenSOACs #-} diff --git a/src/Futhark/Pass/Flatten/Builtins.hs b/src/Futhark/Pass/Flatten/Builtins.hs new file mode 100644 index 0000000000..003729a454 --- /dev/null +++ b/src/Futhark/Pass/Flatten/Builtins.hs @@ -0,0 +1,483 @@ +{-# LANGUAGE TypeFamilies #-} + +module Futhark.Pass.Flatten.Builtins + ( flatteningBuiltins, + segMap, + genFlags, + genSegScan, + genSegRed, + genScatter, + exScanAndSum, + doSegIota, + doPrefixSum, + doRepIota, + doPartition, + ) +where + +import Control.Monad (forM, (<=<)) +import Control.Monad.State.Strict +import Data.Foldable (toList) +import Data.Maybe (fromMaybe) +import Data.Text qualified as T +import Futhark.IR.GPU +import Futhark.IR.SOACS +import Futhark.MonadFreshNames +import Futhark.Pass.ExtractKernels.BlockedKernel (mkSegSpace) +import Futhark.Pass.ExtractKernels.ToGPU (soacsLambdaToGPU) +import Futhark.Tools +import Futhark.Util (unsnoc) + +builtinName :: T.Text -> Name +builtinName = nameFromText . ("builtin#" <>) + +segIotaName, repIotaName, prefixSumName, partitionName :: Name +segIotaName = builtinName "segiota" +repIotaName = builtinName "repiota" +prefixSumName = builtinName "prefixsum" +partitionName = builtinName "partition" + +segMap :: (Traversable f) => f SubExp -> (f SubExp -> Builder GPU Result) -> Builder GPU (Exp GPU) +segMap segments f = do + gtids <- traverse (const $ newVName "gtid") segments + space <- mkSegSpace $ zip (toList gtids) (toList segments) + ((res, ts), stms) <- collectStms $ localScope (scopeOfSegSpace space) $ do + res <- f $ fmap Var gtids + ts <- mapM (subExpType . resSubExp) res + pure (map mkResult res, ts) + let kbody = KernelBody () stms res + pure $ Op $ SegOp $ SegMap (SegThread SegVirt Nothing) space ts kbody + where + mkResult (SubExpRes cs se) = Returns ResultMaySimplify cs se + +genScanomap :: (Traversable f) => String -> f SubExp -> Lambda GPU -> [SubExp] -> (f SubExp -> Builder GPU [SubExp]) -> Builder GPU [VName] +genScanomap desc segments lam nes m = do + gtids <- traverse (const $ newVName "gtid") segments + space <- mkSegSpace $ zip (toList gtids) (toList segments) + ((res, res_t), stms) <- runBuilder . localScope (scopeOfSegSpace space) $ do + res <- m $ fmap Var gtids + res_t <- mapM subExpType res + pure (map (Returns ResultMaySimplify mempty) res, res_t) + let kbody = KernelBody () stms res + op = SegBinOp Commutative lam nes mempty + letTupExp desc $ Op $ SegOp $ SegScan lvl space [op] res_t kbody + where + lvl = SegThread SegVirt Nothing + +genScan :: (Traversable f) => String -> f SubExp -> Lambda GPU -> [SubExp] -> [VName] -> Builder GPU [VName] +genScan desc segments lam nes arrs = + genScanomap desc segments lam nes $ \gtids -> forM arrs $ \arr -> + letSubExp (baseString arr <> "_elem") =<< eIndex arr (toList $ fmap eSubExp gtids) + +-- Also known as a prescan. +genExScan :: (Traversable f) => String -> f SubExp -> Lambda GPU -> [SubExp] -> [VName] -> Builder GPU [VName] +genExScan desc segments lam nes arrs = + genScanomap desc segments lam nes $ \gtids -> + let Just (outerDims, innerDim) = unsnoc $ toList gtids + in letTupExp' "to_prescan" + =<< eIf + (toExp $ pe64 innerDim .==. 0) + (eBody (map eSubExp nes)) + (eBody (map (`eIndex` (map toExp outerDims ++ [toExp $ pe64 innerDim - 1])) arrs)) + +segScanLambda :: + (MonadBuilder m, BranchType (Rep m) ~ ExtType, LParamInfo (Rep m) ~ Type) => + Lambda (Rep m) -> + m (Lambda (Rep m)) +segScanLambda lam = do + x_flag_p <- newParam "x_flag" $ Prim Bool + y_flag_p <- newParam "y_flag" $ Prim Bool + let ts = lambdaReturnType lam + (xps, yps) = splitAt (length ts) $ lambdaParams lam + mkLambda ([x_flag_p] ++ xps ++ [y_flag_p] ++ yps) $ + bodyBind + =<< eBody + [ eBinOp LogOr (eParam x_flag_p) (eParam y_flag_p), + eIf + (eParam y_flag_p) + (eBody (map eParam yps)) + (pure $ lambdaBody lam) + ] + +genSegScan :: String -> Lambda GPU -> [SubExp] -> VName -> [VName] -> Builder GPU [VName] +genSegScan desc lam nes flags arrs = do + w <- arraySize 0 <$> lookupType flags + lam' <- segScanLambda lam + drop 1 <$> genScan desc [w] lam' (constant False : nes) (flags : arrs) + +genPrefixSum :: String -> VName -> Builder GPU VName +genPrefixSum desc ns = do + ws <- arrayDims <$> lookupType ns + add_lam <- binOpLambda (Add Int64 OverflowUndef) int64 + head <$> genScan desc ws add_lam [intConst Int64 0] [ns] + +genExPrefixSum :: String -> VName -> Builder GPU VName +genExPrefixSum desc ns = do + ws <- arrayDims <$> lookupType ns + add_lam <- binOpLambda (Add Int64 OverflowUndef) int64 + head <$> genExScan desc ws add_lam [intConst Int64 0] [ns] + +genSegPrefixSum :: String -> VName -> VName -> Builder GPU VName +genSegPrefixSum desc flags ns = do + add_lam <- binOpLambda (Add Int64 OverflowUndef) int64 + head <$> genSegScan desc add_lam [intConst Int64 0] flags [ns] + +genScatter :: VName -> SubExp -> (SubExp -> Builder GPU (VName, SubExp)) -> Builder GPU (Exp GPU) +genScatter dest n f = do + gtid <- newVName "gtid" + space <- mkSegSpace [(gtid, n)] + ((res, v_t), stms) <- collectStms $ localScope (scopeOfSegSpace space) $ do + (i, v) <- f $ Var gtid + dest_t <- lookupType dest + pure (WriteReturns mempty dest [(Slice [DimFix (Var i)], v)], dest_t) + let kbody = KernelBody () stms [res] + pure $ Op $ SegOp $ SegMap (SegThread SegVirt Nothing) space [v_t] kbody + +genTabulate :: SubExp -> (SubExp -> Builder GPU [SubExp]) -> Builder GPU (Exp GPU) +genTabulate w m = do + gtid <- newVName "gtid" + space <- mkSegSpace [(gtid, w)] + ((res, ts), stms) <- collectStms $ localScope (scopeOfSegSpace space) $ do + ses <- m $ Var gtid + ts <- mapM subExpType ses + pure (map (Returns ResultMaySimplify mempty) ses, ts) + let kbody = KernelBody () stms res + pure $ Op $ SegOp $ SegMap (SegThread SegVirt Nothing) space ts kbody + +genFlags :: SubExp -> VName -> Builder GPU VName +genFlags m offsets = do + flags_allfalse <- + letExp "flags_allfalse" . BasicOp $ + Replicate (Shape [m]) (constant False) + n <- arraySize 0 <$> lookupType offsets + letExp "flags" <=< genScatter flags_allfalse n $ \gtid -> do + i <- letExp "i" =<< eIndex offsets [eSubExp gtid] + pure (i, constant True) + +genSegRed :: VName -> VName -> VName -> [VName] -> Reduce SOACS -> Builder GPU [VName] +genSegRed segments flags offsets elems red = do + scanned <- + genSegScan + "red" + (soacsLambdaToGPU $ redLambda red) + (redNeutral red) + flags + elems + num_segments <- arraySize 0 <$> lookupType offsets + letTupExp "segred" <=< genTabulate num_segments $ \i -> do + n <- letSubExp "n" =<< eIndex segments [eSubExp i] + offset <- letSubExp "offset" =<< eIndex offsets [toExp (pe64 i)] + letTupExp' "segment_res" <=< eIf (toExp $ pe64 n .==. 0) (eBody $ map eSubExp nes) $ + eBody $ + map (`eIndex` [toExp $ pe64 offset + pe64 n - 1]) scanned + where + nes = redNeutral red + +-- Returns (#segments, segment start offsets, sum of segment sizes) +-- Note: If given a multi-dimensional array, +-- `#segments` and `sum of segment sizes` will be arrays, not scalars. +-- `segment start offsets` will always have the same shape as `ks`. +exScanAndSum :: VName -> Builder GPU (SubExp, VName, SubExp) +exScanAndSum ks = do + ns <- arrayDims <$> lookupType ks + -- If `ks` only has a single dimension + -- the size will be a scalar, otherwise it's an array. + ns' <- letExp "ns" $ BasicOp $ case ns of + [] -> error $ "exScanAndSum: Given non-array argument: " ++ prettyString ks + [n] -> SubExp n + _ -> ArrayLit ns (Prim int64) + -- Check if the innermost dimension is empty. + is_empty <- + letExp "is_empty" + =<< ( case ns of + [n] -> toExp (pe64 n .==. 0) + _ -> eLast ns' >>= letSubExp "n" >>= (\n -> toExp $ pe64 n .==. 0) + ) + offsets <- letExp "offsets" =<< toExp =<< genExPrefixSum "offsets" ks + ms <- letExp "ms" <=< segMap (init ns) $ \gtids -> do + let idxs = map toExp gtids + offset <- letExp "offset" =<< eIndex offsets idxs + k <- letExp "k" =<< eIndex ks idxs + m <- + letSubExp "m" + =<< eIf + (toExp is_empty) + (eBody [eSubExp $ intConst Int64 0]) + -- Add last size because 'offsets' is an *exclusive* prefix + -- sum. + (eBody [eBinOp (Add Int64 OverflowUndef) (eLast offset) (eLast k)]) + pure [subExpRes m] + pure (Var ns', offsets, Var ms) + +genSegIota :: VName -> Builder GPU (VName, VName, VName) +genSegIota ks = do + (_n, offsets, m) <- exScanAndSum ks + flags <- genFlags m offsets + ones <- letExp "ones" $ BasicOp $ Replicate (Shape [m]) one + iotas <- genSegPrefixSum "iotas" flags ones + res <- letExp "res" <=< genTabulate m $ \i -> do + x <- letSubExp "x" =<< eIndex iotas [eSubExp i] + letTupExp' "xm1" $ BasicOp $ BinOp (Sub Int64 OverflowUndef) x one + pure (flags, offsets, res) + where + one = intConst Int64 1 + +genRepIota :: VName -> Builder GPU (VName, VName, VName) +genRepIota ks = do + (n, offsets, m) <- exScanAndSum ks + is <- letExp "is" <=< genTabulate n $ \i -> do + o <- letSubExp "o" =<< eIndex offsets [eSubExp i] + k <- letSubExp "n" =<< eIndex ks [eSubExp i] + letTupExp' "i" + =<< eIf + (toExp (pe64 k .==. 0)) + (eBody [eSubExp negone]) + (eBody [toExp $ pe64 o]) + zeroes <- letExp "zeroes" $ BasicOp $ Replicate (Shape [m]) zero + starts <- + letExp "starts" <=< genScatter zeroes n $ \gtid -> do + i <- letExp "i" =<< eIndex is [eSubExp gtid] + pure (i, gtid) + flags <- letExp "flags" <=< genTabulate m $ \i -> do + x <- letSubExp "x" =<< eIndex starts [eSubExp i] + letTupExp' "nonzero" =<< toExp (pe64 x .>. 0) + res <- genSegPrefixSum "res" flags starts + pure (flags, offsets, res) + where + zero = intConst Int64 0 + negone = intConst Int64 (-1) + +genPartition :: VName -> VName -> VName -> Builder GPU (VName, VName, VName) +genPartition n k cls = do + let n' = Var n + let k' = Var k + let dims = [k', n'] + -- Create a `[k][n]` array of flags such that `cls_flags[i][j]` + -- is equal 1 if the j'th element is a member of equivalence class `i` i.e. + -- the `i`th row is a flag array for equivalence class `i`. + cls_flags <- + letExp "flags" + <=< segMap dims + $ \[i, j] -> do + c <- letSubExp "c" =<< eIndex cls [toExp j] + cls_flag <- + letSubExp "cls_flag" + =<< eIf + (toExp $ pe64 i .==. pe64 c) + (eBody [toExp $ intConst Int64 1]) + (eBody [toExp $ intConst Int64 0]) + pure [subExpRes cls_flag] + + -- Offsets of each of the individual equivalence classes. + (_, local_offs, _counts) <- exScanAndSum cls_flags + -- The number of elems in each class + counts <- letExp "counts" =<< toExp _counts + -- Offsets of the whole equivalence classes + global_offs <- genExPrefixSum "global_offs" counts + -- Offsets over all of the equivalence classes. + cls_offs <- + letExp "cls_offs" =<< do + segMap dims $ \[i, j] -> do + global_offset <- letExp "global_offset" =<< eIndex global_offs [toExp i] + offset <- + letSubExp "offset" + =<< eBinOp + (Add Int64 OverflowUndef) + (eIndex local_offs [toExp i, toExp j]) + (toExp global_offset) + pure [subExpRes offset] + + scratch <- letExp "scratch" $ BasicOp $ Scratch int64 [n'] + res <- letExp "scatter_res" <=< genScatter scratch n' $ \gtid -> do + c <- letExp "c" =<< eIndex cls [toExp gtid] + ind <- letExp "ind" =<< eIndex cls_offs [toExp c, toExp gtid] + i <- letSubExp "i" =<< toExp gtid + pure (ind, i) + pure (counts, global_offs, res) + +buildingBuiltin :: Builder GPU (FunDef GPU) -> FunDef GPU +buildingBuiltin m = fst $ evalState (runBuilderT m mempty) blankNameSource + +segIotaBuiltin :: FunDef GPU +segIotaBuiltin = buildingBuiltin $ do + np <- newParam "n" $ Prim int64 + nsp <- newParam "ns" $ Array int64 (Shape [Var (paramName np)]) Nonunique + body <- + localScope (scopeOfFParams [np, nsp]) . buildBody_ $ do + (flags, offsets, res) <- genSegIota (paramName nsp) + m <- arraySize 0 <$> lookupType res + pure $ subExpsRes [m, Var flags, Var offsets, Var res] + pure + FunDef + { funDefEntryPoint = Nothing, + funDefAttrs = mempty, + funDefName = segIotaName, + funDefRetType = + map + (,mempty) + [ Prim int64, + Array Bool (Shape [Ext 0]) Unique, + Array int64 (Shape [Free $ Var $ paramName np]) Unique, + Array int64 (Shape [Ext 0]) Unique + ], + funDefParams = [np, nsp], + funDefBody = body + } + +repIotaBuiltin :: FunDef GPU +repIotaBuiltin = buildingBuiltin $ do + np <- newParam "n" $ Prim int64 + nsp <- newParam "ns" $ Array int64 (Shape [Var (paramName np)]) Nonunique + body <- + localScope (scopeOfFParams [np, nsp]) . buildBody_ $ do + (flags, offsets, res) <- genRepIota (paramName nsp) + m <- arraySize 0 <$> lookupType res + pure $ subExpsRes [m, Var flags, Var offsets, Var res] + pure + FunDef + { funDefEntryPoint = Nothing, + funDefAttrs = mempty, + funDefName = repIotaName, + funDefRetType = + map + (,mempty) + [ Prim int64, + Array Bool (Shape [Ext 0]) Unique, + Array int64 (Shape [Free $ Var $ paramName np]) Unique, + Array int64 (Shape [Ext 0]) Unique + ], + funDefParams = [np, nsp], + funDefBody = body + } + +prefixSumBuiltin :: FunDef GPU +prefixSumBuiltin = buildingBuiltin $ do + np <- newParam "n" $ Prim int64 + nsp <- newParam "ns" $ Array int64 (Shape [Var (paramName np)]) Nonunique + body <- + localScope (scopeOfFParams [np, nsp]) . buildBody_ $ + varsRes . pure <$> genPrefixSum "res" (paramName nsp) + pure + FunDef + { funDefEntryPoint = Nothing, + funDefAttrs = mempty, + funDefName = prefixSumName, + funDefRetType = + [(Array int64 (Shape [Free $ Var $ paramName np]) Unique, mempty)], + funDefParams = [np, nsp], + funDefBody = body + } + +partitionBuiltin :: FunDef GPU +partitionBuiltin = buildingBuiltin $ do + np <- newParam "n" $ Prim int64 + kp <- newParam "k" $ Prim int64 + csp <- newParam "cs" $ Array int64 (Shape [Var (paramName np)]) Nonunique + body <- + localScope (scopeOfFParams [np, kp, csp]) . buildBody_ $ do + (counts, offsets, res) <- genPartition (paramName np) (paramName kp) (paramName csp) + pure $ varsRes [counts, offsets, res] + pure + FunDef + { funDefEntryPoint = Nothing, + funDefAttrs = mempty, + funDefName = partitionName, + funDefRetType = + map + (,mempty) + [ Array int64 (Shape [Free $ Var $ paramName kp]) Unique, + Array int64 (Shape [Free $ Var $ paramName kp]) Unique, + Array int64 (Shape [Free $ Var $ paramName np]) Unique + ], + funDefParams = [np, kp, csp], + funDefBody = body + } + +-- | Builtin functions used in flattening. Must be prepended to a +-- program that is transformed by flattening. The intention is to +-- avoid the code explosion that would result if we inserted +-- primitives everywhere. +flatteningBuiltins :: [FunDef GPU] +flatteningBuiltins = [segIotaBuiltin, repIotaBuiltin, prefixSumBuiltin, partitionBuiltin] + +-- | @[0,1,2,0,1,0,1,2,3,4,...]@. Returns @(flags,offsets,elems)@. +doSegIota :: VName -> Builder GPU (VName, VName, VName) +doSegIota ns = do + ns_t <- lookupType ns + let n = arraySize 0 ns_t + m <- newVName "m" + flags <- newVName "segiota_flags" + offsets <- newVName "segiota_offsets" + elems <- newVName "segiota_elems" + let args = [(n, Prim int64), (Var ns, ns_t)] + restype = + fromMaybe (error "doSegIota: bad application") $ + applyRetType + (map fst $ funDefRetType segIotaBuiltin) + (funDefParams segIotaBuiltin) + args + letBindNames [m, flags, offsets, elems] $ + Apply + (funDefName segIotaBuiltin) + [(n, Observe), (Var ns, Observe)] + (map (,mempty) restype) + (Safe, mempty, mempty) + pure (flags, offsets, elems) + +-- | Produces @[0,0,0,1,1,2,2,2,...]@. Returns @(flags, offsets, +-- elems)@. +doRepIota :: VName -> Builder GPU (VName, VName, VName) +doRepIota ns = do + ns_t <- lookupType ns + let n = arraySize 0 ns_t + m <- newVName "m" + flags <- newVName "repiota_flags" + offsets <- newVName "repiota_offsets" + elems <- newVName "repiota_elems" + let args = [(n, Prim int64), (Var ns, ns_t)] + restype = + fromMaybe (error "doRepIota: bad application") $ + applyRetType + (map fst $ funDefRetType repIotaBuiltin) + (funDefParams repIotaBuiltin) + args + letBindNames [m, flags, offsets, elems] $ + Apply + (funDefName repIotaBuiltin) + [(n, Observe), (Var ns, Observe)] + (map (,mempty) restype) + (Safe, mempty, mempty) + pure (flags, offsets, elems) + +doPrefixSum :: VName -> Builder GPU VName +doPrefixSum ns = do + ns_t <- lookupType ns + let n = arraySize 0 ns_t + letExp "prefix_sum" $ + Apply + (funDefName prefixSumBuiltin) + [(n, Observe), (Var ns, Observe)] + [(toDecl (staticShapes1 ns_t) Unique, mempty)] + (Safe, mempty, mempty) + +doPartition :: VName -> VName -> Builder GPU (VName, VName, VName) +doPartition k cs = do + cs_t <- lookupType cs + let n = arraySize 0 cs_t + counts <- newVName "partition_counts" + offsets <- newVName "partition_offsets" + res <- newVName "partition_res" + let args = [(n, Prim int64), (Var k, Prim int64), (Var cs, cs_t)] + restype = + fromMaybe (error "doPartition: bad application") $ + applyRetType + (map fst $ funDefRetType partitionBuiltin) + (funDefParams partitionBuiltin) + args + letBindNames [counts, offsets, res] $ + Apply + (funDefName partitionBuiltin) + [(n, Observe), (Var k, Observe), (Var cs, Observe)] + (map (,mempty) restype) + (Safe, mempty, mempty) + pure (counts, offsets, res) diff --git a/src/Futhark/Pass/Flatten/Distribute.hs b/src/Futhark/Pass/Flatten/Distribute.hs new file mode 100644 index 0000000000..de02127115 --- /dev/null +++ b/src/Futhark/Pass/Flatten/Distribute.hs @@ -0,0 +1,261 @@ +module Futhark.Pass.Flatten.Distribute + ( distributeMap, + distributeBody, + MapArray (..), + mapArrayRowType, + DistResults (..), + DistRep, + ResMap, + Distributed (..), + DistStm (..), + DistInput (..), + DistInputs, + DistType (..), + distInputType, + DistResult (..), + ResTag (..), + ) +where + +import Data.Bifunctor +import Data.List qualified as L +import Data.Map qualified as M +import Data.Maybe +import Futhark.IR.SOACS +import Futhark.Util (nubOrd) +import Futhark.Util.Pretty + +newtype ResTag = ResTag Int + deriving (Eq, Ord, Show) + +-- | Something that is mapped. +data DistInput + = -- | A value bound outside the original map nest. By necessity + -- regular. The type is the parameter type. + DistInputFree VName Type + | -- | A value constructed inside the original map nest. May be + -- irregular. + DistInput ResTag Type + deriving (Eq, Ord, Show) + +type DistInputs = [(VName, DistInput)] + +-- | The type of a 'DistInput'. This corresponds to the parameter +-- type of the original map nest. +distInputType :: DistInput -> Type +distInputType (DistInputFree _ t) = t +distInputType (DistInput _ t) = t + +data DistType + = DistType + -- | Outer regular size. + SubExp + -- | Irregular dimensions on top (but after the leading regular + -- size). + Rank + -- | The regular "element type" - in the worst case, at least a + -- scalar. + Type + deriving (Eq, Ord, Show) + +data DistResult = DistResult {distResTag :: ResTag, distResType :: DistType} + deriving (Eq, Ord, Show) + +data DistStm = DistStm + { distStmInputs :: DistInputs, + distStmResult :: [DistResult], + distStm :: Stm SOACS + } + deriving (Eq, Ord, Show) + +-- | First element of tuple are certificates for this result. +-- +-- Second is the name to which is should be bound. +-- +-- Third is the element type (i.e. excluding shape of segments). +type ResMap = M.Map ResTag ([DistInput], VName, Type) + +-- | The results of a map-distribution that were free or identity +-- mapped in the original map function. These correspond to plain +-- replicated arrays. +type DistRep = (VName, Either SubExp DistInput) + +data DistResults = DistResults ResMap [DistRep] + deriving (Eq, Ord, Show) + +data Distributed = Distributed [DistStm] DistResults + deriving (Eq, Ord, Show) + +instance Pretty ResTag where + pretty (ResTag x) = "r" <> pretty x + +instance Pretty DistInput where + pretty (DistInputFree v _) = pretty v + pretty (DistInput rt _) = pretty rt + +instance Pretty DistType where + pretty (DistType w r t) = + brackets (pretty w) <> pretty r <> pretty t + +instance Pretty DistResult where + pretty (DistResult rt t) = + pretty rt <> colon <+> pretty t + +instance Pretty DistStm where + pretty (DistStm inputs res stm) = + "let" <+> ppTuple' (map pretty res) <+> "=" indent 2 stm' + where + stm' = + "map" + <+> nestedBlock + "{" + "}" + ( stack $ + map onInput inputs + ++ [ pretty stm, + "return" <+> ppTuple' (map pretty res) + ] + ) + onInput (v, inp) = + "for" + <+> parens (pretty v <> colon <+> pretty (distInputType inp)) + <+> "<-" + <+> pretty inp + +instance Pretty Distributed where + pretty (Distributed stms (DistResults resmap reps)) = + stms' res' + where + res' = stack $ map onRes (M.toList resmap) <> map onRep reps + stms' = stack $ map pretty stms + onRes (rt, v) = "let" <+> pretty v <+> "=" <+> pretty rt + onRep (v, Left se) = + "let" <+> pretty v <+> "=" <+> "rep" <> parens (pretty se) + onRep (v, Right tag) = + "let" <+> pretty v <+> "=" <+> "rep" <> parens (pretty tag) + +resultMap :: [(VName, DistInput)] -> [DistStm] -> Pat Type -> Result -> ResMap +resultMap avail_inputs stms pat res = mconcat $ map f stms + where + f stm = + foldMap g $ zip (distStmResult stm) (patElems (stmPat (distStm stm))) + g (DistResult rt _, pe) = + maybe mempty (M.singleton rt) $ findRes pe + findRes (PatElem v v_t) = do + (SubExpRes cs _, pv) <- + L.find ((Var v ==) . resSubExp . fst) $ zip res $ patNames pat + Just (map findCert $ unCerts cs, pv, v_t) + findCert v = fromMaybe (DistInputFree v (Prim Unit)) $ lookup v avail_inputs + +splitIrregDims :: Names -> Type -> (Rank, Type) +splitIrregDims bound_outside (Array pt shape u) = + let (reg, irreg) = + first reverse $ span regDim $ reverse $ shapeDims shape + in (Rank $ length irreg, Array pt (Shape reg) u) + where + regDim (Var v) = v `nameIn` bound_outside + regDim Constant {} = True +splitIrregDims _ t = (mempty, t) + +freeInput :: [(VName, DistInput)] -> VName -> Maybe (VName, DistInput) +freeInput avail_inputs v = + (v,) <$> lookup v avail_inputs + +patInput :: ResTag -> PatElem Type -> (VName, DistInput) +patInput tag pe = + (patElemName pe, DistInput tag $ patElemType pe) + +distributeBody :: + Scope rep -> + SubExp -> + DistInputs -> + Body SOACS -> + (DistInputs, [DistStm]) +distributeBody outer_scope w param_inputs body = + let ((_, avail_inputs), stms) = + L.mapAccumL distributeStm (ResTag (length param_inputs), param_inputs) $ + stmsToList $ + bodyStms body + in (avail_inputs, stms) + where + bound_outside = namesFromList $ M.keys outer_scope + distType t = uncurry (DistType w) $ splitIrregDims bound_outside t + distributeStm (ResTag tag, avail_inputs) stm = + let pat = stmPat stm + new_tags = map ResTag $ take (patSize pat) [tag ..] + avail_inputs' = + avail_inputs <> zipWith patInput new_tags (patElems pat) + free_in_stm = freeIn stm + used_free = mapMaybe (freeInput avail_inputs) $ namesToList free_in_stm + used_free_types = + mapMaybe (freeInput avail_inputs) + . namesToList + . foldMap (freeIn . distInputType . snd) + $ used_free + stm' = + DistStm + (nubOrd $ used_free_types <> used_free) + (zipWith DistResult new_tags $ map distType $ patTypes pat) + stm + in ((ResTag $ tag + length new_tags, avail_inputs'), stm') + +-- | The input we are mapping over in 'distributeMap'. +data MapArray t + = -- | A straightforward array passed in to a + -- top-level map. + MapArray VName Type + | -- | Something more exotic - distribution will assign it a + -- 'ResTag', but not do anything else. This is used to + -- distributed nested maps whose inputs are produced in the outer + -- nests. + MapOther t Type + +mapArrayRowType :: MapArray t -> Type +mapArrayRowType (MapArray _ t) = t +mapArrayRowType (MapOther _ t) = t + +-- This is used to handle those results that are constants or lambda +-- parameters. +findReps :: [(VName, DistInput)] -> Pat Type -> Lambda SOACS -> [DistRep] +findReps avail_inputs map_pat lam = + mapMaybe f $ zip (patElems map_pat) (bodyResult (lambdaBody lam)) + where + f (pe, SubExpRes _ (Var v)) = + case lookup v avail_inputs of + Nothing -> Just (patElemName pe, Left $ Var v) + Just inp + | v `elem` map paramName (lambdaParams lam) -> + Just (patElemName pe, Right inp) + | otherwise -> Nothing + f (pe, SubExpRes _ (Constant v)) = do + Just (patElemName pe, Left $ Constant v) + +distributeMap :: + Scope rep -> + Pat Type -> + SubExp -> + [MapArray t] -> + Lambda SOACS -> + (Distributed, M.Map ResTag t) +distributeMap outer_scope map_pat w arrs lam = + let ((_, arrmap), param_inputs) = + L.mapAccumL paramInput (ResTag 0, mempty) $ + zip (lambdaParams lam) arrs + (avail_inputs, stms) = + distributeBody outer_scope w param_inputs $ lambdaBody lam + resmap = + resultMap avail_inputs stms map_pat $ + bodyResult (lambdaBody lam) + reps = findReps avail_inputs map_pat lam + in ( Distributed stms $ DistResults resmap reps, + arrmap + ) + where + paramInput (ResTag i, m) (p, MapArray arr _) = + ( (ResTag i, m), + (paramName p, DistInputFree arr $ paramType p) + ) + paramInput (ResTag i, m) (p, MapOther x _) = + ( (ResTag (i + 1), M.insert (ResTag i) x m), + (paramName p, DistInput (ResTag i) $ paramType p) + ) diff --git a/src/Futhark/Passes.hs b/src/Futhark/Passes.hs index 0cf122a5a8..66833f9c48 100644 --- a/src/Futhark/Passes.hs +++ b/src/Futhark/Passes.hs @@ -38,9 +38,9 @@ import Futhark.Pass.ExpandAllocations import Futhark.Pass.ExplicitAllocations.GPU qualified as GPU import Futhark.Pass.ExplicitAllocations.MC qualified as MC import Futhark.Pass.ExplicitAllocations.Seq qualified as Seq -import Futhark.Pass.ExtractKernels import Futhark.Pass.ExtractMulticore import Futhark.Pass.FirstOrderTransform +import Futhark.Pass.Flatten import Futhark.Pass.LiftAllocations as LiftAllocations import Futhark.Pass.LowerAllocations as LowerAllocations import Futhark.Pass.Simplify @@ -84,7 +84,7 @@ adPipeline = gpuPipeline :: Pipeline SOACS GPU gpuPipeline = standardPipeline - >>> onePass extractKernels + >>> onePass flattenSOACs >>> passes [ simplifyGPU, optimiseGenRed, diff --git a/src/Futhark/Util.hs b/src/Futhark/Util.hs index 16fdba1ef5..41ea48ae12 100644 --- a/src/Futhark/Util.hs +++ b/src/Futhark/Util.hs @@ -22,6 +22,7 @@ module Futhark.Util partitionMaybe, maybeNth, maybeHead, + unsnoc, lookupWithIndex, splitFromEnd, splitAt3, @@ -189,6 +190,12 @@ maybeHead :: [a] -> Maybe a maybeHead [] = Nothing maybeHead (x : _) = Just x +-- | Split the last element from the list, if it exists. +unsnoc :: [a] -> Maybe ([a], a) +unsnoc [] = Nothing +unsnoc [x] = Just ([], x) +unsnoc (x : xs) = unsnoc xs >>= \(ys, y) -> Just (x : ys, y) + -- | Lookup a value, returning also the index at which it appears. lookupWithIndex :: (Eq a) => a -> [(a, b)] -> Maybe (Int, b) lookupWithIndex needle haystack = diff --git a/tests/flattening/CosminArrayExample.fut b/tests/flattening/CosminArrayExample.fut deleted file mode 100644 index c031ac09e8..0000000000 --- a/tests/flattening/CosminArrayExample.fut +++ /dev/null @@ -1,17 +0,0 @@ --- Problem here is that we need will distribute the map --- let arrs = map (\x -> iota(2*x)) xs --- let arr's = map (\x arr -> reshape( (x,2), arr) $ zip xs arrs --- let res = map(\arr' -> reduce(op(+), 0, arr')) arr's --- == --- input { --- [ 1i64, 2i64, 3i64, 4i64] --- } --- output { --- [1i64, 6i64, 15i64, 28i64] --- } -def main (xs: []i64): []i64 = - map (\(x: i64) -> - let arr = #[unsafe] 0..<(2 * x) - let arr' = #[unsafe] unflatten arr in - reduce (+) 0 (arr'[0]) + reduce (+) 0 (arr'[1]) - ) xs diff --git a/tests/flattening/HighlyNestedMap.fut b/tests/flattening/HighlyNestedMap.fut deleted file mode 100644 index 42ea0087a5..0000000000 --- a/tests/flattening/HighlyNestedMap.fut +++ /dev/null @@ -1,41 +0,0 @@ --- == --- input { --- [ [ [ [1,2,3], [4,5,6] ] --- , [ [6,7,8], [9,10,11] ] --- ] --- , [ [ [3,2,1], [4,5,6] ] --- , [ [8,7,6], [11,10,9] ] --- ] --- ] --- [ [ [ [4,5,6] , [1,2,3] ] --- , [ [9,10,11], [6,7,8] ] --- ] --- , [ [ [4,5,6] , [3,2,1] ] --- , [ [11,10,9], [8,7,6] ] --- ] --- ] --- } --- output { --- [[[[5, 7, 9], --- [5, 7, 9]], --- [[15, 17, 19], --- [15, 17, 19]]], --- [[[7, 7, 7], --- [7, 7, 7]], --- [[19, 17, 15], --- [19, 17, 15]]]] --- } -def add1 [n] (xs: [n]i32, ys: [n]i32): [n]i32 = - map2 (+) xs ys - -def add2 [n][m] (xs: [n][m]i32, ys: [n][m]i32): [n][m]i32 = - map add1 (zip xs ys) - -def add3 [n][m][l] (xs: [n][m][l]i32, ys: [n][m][l]i32): [n][m][l]i32 = - map add2 (zip xs ys) - -def add4 (xs: [][][][]i32, ys: [][][][]i32): [][][][]i32 = - map add3 (zip xs ys) - -def main (a: [][][][]i32) (b: [][][][]i32): [][][][]i32 = - add4(a,b) diff --git a/tests/flattening/IntmRes1.fut b/tests/flattening/IntmRes1.fut deleted file mode 100644 index 5c90367bc9..0000000000 --- a/tests/flattening/IntmRes1.fut +++ /dev/null @@ -1,23 +0,0 @@ --- == --- input { --- [ [1,2,3], [4,5,6] --- , [6,7,8], [9,10,11] --- ] --- [1,2,3,4] --- 5 --- } --- output { --- [[7, 8, 9], --- [16, 17, 18], --- [24, 25, 26], --- [33, 34, 35]] --- } -def addToRow [n] (xs: [n]i32, y: i32): [n]i32 = - map (\(x: i32): i32 -> x+y) xs - -def main (xss: [][]i32) (cs: []i32) (y: i32): [][]i32 = - map (\(xs: []i32, c: i32) -> - let y' = y * c + c - let zs = addToRow(xs,y') in - zs - ) (zip xss cs) diff --git a/tests/flattening/IntmRes2.fut b/tests/flattening/IntmRes2.fut deleted file mode 100644 index 8f4f1bd5cd..0000000000 --- a/tests/flattening/IntmRes2.fut +++ /dev/null @@ -1,30 +0,0 @@ --- == --- input { --- [ [ [1,2,3], [4,5,6] ] --- , [ [6,7,8], [9,10,11] ] --- , [ [3,2,1], [4,5,6] ] --- , [ [8,7,6], [11,10,9] ] --- ] --- [1,2,3,4] --- 5 --- } --- output { --- [[[7, 8, 9], --- [10, 11, 12]], --- [[18, 19, 20], --- [21, 22, 23]], --- [[21, 20, 19], --- [22, 23, 24]], --- [[32, 31, 30], --- [35, 34, 33]]] --- } -def addToRow [n] (xs: [n]i32, y: i32): [n]i32 = - map (\(x: i32): i32 -> x+y) xs - -def main (xsss: [][][]i32) (cs: []i32) (y: i32): [][][]i32 = - map (\(xss: [][]i32, c: i32) -> - let y' = y * c + c in - map (\(xs: []i32) -> - addToRow(xs,y') - ) xss - ) (zip xsss cs) diff --git a/tests/flattening/IntmRes3.fut b/tests/flattening/IntmRes3.fut deleted file mode 100644 index 230dbf6405..0000000000 --- a/tests/flattening/IntmRes3.fut +++ /dev/null @@ -1,36 +0,0 @@ --- == --- input { --- [ [ [ [1,2,3], [4,5,6] ] --- ] --- , [ [ [6,7,8], [9,10,11] ] --- ] --- , [ [ [3,2,1], [4,5,6] ] --- ] --- , [ [ [8,7,6], [11,10,9] ] --- ] --- ] --- [1,2,3,4] --- 5 --- } --- output { --- [[[[7, 8, 9], --- [10, 11, 12]]], --- [[[18, 19, 20], --- [21, 22, 23]]], --- [[[21, 20, 19], --- [22, 23, 24]]], --- [[[32, 31, 30], --- [35, 34, 33]]]] --- } -def addToRow [n] (xs: [n]i32, y: i32): [n]i32 = - map (\(x: i32): i32 -> x+y) xs - -def main (xssss: [][][][]i32) (cs: []i32) (y: i32): [][][][]i32 = - map (\(xsss: [][][]i32, c: i32) -> - let y' = y * c + c in - map (\(xss: [][]i32) -> - map (\(xs: []i32) -> - addToRow(xs,y') - ) xss - ) xsss - ) (zip xssss cs) diff --git a/tests/flattening/LoopInv1.fut b/tests/flattening/LoopInv1.fut deleted file mode 100644 index c616c16562..0000000000 --- a/tests/flattening/LoopInv1.fut +++ /dev/null @@ -1,24 +0,0 @@ --- == --- input { --- [ [1,2,3], [4,5,6] --- , [6,7,8], [9,10,11] --- , [3,2,1], [4,5,6] --- , [8,7,6], [11,10,9] --- ] --- [1,2,3] --- } --- output { --- [[2, 4, 6], --- [5, 7, 9], --- [7, 9, 11], --- [10, 12, 14], --- [4, 4, 4], --- [5, 7, 9], --- [9, 9, 9], --- [12, 12, 12]] --- } -def addRows [n] (xs: [n]i32, ys: [n]i32): [n]i32 = - map2 (+) xs ys - -def main (xss: [][]i32) (ys: []i32): [][]i32 = - map (\(xs: []i32) -> addRows(xs,ys)) xss diff --git a/tests/flattening/LoopInv2.fut b/tests/flattening/LoopInv2.fut deleted file mode 100644 index 7af5c15362..0000000000 --- a/tests/flattening/LoopInv2.fut +++ /dev/null @@ -1,26 +0,0 @@ --- == --- input { --- [ [ [1,2,3], [4,5,6] ] --- , [ [6,7,8], [9,10,11] ] --- , [ [3,2,1], [4,5,6] ] --- , [ [8,7,6], [11,10,9] ] --- ] --- [1,2,3] --- } --- output { --- [[[2, 4, 6], --- [5, 7, 9]], --- [[7, 9, 11], --- [10, 12, 14]], --- [[4, 4, 4], --- [5, 7, 9]], --- [[9, 9, 9], --- [12, 12, 12]]] --- } -def addRows [n] (xs: [n]i32, ys: [n]i32): [n]i32 = - map2 (+) xs ys - -def main (xsss: [][][]i32) (ys: []i32): [][][]i32 = - map (\(xss: [][]i32) -> - map (\(xs: []i32) -> addRows(xs,ys)) xss - ) xsss diff --git a/tests/flattening/LoopInv3.fut b/tests/flattening/LoopInv3.fut deleted file mode 100644 index 3cffc4dfe9..0000000000 --- a/tests/flattening/LoopInv3.fut +++ /dev/null @@ -1,34 +0,0 @@ --- == --- input { --- [ [ [ [1,2,3], [4,5,6] ] --- ] --- , [ [ [6,7,8], [9,10,11] ] --- ] --- , [ [ [3,2,1], [4,5,6] ] --- ] --- , [ [ [8,7,6], [11,10,9] ] --- ] --- ] --- [1,2,3] --- } --- output { --- [[[[2, 4, 6], --- [5, 7, 9]]], --- [[[7, 9, 11], --- [10, 12, 14]]], --- [[[4, 4, 4], --- [5, 7, 9]]], --- [[[9, 9, 9], --- [12, 12, 12]]]] --- } -def addRows [n] (xs: [n]i32, ys: [n]i32): [n]i32 = - map2 (+) xs ys - -def main (xssss: [][][][]i32) (ys: []i32): [][][][]i32 = - map (\(xsss: [][][]i32) -> - map (\(xss: [][]i32) -> - map (\(xs: []i32) -> - addRows(xs,ys) - ) xss - ) xsss - ) xssss diff --git a/tests/flattening/LoopInvReshape.fut b/tests/flattening/LoopInvReshape.fut deleted file mode 100644 index dd8aab2b2d..0000000000 --- a/tests/flattening/LoopInvReshape.fut +++ /dev/null @@ -1,16 +0,0 @@ --- This example presents difficulty for me right now, but also has a --- large potential for improvement later on. --- --- we could turn it into: --- --- let []i32 bettermain ([]i32 xs, [#n]i32 ys, [#n]i32 zs, [#n]i32 is, [#n]i32 js) = --- map (\i32 (i32 y, i32 z, i32 i, i32 j) -> --- xs[i*z + j] --- , zip(ys,zs,is,js)) - -def main [n][m] (xs: [m]i32, ys: [n]i64, zs: [n]i64, is: [n]i32, js: [n]i32): []i32 = - map (\(y: i64, z: i64, i: i32, j: i32): i32 -> - #[unsafe] - let tmp = unflatten (xs :> [y*z]i32) - in tmp[i,j] - ) (zip4 ys zs is js) diff --git a/tests/flattening/Map-IotaMapReduce.fut b/tests/flattening/Map-IotaMapReduce.fut deleted file mode 100644 index ea6377f1f8..0000000000 --- a/tests/flattening/Map-IotaMapReduce.fut +++ /dev/null @@ -1,14 +0,0 @@ --- == --- input { --- [2,3,4] --- [8,3,2] --- } --- output { --- [8,9,12] --- } -def main [n] (xs: [n]i32) (ys: [n]i32): []i32 = - map (\(x: i32, y: i32): i32 -> - let tmp1 = 0.. - map (\(x: i32): i32 -> - let tmp1 = map i32.i64(iota(i64.i32 x)) - let tmp2 = map (*y) tmp1 in - reduce (+) 0 tmp2 - ) xs - ) (zip xss ys ) diff --git a/tests/flattening/MapIotaReduce.fut b/tests/flattening/MapIotaReduce.fut deleted file mode 100644 index e6840d5bea..0000000000 --- a/tests/flattening/MapIotaReduce.fut +++ /dev/null @@ -1,12 +0,0 @@ --- == --- input { --- [1,2,3,4] --- } --- output { --- [0, 1, 3, 6] --- } -def main (xs: []i32): []i32 = - map (\(x: i32): i32 -> - let tmp = 0.. - reduce (+) 0 xs - ) xss diff --git a/tests/flattening/VectorAddition.fut b/tests/flattening/VectorAddition.fut deleted file mode 100644 index 0b9445729b..0000000000 --- a/tests/flattening/VectorAddition.fut +++ /dev/null @@ -1,10 +0,0 @@ --- == --- input { --- [1,2,3,4] --- [5,6,7,8] --- } --- output { --- [6,8,10,12] --- } -def main (xs: []i32) (ys: []i32): []i32 = - map2 (+) xs ys diff --git a/tests/flattening/binop.fut b/tests/flattening/binop.fut new file mode 100644 index 0000000000..6496c9f5c3 --- /dev/null +++ b/tests/flattening/binop.fut @@ -0,0 +1,5 @@ +-- == +-- input { [1,2,3] [4,5,6] } +-- output { [5,7,9] } + +def main = map2 (i32.+) diff --git a/tests/flattening/concat-check-index.fut b/tests/flattening/concat-check-index.fut new file mode 100644 index 0000000000..8e1f0bb51e --- /dev/null +++ b/tests/flattening/concat-check-index.fut @@ -0,0 +1,27 @@ +-- Validation of flattening with 2 lists +-- +-- == +-- entry: validate_flattening_2 +-- input {[0i64, 1i64, 2i64, 3i64, 4i64, 5i64] +-- [0i64, 2i64, 4i64, 6i64, 8i64, 10i64] +-- [0i64,0i64,2i64,6i64,12i64,20i64,30i64] +-- [0i64,0i64, +-- 0i64,1i64,0i64,1i64, +-- 0i64,1i64,2i64,0i64,1i64,4i64, +-- 0i64,1i64,2i64,3i64,0i64,1i64,4i64,9i64, +-- 0i64,1i64,2i64,3i64,4i64,0i64,1i64,4i64,9i64,16i64 +--]} +-- output {[true, true, true, true, true, true]} +entry validate_flattening_2 (ns: []i64) (shp: []i64) (offsets: []i64) (expected: []i64) : []bool = + map2 (\n i -> + let irreg = opaque (iota n `concat` (iota n |> map (**2))) + in + if shp[i] == 0i64 && length irreg == 0i64 then true + else if shp[i] == 0i64 then false + else + let gts = iota shp[i] |> map (\j -> expected[offsets[i] + j]) :> [n + n]i64 + let pairs = zip irreg gts + let eq = map (\(pd, gt) -> pd == gt) pairs + in + reduce (&&) true eq + )ns (indices ns) diff --git a/tests/flattening/concat-iota.fut b/tests/flattening/concat-iota.fut new file mode 100644 index 0000000000..32b7b9517f --- /dev/null +++ b/tests/flattening/concat-iota.fut @@ -0,0 +1,8 @@ +-- Validation of flattening with 2 lists +-- +-- == +-- entry: validate_flattening +-- input {[0i64, 1i64, 2i64, 3i64, 4i64, 5i64]} +-- output {[0i64, 0i64, 2i64, 6i64, 12i64, 20i64]} +entry validate_flattening (ns: []i64) : []i64 = + map (\n -> i64.sum (opaque (iota n `concat` iota n))) ns \ No newline at end of file diff --git a/tests/flattening/concat-rep.fut b/tests/flattening/concat-rep.fut new file mode 100644 index 0000000000..eb5e650cbf --- /dev/null +++ b/tests/flattening/concat-rep.fut @@ -0,0 +1,8 @@ +-- Validation of flattening with 3 lists +-- +-- == +-- entry: validate_flattening2 +-- input {[0i64, 1i64, 2i64, 3i64, 4i64, 5i64, 10i64, 13i64]} +-- output {[0i64, 1i64, 4i64, 9i64, 16i64, 25i64, 100i64, 169i64]} +entry validate_flattening2 (ns: []i64) : []i64 = + map (\n -> i64.sum (opaque ((replicate n 1) `concat` iota n `concat` iota n))) ns \ No newline at end of file diff --git a/tests/flattening/dup2d.fut b/tests/flattening/dup2d.fut new file mode 100644 index 0000000000..e33a722966 --- /dev/null +++ b/tests/flattening/dup2d.fut @@ -0,0 +1,7 @@ +-- == +-- input { [[1,2,3],[4,5,6]] } +-- auto output + +def dup = replicate 2 >-> transpose >-> flatten + +entry main (z: [][]i32) = z |> map dup |> dup diff --git a/tests/flattening/dup3d.fut b/tests/flattening/dup3d.fut new file mode 100644 index 0000000000..1b8e2a228e --- /dev/null +++ b/tests/flattening/dup3d.fut @@ -0,0 +1,9 @@ +-- Currently fails; an array that is too small is produced somehow. I +-- suspect replication. +-- == +-- input { [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] } +-- auto output + +def dup = replicate 5 >-> transpose >-> flatten + +def main (z: [2][3][4]i32) = z |> map (map dup) |> map dup |> dup diff --git a/tests/flattening/flattening-pipeline b/tests/flattening/flattening-pipeline deleted file mode 100755 index ed91df97eb..0000000000 --- a/tests/flattening/flattening-pipeline +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/sh -futhark -s --flattening -i "$1" diff --git a/tests/flattening/flattening-test b/tests/flattening/flattening-test deleted file mode 100755 index 92bc4de552..0000000000 --- a/tests/flattening/flattening-test +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/sh - -HERE=$(dirname "$0") - -if [ $# -lt 1 ]; then - FILES="$HERE/"*.fut -else - FILES=$* -fi - -futhark-test --only-interpret --interpreter="$HERE/flattening-pipeline" $FILES diff --git a/tests/flattening/function-lifting/func_const.fut b/tests/flattening/function-lifting/func_const.fut new file mode 100644 index 0000000000..d6102298d5 --- /dev/null +++ b/tests/flattening/function-lifting/func_const.fut @@ -0,0 +1,22 @@ +-- Lifting a function with a constants as argument and result +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [7i64, 7i64,10i64,16i64,25i64,37i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let bar (x : i64) (xs : []i64) : ([]i64, i64) = + let ys = map (x*) xs + in (ys, 7) + +#[noinline] +let foo (x : i64) = + let xs = iota x + let (ys, z) = bar 3 xs + in z + reduce (+) 0 ys + +def main (xs : []i64) = map foo xs + diff --git a/tests/flattening/function-lifting/func_free.fut b/tests/flattening/function-lifting/func_free.fut new file mode 100644 index 0000000000..433ba3156d --- /dev/null +++ b/tests/flattening/function-lifting/func_free.fut @@ -0,0 +1,27 @@ +-- Lifting a function with a free variables as argument and result +-- == +-- entry: main +-- input { [ 0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [280i64,294i64,308i64,322i64,336i64,350i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let v1 : []i64 = [5,9,6] + +#[noinline] +let v2 : []i64 = [3,1,4,1,5] + +#[noinline] +let bar (xs : []i64) (y : i64) : (i64, []i64) = + let z = y + reduce (+) 0 xs + in (z, copy v2) + +#[noinline] +let foo (x : i64) = + let (y, zs) = bar v1 x + let z = reduce (+) 0 zs + in (y * z) + +def main (xs : []i64) = map foo xs diff --git a/tests/flattening/function-lifting/func_fully_irreg.fut b/tests/flattening/function-lifting/func_fully_irreg.fut new file mode 100644 index 0000000000..27382dfba8 --- /dev/null +++ b/tests/flattening/function-lifting/func_fully_irreg.fut @@ -0,0 +1,23 @@ +-- Lifting a function with an irregular +-- parameter and return type +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [0i64, 0i64, 0i64, 3i64, 15i64,45i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let bar (xs : []i64) : []i64 = + let y = reduce (+) 0 xs + in iota y + +#[noinline] +let foo (x : i64) = + let xs = iota x + let ys = bar xs + in reduce (+) 0 ys + +def main (xs : []i64) = map foo xs + diff --git a/tests/flattening/function-lifting/func_irreg_input.fut b/tests/flattening/function-lifting/func_irreg_input.fut new file mode 100644 index 0000000000..718895b430 --- /dev/null +++ b/tests/flattening/function-lifting/func_irreg_input.fut @@ -0,0 +1,17 @@ +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [0i64, 0i64, 1i64, 3i64, 6i64,10i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let bar (xs : []i64) : i64 = reduce (+) 0 xs + +#[noinline] +let foo (x : i64) = + let xs = iota x + in bar xs + +def main (xs : []i64) = map foo xs diff --git a/tests/flattening/function-lifting/func_irreg_result.fut b/tests/flattening/function-lifting/func_irreg_result.fut new file mode 100644 index 0000000000..0225943179 --- /dev/null +++ b/tests/flattening/function-lifting/func_irreg_result.fut @@ -0,0 +1,17 @@ +-- == +-- entry: main +-- input { [0i64,1i64,2i64,3i64,4i64, 5i64] } +-- output { [0i64,0i64,1i64,3i64,6i64,10i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let bar (x : i64) : []i64 = iota x + +#[noinline] +let foo (x : i64) = + let xs = bar x + in reduce (+) 0 xs + +def main (xs : []i64) = map foo xs diff --git a/tests/flattening/function-lifting/func_irreg_update.fut b/tests/flattening/function-lifting/func_irreg_update.fut new file mode 100644 index 0000000000..b329759b91 --- /dev/null +++ b/tests/flattening/function-lifting/func_irreg_update.fut @@ -0,0 +1,23 @@ +-- Lifting a function which consumes its argument +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [0i64, 0i64, 0i64, 1i64, 2i64, 4i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let bar [n] (xs : *[n]i64) (z : i64) (ys : [z]i64) : [n]i64 = + let m = n - z + in xs with [m:n] = ys + +#[noinline] +let foo (a : i64) = + let b = a / 2 + let xs = iota a + let ys = iota b :> [b]i64 + let zs = bar xs b ys + in reduce (+) 0 zs + +def main (xs : []i64) = map foo xs diff --git a/tests/flattening/function-lifting/func_mix.fut b/tests/flattening/function-lifting/func_mix.fut new file mode 100644 index 0000000000..cca3be9c17 --- /dev/null +++ b/tests/flattening/function-lifting/func_mix.fut @@ -0,0 +1,25 @@ +-- Lifting a function with both regular and irregular +-- parameters and return types. +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [0i64, 0i64, -1i64, 27i64, 252i64, 1175i64] } +-- input { [5i64, 4i64, 3i64, 2i64, 1i64, 0i64] [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [0i64, 9i64, 9i64, 0i64, 0i64, 0i64] } +-- input { empty([0]i64) empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let bar (y : i64) (xs : []i64) : ([]i64, i64) = + let z = y * reduce (+) 0 xs + in (iota z, z) + +#[noinline] +let foo (a : i64) (b : i64) = + let xs = iota a + let (ys, z) = bar b xs + in reduce (+) 0 ys - z + +def main (as : []i64) (bs : []i64) = map2 foo as bs + diff --git a/tests/flattening/function-lifting/func_mix_nested.fut b/tests/flattening/function-lifting/func_mix_nested.fut new file mode 100644 index 0000000000..0431af1e94 --- /dev/null +++ b/tests/flattening/function-lifting/func_mix_nested.fut @@ -0,0 +1,31 @@ +-- Lifting a function that calls another function +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64] [0i64, 1i64, 2i64, 3i64, 4i64] } +-- output { [0i64, 0i64, 0i64, 52290i64, 21935100i64] } +-- input { [5i64, 4i64, 3i64, 2i64, 1i64, 0i64] [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [0i64, 3990i64, 3990i64, 33i64, 0i64, 0i64] } +-- input { empty([0]i64) empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let baz (xs : []i64) (y : i64) : ([]i64, []i64) = + let z = y * reduce (+) 0 xs + in (iota y, iota z) + +#[noinline] +let bar (y : i64) (xs : []i64) : ([]i64, i64) = + let z = y * reduce (+) 0 xs + let (as, bs) = baz (iota z) z + let a = reduce (+) 0 as + in (bs, a) + +#[noinline] +let foo (a : i64) (b : i64) = + let xs = iota a + let (ys, z) = bar b xs + in reduce (+) 0 ys - z + +def main (as : []i64) (bs : []i64) = map2 foo as bs + diff --git a/tests/flattening/function-lifting/func_simple.fut b/tests/flattening/function-lifting/func_simple.fut new file mode 100644 index 0000000000..22002bbf3e --- /dev/null +++ b/tests/flattening/function-lifting/func_simple.fut @@ -0,0 +1,16 @@ +-- Lifting a simple function +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [1i64, 2i64, 3i64, 4i64, 5i64, 6i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +def bar (x : i64) = x + 1 + +#[noinline] +def foo (x : i64) = bar x + +def main (xs : []i64) = map foo xs diff --git a/tests/flattening/iota-index.fut b/tests/flattening/iota-index.fut new file mode 100644 index 0000000000..a21c5f4096 --- /dev/null +++ b/tests/flattening/iota-index.fut @@ -0,0 +1,10 @@ +-- iota is probably simplified away, but certs must be kept. +-- == +-- input { [1i64,2i64] [0,1] } +-- output { [0i64,1i64] } +-- input { [1i64,2i64] [0,2] } +-- error: out of bounds +-- input { [1i64,-2i64] [0,1] } +-- error: Range 0..1..<-2 is invalid + +def main = map2 (\n (i:i32) -> (iota n)[i]) diff --git a/tests/flattening/iota-opaque-index.fut b/tests/flattening/iota-opaque-index.fut new file mode 100644 index 0000000000..065c55d294 --- /dev/null +++ b/tests/flattening/iota-opaque-index.fut @@ -0,0 +1,9 @@ +-- == +-- input { [1i64,2i64] [0,1] } +-- output { [0i64,1i64] } +-- input { [1i64,2i64] [0,2] } +-- error: out of bounds +-- input { [1i64,-2i64] [0,1] } +-- error: Range 0..1..<-2 is invalid + +def main = map2 (\n (i:i32) -> (opaque (iota n))[i]) diff --git a/tests/flattening/iota-opaque-slice-red.fut b/tests/flattening/iota-opaque-slice-red.fut new file mode 100644 index 0000000000..56b1218c9c --- /dev/null +++ b/tests/flattening/iota-opaque-slice-red.fut @@ -0,0 +1,11 @@ +-- == +-- input { [1i64,2i64] [0i64,1i64] } +-- output { [0i64,1i64] } +-- input { [1i64,5i64] [0i64,3i64] } +-- output { [0i64,7i64] } +-- input { [1i64,2i64] [0i64,3i64] } +-- error: out of bounds +-- input { [1i64,-2i64] [0i64,1i64] } +-- error: Range 0..1..<-2 is invalid + +def main = map2 (\n (i:i64) -> i64.sum (opaque (iota n))[i:]) diff --git a/tests/flattening/iota-red.fut b/tests/flattening/iota-red.fut new file mode 100644 index 0000000000..ba2d5ea6fa --- /dev/null +++ b/tests/flattening/iota-red.fut @@ -0,0 +1,7 @@ +-- == +-- input { [0i64,1i64,2i64] } +-- output { [0i64, 0i64, 1i64] } +-- input { [0i64,1i64,-2i64] } +-- error: Range 0..1..<-2 is invalid + +def main = map (\n -> i64.sum (iota n)) diff --git a/tests/flattening/map-nested-deeper.fut b/tests/flattening/map-nested-deeper.fut new file mode 100644 index 0000000000..f941f80e24 --- /dev/null +++ b/tests/flattening/map-nested-deeper.fut @@ -0,0 +1,9 @@ +-- == +-- input { [5i64,7i64] [[5],[7]] } +-- output { [7,9] } + +def main = map2 (\n xs -> + #[unsafe] + let A = #[opaque] replicate n xs + let B = #[opaque] map (\x -> (opaque x)[0]+2i32) A + in B[0]) diff --git a/tests/flattening/map-nested-free2d.fut b/tests/flattening/map-nested-free2d.fut new file mode 100644 index 0000000000..57af621dd1 --- /dev/null +++ b/tests/flattening/map-nested-free2d.fut @@ -0,0 +1,9 @@ +-- == +-- input { [5i64,7i64] [5i64,7i64] [3i64,2i64] } +-- output { [3i64, 2i64] } + +def main = map3 (\n m x -> + #[unsafe] + let A = #[opaque] replicate n (replicate m x) + let B = #[opaque] map (\i -> A[i%x,i%x]) (iota n) + in B[0]) diff --git a/tests/flattening/map-nested.fut b/tests/flattening/map-nested.fut new file mode 100644 index 0000000000..3942a7868d --- /dev/null +++ b/tests/flattening/map-nested.fut @@ -0,0 +1,5 @@ +-- == +-- input { [5i64,7i64] } +-- output { [20i64, 35i64] } + +def main = map (\n -> i64.sum (map (+2) (iota n))) diff --git a/tests/flattening/map-slice-nested.fut b/tests/flattening/map-slice-nested.fut new file mode 100644 index 0000000000..0b01ac7880 --- /dev/null +++ b/tests/flattening/map-slice-nested.fut @@ -0,0 +1,5 @@ +-- == +-- input { [1i64,2i64,3i64,4i64,5i64] [-5i64,7i64] [2i64,3i64] [3i64,4i64] } +-- output { [-2i64, 11i64] } + +def main A = map3 (\x i j -> i64.sum (map (+x) A[i:j])) diff --git a/tests/flattening/mapout.fut b/tests/flattening/mapout.fut new file mode 100644 index 0000000000..67ee76a39c --- /dev/null +++ b/tests/flattening/mapout.fut @@ -0,0 +1,11 @@ +-- A redomap where part of the result is not reduced. +-- == +-- input { [5i64,7i64] [0i64,1i64] } +-- output { [20i64, 35i64] [0i64, 1i64] } + +def main ns is = map2 (\n (i:i64) -> let is = iota n + let xs = map (+2) is + let ys = map (*i) is + in (i64.sum xs, (opaque ys)[i])) + ns is + |> unzip diff --git a/tests/flattening/match-case/if.fut b/tests/flattening/match-case/if.fut new file mode 100644 index 0000000000..cc1594d32e --- /dev/null +++ b/tests/flattening/match-case/if.fut @@ -0,0 +1,17 @@ +-- == +-- entry: main +-- nobench input { [-1i64,1i64,-2i64,2i64,-3i64,3i64] } +-- output { [ 1i64,2i64, 4i64,4i64, 9i64,6i64] } +-- nobench input { [-5i64,-3i64,4i64,2i64,0i64,-1i64,3i64,1i64] } +-- output { [25i64, 9i64,8i64,4i64,0i64, 1i64,6i64,2i64] } +-- nobench input { [ 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [ 2i64, 4i64, 6i64, 8i64,10i64] } +-- nobench input { [-1i64,-2i64,-3i64,-4i64,-5i64] } +-- output { [ 1i64, 4i64, 9i64,16i64,25i64] } +-- nobench input { empty([0]i64) } +-- output { empty([0]i64) } + +#[noinline] +let foo (x : i64) = if x < 0 then x * x else x * 2 + +def main [n] (xs : [n]i64) = map foo xs diff --git a/tests/flattening/match-case/if_fully_irreg.fut b/tests/flattening/match-case/if_fully_irreg.fut new file mode 100644 index 0000000000..4efbbe1771 --- /dev/null +++ b/tests/flattening/match-case/if_fully_irreg.fut @@ -0,0 +1,24 @@ +-- == +-- entry: main +-- nobench input { [ 2i64, 7i64, 1i64, 8i64, 7i64] } +-- output { [ 2i64, 23i64, 0i64, 31i64, 23i64] } +-- nobench input { [ 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [ 0i64, 2i64, 6i64, 12i64, 20i64] } +-- nobench input { [ 6i64, 7i64, 8i64, 9i64, 10i64] } +-- output { [16i64, 23i64, 31i64, 40i64, 50i64] } +-- nobench input { empty([0]i64) } +-- output { empty([0]i64) } + +#[noinline] +let bar [n] (xs : [n]i64) = + if n <= 5 then (false, xs) + else (true, copy xs with [5] = n) + +#[noinline] +let foo (x : i64) = + let xs = iota x in + let (b, ys) = bar xs + let z = reduce (+) 0 ys + in if b then z else z * 2 + +def main [n] (xs : [n]i64) = map foo xs diff --git a/tests/flattening/match-case/if_irreg_input.fut b/tests/flattening/match-case/if_irreg_input.fut new file mode 100644 index 0000000000..8b4a164b0b --- /dev/null +++ b/tests/flattening/match-case/if_irreg_input.fut @@ -0,0 +1,17 @@ +-- == +-- entry: main +-- nobench input { [-5i64,-3i64,4i64,2i64,0i64,-1i64,3i64,1i64] } +-- output { [-1i64,-1i64,6i64,1i64,0i64,-1i64,3i64,0i64] } +-- nobench input { [ 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [ 0i64, 1i64, 3i64, 6i64,10i64] } +-- nobench input { [-1i64,-2i64,-3i64,-4i64,-5i64] } +-- output { [-1i64,-1i64,-1i64,-1i64,-1i64] } +-- nobench input { empty([0]i64) } +-- output { empty([0]i64) } + +#[noinline] +let foo (x : i64) = + let ys = iota (i64.abs x) + in if x < 0 then -1 else reduce (+) 0 ys + +def main [n] (xs : [n]i64) = map foo xs diff --git a/tests/flattening/match-case/if_irreg_result.fut b/tests/flattening/match-case/if_irreg_result.fut new file mode 100644 index 0000000000..218780ae1e --- /dev/null +++ b/tests/flattening/match-case/if_irreg_result.fut @@ -0,0 +1,20 @@ +-- == +-- entry: main +-- nobench input { [ -5i64,-3i64,4i64,2i64,0i64,-1i64,3i64,1i64] } +-- output { [300i64,36i64,6i64,1i64,0i64, 0i64,3i64,0i64] } +-- nobench input { [ 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [ 0i64, 1i64, 3i64, 6i64, 10i64] } +-- nobench input { [ 1i64,-2i64,-3i64, -4i64, -5i64] } +-- output { [ 0i64, 6i64,36i64,120i64,300i64] } +-- nobench input { empty([0]i64) } +-- output { empty([0]i64) } + +#[noinline] +let bar (x : i64) = if x < 0 then iota (x*x) else iota x + +#[noinline] +let foo (x : i64) = + let ys = bar x + in reduce (+) 0 ys + +def main [n] (xs : [n]i64) = map foo xs diff --git a/tests/flattening/match-case/match_fully_irreg.fut b/tests/flattening/match-case/match_fully_irreg.fut new file mode 100644 index 0000000000..72f4b6b2c5 --- /dev/null +++ b/tests/flattening/match-case/match_fully_irreg.fut @@ -0,0 +1,25 @@ +-- == +-- entry: main +-- nobench input { [0i64, 0i64, 0i64, 1i64, 1i64, 1i64, 2i64, 2i64, 2i64] [0i64, 1i64, 2i64, 0i64, 1i64, 2i64, 0i64, 1i64, 2i64] } +-- output { [7i64, -5i64, -4i64, 2i64, -1i64, -1i64, 1i64, -1i64, 2i64] } +-- nobench input { [0i64, 0i64, 0i64, 1i64, 1i64, 1i64, 2i64, 2i64, 2i64] [2i64, 2i64, 2i64, 1i64, 1i64, 1i64, 0i64, 0i64, 0i64] } +-- output { [-4i64, -4i64, -4i64, -1i64, -1i64, -1i64, 1i64, 1i64, 1i64] } +-- nobench input { [1i64, 2i64, 3i64] [4i64, 5i64, 6i64] } +-- output { [2i64, 35i64, 135i64] } +-- nobench input { empty([0]i64) empty([0]i64) } +-- output { empty([0]i64) } + +#[noinline] +let foo (x : i64) (y : i64) (zs : []i64) = + let (a, as) = + match (x, y) + case (0,0) -> (3,iota 5) + case (0,b) -> (5,iota b) + case (a,0) -> (a,iota 3) + case (a,b) -> (a*b, zs) + in reduce (+) 0 as - a + +let bar (x : i64) (y : i64) = + let zs = iota (x * y) in foo x y zs + +def main [n] (xs : [n]i64) (ys : [n]i64) = map2 bar xs ys diff --git a/tests/flattening/range-irreg-stride.fut b/tests/flattening/range-irreg-stride.fut new file mode 100644 index 0000000000..464fd4acef --- /dev/null +++ b/tests/flattening/range-irreg-stride.fut @@ -0,0 +1,7 @@ +-- == +-- input { 10i64 [1,2] } +-- output { [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], +-- [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]] +-- } + +def main k = map (\s -> (0..s.. [k]i32) diff --git a/tests/flattening/range-opaque-red.fut b/tests/flattening/range-opaque-red.fut new file mode 100644 index 0000000000..a2ca853795 --- /dev/null +++ b/tests/flattening/range-opaque-red.fut @@ -0,0 +1,7 @@ +-- == +-- input { [1i64,2i64] [3i64,3i64] [10i64,8i64] } +-- output { [25i64, 27i64] } +-- input { [1i64,2i64] [3i64,2i64] [10i64,-8i64] } +-- error: Range 2..2..<-8 is invalid + +def main = map3 (\a b c -> i64.sum (opaque (a..b.. (opaque (transpose (opaque xs)))[1,1]) xsss diff --git a/tests/flattening/rearrange1.fut b/tests/flattening/rearrange1.fut new file mode 100644 index 0000000000..371768b399 --- /dev/null +++ b/tests/flattening/rearrange1.fut @@ -0,0 +1,5 @@ +-- == +-- input { [3i64,4i64] } +-- output { [1i64,1i64] } + +def main = map (\n -> ((transpose (replicate (n+1) (iota n))))[1,1]) diff --git a/tests/flattening/redomap1.fut b/tests/flattening/redomap1.fut deleted file mode 100644 index 0621ac3b84..0000000000 --- a/tests/flattening/redomap1.fut +++ /dev/null @@ -1,17 +0,0 @@ --- == --- input { --- [[1,2,3],[1,2,3]] --- [[3,2,1],[6,7,8]] --- } --- output { --- [12, 27] --- } -def main [m][n] (xss: [m][n]i32) (yss: [m][n]i32): [m]i32 = - let final_res = - map (\(xs: [n]i32, ys: [n]i32): i32 -> - let tmp = - map (\(x: i32, y: i32): i32 -> x+y - ) (zip xs ys) in - reduce (+) 0 tmp - ) (zip xss yss) - in final_res diff --git a/tests/flattening/redomap2.fut b/tests/flattening/redomap2.fut deleted file mode 100644 index fa96cdb488..0000000000 --- a/tests/flattening/redomap2.fut +++ /dev/null @@ -1,13 +0,0 @@ --- == --- input { --- [1,2,3] --- [6,7,8] --- } --- output { --- 27 --- } -def main [n] (xs: [n]i32) (ys: [n]i32): i32 = - let tmp = - map (\(x: i32, y: i32): i32 -> x+y - ) (zip xs ys) in - reduce (+) 0 tmp diff --git a/tests/flattening/replicate0.fut b/tests/flattening/replicate0.fut new file mode 100644 index 0000000000..1e34ee3240 --- /dev/null +++ b/tests/flattening/replicate0.fut @@ -0,0 +1,6 @@ +-- == +-- input { [1i64,2i64] [0, 1] [4,5] } +-- output { [4,5] } + +def main = map3 (\n (i:i32) (x:i32) -> let A = opaque (replicate n x) + in #[unsafe] A[i]) diff --git a/tests/flattening/replicate1.fut b/tests/flattening/replicate1.fut new file mode 100644 index 0000000000..9d61bf17d3 --- /dev/null +++ b/tests/flattening/replicate1.fut @@ -0,0 +1,7 @@ +-- Now we are replicating a regular array. +-- == +-- input { [1i64,2i64] [0, 1] [[4,5],[5,6]] } +-- output { [[4,5],[5,6]] } + +def main = map3 (\n (i:i32) (x:[2]i32) -> let A = opaque (replicate n x) + in #[unsafe] A[i]) diff --git a/tests/flattening/slice-red.fut b/tests/flattening/slice-red.fut new file mode 100644 index 0000000000..4362860300 --- /dev/null +++ b/tests/flattening/slice-red.fut @@ -0,0 +1,5 @@ +-- == +-- input { [[0i64,1i64,5i64],[-2i64,9i64,1i64]] [0i64,1i64] } +-- output { [6i64,10i64] } + +def main = map2 (\A (i:i64) -> i64.sum A[i:]) diff --git a/tests/flattening/slice2d-red.fut b/tests/flattening/slice2d-red.fut new file mode 100644 index 0000000000..ad2bc650e5 --- /dev/null +++ b/tests/flattening/slice2d-red.fut @@ -0,0 +1,5 @@ +-- == +-- input { [[[0i64,1i64],[4i64,5i64]],[[-2i64,9i64],[9i64,2i64]]] [0i64,1i64] [1i64,0i64] } +-- output { [6i64,11i64] } + +def main = map3 (\A (i:i64) (j: i64) -> i64.sum (flatten A[i:,j:])) diff --git a/tests/flattening/update_dimfix.fut b/tests/flattening/update_dimfix.fut new file mode 100644 index 0000000000..6ddad56a10 --- /dev/null +++ b/tests/flattening/update_dimfix.fut @@ -0,0 +1,38 @@ +-- Test with fixed dimension +-- == +-- input { [0,1,2,3,4] [0,1,2,3,4] [5,6,7,8,9] } +-- output { +-- [[[5, 1, 2, 3, 4], +-- [0, 6, 2, 3, 4], +-- [0, 1, 7, 3, 4], +-- [0, 1, 2, 8, 4], +-- [0, 1, 2, 3, 9]], +-- [[5, 1, 2, 3, 4], +-- [0, 6, 2, 3, 4], +-- [0, 1, 7, 3, 4], +-- [0, 1, 2, 8, 4], +-- [0, 1, 2, 3, 9]], +-- [[5, 1, 2, 3, 4], +-- [0, 6, 2, 3, 4], +-- [0, 1, 7, 3, 4], +-- [0, 1, 2, 8, 4], +-- [0, 1, 2, 3, 9]], +-- [[5, 1, 2, 3, 4], +-- [0, 6, 2, 3, 4], +-- [0, 1, 7, 3, 4], +-- [0, 1, 2, 8, 4], +-- [0, 1, 2, 3, 9]], +-- [[5, 1, 2, 3, 4], +-- [0, 6, 2, 3, 4], +-- [0, 1, 7, 3, 4], +-- [0, 1, 2, 8, 4], +-- [0, 1, 2, 3, 9]]] +-- } + +let main (arr: []i32) (is: []i32) (js: []i32) = + [map2(\i j -> (copy arr with [i] = j)) is js + ,map2(\i j -> (copy arr with [i] = j)) is js + ,map2(\i j -> (copy arr with [i] = j)) is js + ,map2(\i j -> (copy arr with [i] = j)) is js + ,map2(\i j -> (copy arr with [i] = j)) is js + ] diff --git a/tests/flattening/update_fully_irregular.fut b/tests/flattening/update_fully_irregular.fut new file mode 100644 index 0000000000..cbd5a98f93 --- /dev/null +++ b/tests/flattening/update_fully_irregular.fut @@ -0,0 +1,7 @@ +-- Fully irregular test-case +-- == +-- input { [5i64,6i64,7i64] [2i64,3i64,1i64] [3i64,1i64,2i64] [5i64,6i64,3i64] [1i64,2i64,3i64] } +-- output { [4i64,9i64,19i64] } + +entry main [n] (xs : [n]i64) (vs : [n]i64) (is : [n]i64) (js : [n]i64) (ss: [n]i64) = + map5 (\x v i j s -> reduce (+) 0 (iota x with [i:j:s] = iota v)) xs vs is js ss diff --git a/tests/flattening/update_invariant_is.fut b/tests/flattening/update_invariant_is.fut new file mode 100644 index 0000000000..31c8672602 --- /dev/null +++ b/tests/flattening/update_invariant_is.fut @@ -0,0 +1,7 @@ +-- Test with only invariant indices. +-- == +-- input { [4i64,5i64,6i64] [3i64,3i64,3i64] } +-- output { [3i64,7i64,12i64] } + +entry main [n] (xs : [n]i64) (vs : [n]i64) = + map2(\x v -> reduce (+) 0 (iota x with [1:4] = iota v)) xs vs diff --git a/tests/flattening/update_invariant_vs.fut b/tests/flattening/update_invariant_vs.fut new file mode 100644 index 0000000000..ab2a01225a --- /dev/null +++ b/tests/flattening/update_invariant_vs.fut @@ -0,0 +1,8 @@ +-- Test with only invariant 'vs'. +-- == +-- input { [6i64,7i64,8i64] [0i64,1i64,2i64] [5i64,6i64,7i64] } +-- output { [15i64,16i64,18i64] } + +entry main [n] (xs : [n]i64) (is : [n]i64) (js : [n]i64) = + map3(\x i j -> reduce (+) 0 (iota x with [i:j] = iota 5)) xs is js + diff --git a/tests/flattening/update_invariant_xs.fut b/tests/flattening/update_invariant_xs.fut new file mode 100644 index 0000000000..fe9dc215dc --- /dev/null +++ b/tests/flattening/update_invariant_xs.fut @@ -0,0 +1,7 @@ +-- Test with only invariant 'xs'. +-- == +-- input { [1i64,2i64,3i64] [3i64,3i64,3i64] } +-- output { [8i64,8i64,10i64] } + +entry main [n] (is : [n]i64) (js : [n]i64) = + map2(\i j -> reduce (+) 0 (iota 5 with [i:j] = iota (j-i))) is js diff --git a/tests/flattening/update_mixdim.fut b/tests/flattening/update_mixdim.fut new file mode 100644 index 0000000000..77b6aa16bf --- /dev/null +++ b/tests/flattening/update_mixdim.fut @@ -0,0 +1,12 @@ +-- Mixing slices and indexes in complex ways. +-- == +-- input { [0i64,1i64] +-- [2i64,3i64] +-- [[[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64]], +-- [[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64]]] +-- [[0f64,1f64],[4f64,5f64]] +-- } +-- output { [91.0, 99.0] } + +let main [n] (is : [n]i64) (js : [n]i64) (ass : [n][][]f64) (vs : [n][]f64) = + map4(\i j as vs -> f64.sum(flatten(copy as with [i,i:j] = vs))) is js ass vs diff --git a/tests/flattening/update_multdim.fut b/tests/flattening/update_multdim.fut new file mode 100644 index 0000000000..f90aa3f2b3 --- /dev/null +++ b/tests/flattening/update_multdim.fut @@ -0,0 +1,11 @@ +-- == +-- input { [0i64,1i64] +-- [2i64,3i64] +-- [[[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64]], +-- [[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64]]] +-- [[[0f64,1f64],[2f64,3f64]],[[4f64,5f64],[6f64,7f64]]] +-- } +-- output { [78.0, 94.0] } + +let main [n] (is : [n]i64) (js : [n]i64) (ass : [n][][]f64) (vss : [n][][]f64) = + map4(\i j as vs -> f64.sum(flatten(copy as with [i:j,i:j] = vs))) is js ass vss diff --git a/tests/flattening/update_variant_is.fut b/tests/flattening/update_variant_is.fut new file mode 100644 index 0000000000..8c62f3ba46 --- /dev/null +++ b/tests/flattening/update_variant_is.fut @@ -0,0 +1,7 @@ +-- Test with only variant indices. +-- == +-- input { [0i64,3i64,1i64] [5i64,8i64,6i64] } +-- output { [28i64,13i64,23i64] } + +entry main [n] (is : [n]i64) (js : [n]i64) = + map2 (\i j -> reduce (+) 0 (iota 8 with [i:j] = iota 5)) is js diff --git a/tests/flattening/update_variant_vs.fut b/tests/flattening/update_variant_vs.fut new file mode 100644 index 0000000000..3fdec27aae --- /dev/null +++ b/tests/flattening/update_variant_vs.fut @@ -0,0 +1,7 @@ +-- Test with only variant 'vs'. +-- == +-- input { [3i64,3i64,3i64] } +-- output { [7i64,7i64,7i64] } + +entry main (vs : []i64) = + map (\v -> reduce (+) 0 (iota 5 with [1:4] = iota v)) vs diff --git a/tests/flattening/update_variant_xs.fut b/tests/flattening/update_variant_xs.fut new file mode 100644 index 0000000000..89730113b3 --- /dev/null +++ b/tests/flattening/update_variant_xs.fut @@ -0,0 +1,7 @@ +-- Test with only variant 'xs'. +-- == +-- input { [4i64,5i64,6i64] } +-- output { [3i64,7i64,12i64] } + +entry main [n] (xs : [n]i64) = + map (\x -> reduce (+) 0 (iota x with [1:4] = iota 3)) xs