Skip to content

Commit

Permalink
Flattening of concat (#2214)
Browse files Browse the repository at this point in the history
  • Loading branch information
athas authored Jan 28, 2025
1 parent 1de4f0c commit dbf4689
Show file tree
Hide file tree
Showing 77 changed files with 2,864 additions and 379 deletions.
4 changes: 4 additions & 0 deletions futhark.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,9 @@ library
Futhark.Pass.ExtractKernels.ToGPU
Futhark.Pass.ExtractMulticore
Futhark.Pass.FirstOrderTransform
Futhark.Pass.Flatten
Futhark.Pass.Flatten.Builtins
Futhark.Pass.Flatten.Distribute
Futhark.Pass.LiftAllocations
Futhark.Pass.LowerAllocations
Futhark.Pass.Simplify
Expand Down Expand Up @@ -475,6 +478,7 @@ library
, lsp-types >= 2.0.1.0
, mainland-pretty >=0.7.1
, cmark-gfm >=0.2.1
, OneTuple
, megaparsec >=9.0.0
, mtl >=2.2.1
, neat-interpolation >=0.3
Expand Down
36 changes: 36 additions & 0 deletions rewritefut/segupdate.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
-- Flat-Parallel Segmented Update
-- ==
-- compiled input { [1i64,2i64,3i64,1i64,2i64,1i64,2i64,3i64,4i64] [3i64,2i64,4i64] [0i64,0i64,0i64,0i64,0i64] [2i64,1i64,2i64] [0i64, 1i64, 0i64] [1i64, 1i64, 2i64] } output { [0i64,0i64,3i64,1i64,0i64,0i64,2i64,0i64,4i64] }

let sgmSumI64 [n] (flg : [n]i64) (arr : [n]i64) : [n]i64 =
let flgs_vals =
scan ( \ (f1, x1) (f2,x2) ->
let f = f1 | f2 in
if f2 != 0 then (f, x2)
else (f, x1 + x2) )
(0, 0i64) (zip flg arr)
let (_, vals) = unzip flgs_vals
in vals

let mkFlagArray [m] (aoa_shp: [m]i64) (zero: i64)
(aoa_val: [m]i64) : []i64 =
let shp_rot = map(\i -> if i==0i64 then 0i64 else aoa_shp[i-1]) (iota m)
let shp_scn = scan (+) 0i64 shp_rot
let aoa_len = shp_scn[m-1]+aoa_shp[m-1]
let shp_ind = map2 (\shp ind -> if shp==0 then -1i64 else ind) aoa_shp shp_scn
in scatter (replicate aoa_len zero) shp_ind aoa_val

let segUpdate [n][m][t] (xss_val : *[n]i64) (shp_xss : [t]i64) (vss_val : [m]i64)
(shp_vss : [t]i64) (bs : [t]i64) (ss : [t]i64): [n]i64 =
let fvss = (mkFlagArray shp_vss 0 (1...t :> [t]i64)) :> [m]i64
let II1 = sgmSumI64 fvss fvss |> map (\x -> x - 1)
let shp_xss_rot = map(\i -> if i==0i64 then 0i64 else shp_xss[i-1]) (iota t)
let bxss = scan (+) 0 shp_xss_rot
let II2 = sgmSumI64 fvss (replicate m 1) |> map (\x -> x - 1)
let iss = map (\i -> bxss[II1[i]] + bs[II1[i]] + (II2[i] * ss[II1[i]])) (iota m)
in scatter xss_val iss vss_val


let main [n][m][t] (xss_val : *[n]i64) (shp_xss : [t]i64) (vss_val : [m]i64)
(shp_vss : [t]i64) (bs : [t]i64) (ss : [t]i64): [n]i64 =
segUpdate xss_val shp_xss vss_val shp_vss bs ss
2 changes: 2 additions & 0 deletions src/Futhark/CLI/Dev.hs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import Futhark.Pass.ExplicitAllocations.Seq qualified as Seq
import Futhark.Pass.ExtractKernels
import Futhark.Pass.ExtractMulticore
import Futhark.Pass.FirstOrderTransform
import Futhark.Pass.Flatten (flattenSOACs)
import Futhark.Pass.LiftAllocations as LiftAllocations
import Futhark.Pass.LowerAllocations as LowerAllocations
import Futhark.Pass.Simplify
Expand Down Expand Up @@ -654,6 +655,7 @@ commandLineOptions =
sinkOption [],
kernelsPassOption reduceDeviceSyncs [],
typedPassOption soacsProg GPU extractKernels [],
typedPassOption soacsProg GPU flattenSOACs [],
typedPassOption soacsProg MC extractMulticore [],
allocateOption "a",
kernelsMemPassOption doubleBufferGPU [],
Expand Down
1 change: 1 addition & 0 deletions src/Futhark/IR/Pretty.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ module Futhark.IR.Pretty
( prettyTuple,
prettyTupleLines,
prettyString,
prettyRet,
PrettyRep (..),
)
where
Expand Down
12 changes: 6 additions & 6 deletions src/Futhark/IR/TypeCheck.hs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ import Futhark.Analysis.PrimExp
import Futhark.Construct (instantiateShapes)
import Futhark.IR.Aliases hiding (lookupAliases)
import Futhark.Util
import Futhark.Util.Pretty (align, docText, indent, ppTuple', pretty, (<+>), (</>))
import Futhark.Util.Pretty hiding (width)

-- | Information about an error during type checking. The 'Show'
-- instance for this type produces a human-readable description.
Expand Down Expand Up @@ -742,7 +742,7 @@ checkSubExp (Var ident) = context ("In subexp " <> prettyText ident) $ do
lookupType ident

checkCerts :: (Checkable rep) => Certs -> TypeM rep ()
checkCerts (Certs cs) = mapM_ (requireI [Prim Unit]) cs
checkCerts = mapM_ lookupType . unCerts

checkSubExpRes :: (Checkable rep) => SubExpRes -> TypeM rep Type
checkSubExpRes (SubExpRes cs se) = do
Expand Down Expand Up @@ -1023,9 +1023,9 @@ checkExp (Apply fname args rettype_annot _) = do
when (rettype_derived /= rettype_annot) $
bad . TypeError . docText $
"Expected apply result type:"
</> indent 2 (pretty $ map fst rettype_derived)
</> indent 2 (braces $ commasep $ map prettyRet rettype_derived)
</> "But annotation is:"
</> indent 2 (pretty $ map fst rettype_annot)
</> indent 2 (braces $ commasep $ map prettyRet rettype_annot)
consumeArgs paramtypes argflows
checkExp (Loop merge form loopbody) = do
let (mergepat, mergeexps) = unzip merge
Expand Down Expand Up @@ -1252,8 +1252,8 @@ checkStm ::
Stm (Aliases rep) ->
TypeM rep a ->
TypeM rep a
checkStm stm@(Let pat (StmAux (Certs cs) _ (_, dec)) e) m = do
context "When checking certificates" $ mapM_ (requireI [Prim Unit]) cs
checkStm stm@(Let pat (StmAux cs _ (_, dec)) e) m = do
context "When checking certificates" $ checkCerts cs
context "When checking expression annotation" $ checkExpDec dec
context ("When matching\n" <> message " " pat <> "\nwith\n" <> message " " e) $
matchPat pat e
Expand Down
4 changes: 4 additions & 0 deletions src/Futhark/Pass/ExtractKernels/ToGPU.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module Futhark.Pass.ExtractKernels.ToGPU
segThread,
soacsLambdaToGPU,
soacsStmToGPU,
soacsExpToGPU,
scopeForGPU,
scopeForSOACs,
injectSOACS,
Expand Down Expand Up @@ -74,6 +75,9 @@ injectSOACS f =
soacsStmToGPU :: Stm SOACS -> Stm GPU
soacsStmToGPU = runIdentity . rephraseStm (injectSOACS OtherOp)

soacsExpToGPU :: Exp SOACS -> Exp GPU
soacsExpToGPU = runIdentity . rephraseExp (injectSOACS OtherOp)

soacsLambdaToGPU :: Lambda SOACS -> Lambda GPU
soacsLambdaToGPU = runIdentity . rephraseLambda (injectSOACS OtherOp)

Expand Down
Loading

0 comments on commit dbf4689

Please sign in to comment.