Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flatten concat #2215

Merged
merged 22 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 108 additions & 2 deletions src/Futhark/Pass/Flatten.hs
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ import Prelude hiding (div, rem)
--
-- A_O = [0, 0, 6]
--
-- A_II1 = [0,0,0,1,3,3,4,6,6,6]
-- A_II1 = [1,1,1,1,1,1,2,2,2,2]
--
-- A_II2 = [1,1,1,1,1,1,2,2,2,2]
-- A_II2 = [0,0,0,1,3,3,0,2,2,2]

data IrregularRep = IrregularRep
{ -- | Array of size of each segment, type @[]i64@.
Expand Down Expand Up @@ -321,6 +321,107 @@ getIrregRep segments env inps v =
Replicate (segmentsShape segments) (Var v)
mkIrregFromReg segments v'

-- Do 'map2 (++) A B' where 'A' and 'B' are irregular arrays and have the same
-- number of subarrays
concatIrreg ::
Segments ->
DistEnv ->
VName ->
[IrregularRep] ->
Builder GPU IrregularRep
concatIrreg _segments _env ns reparr = do
-- Concatenation does not change the number of segments - it simply
-- makes each of them larger.

num_segments <- arraySize 0 <$> lookupType ns

-- Constructs the full list size / shape that should hold the final results.
let zero = Constant $ IntValue $ intValue Int64 (0 :: Int)
ns_full <- letExp (baseString ns <> "_full") <=< segMap (MkSolo num_segments) $
\(MkSolo i) -> do
old_segments <-
forM reparr $ \rep ->
letSubExp "old_segment" =<< eIndex (irregularS rep) [eSubExp i]
new_segment <-
letSubExp "new_segment"
=<< toExp (foldl (+) (pe64 zero) $ map pe64 old_segments)
pure $ subExpsRes [new_segment]

(ns_full_F, ns_full_O, ns_II1) <- doRepIota ns_full

repIota <- mapM (doRepIota . irregularS) reparr
segIota <- mapM (doSegIota . irregularS) reparr

let (_, _, rep_II1) = unzip3 repIota
let (_, _, rep_II2) = unzip3 segIota

n_arr <- mapM (fmap (arraySize 0) . lookupType) rep_II1

-- Calculate offsets for the scatter operations
let shapes = map irregularS reparr
scatter_offsets <-
letTupExp "irregular_scatter_offsets" <=< segMap (MkSolo num_segments) $
\(MkSolo i) -> do
segment_sizes <-
forM shapes $ \shape ->
letSubExp "segment_size" =<< eIndex shape [eSubExp i]
let prefixes = L.init $ L.inits segment_sizes
sumprefix <-
mapM
( letSubExp "segment_prefix"
<=< foldBinOp (Add Int64 OverflowUndef) (intConst Int64 0)
)
prefixes
pure $ subExpsRes sumprefix

scatter_offsets_T <-
letTupExp "irregular_scatter_offsets_T" <=< segMap (MkSolo num_segments) $
\(MkSolo i) -> do
columns <-
forM scatter_offsets $ \offsets ->
letSubExp "segment_offset" =<< eIndex offsets [eSubExp i]
pure $ subExpsRes columns

-- Scatter data into result array
elems <-
foldlM
( \elems (reparr1, scatter_offset, n, ii1, ii2) -> do
letExp "irregular_scatter_elems" <=< genScatter elems n $ \gid -> do
-- Which segment we are in.
segment_i <-
letSubExp "segment_i" =<< eIndex ii1 [eSubExp gid]

-- Get segment offset in final array
segment_o <-
letSubExp "segment_o" =<< eIndex ns_full_O [eSubExp segment_i]

-- Get local segment offset
segment_local_o <-
letSubExp "segment_local_o"
=<< eIndex scatter_offset [eSubExp segment_i]

-- Value to write
v' <-
letSubExp "v" =<< eIndex (irregularD reparr1) [eSubExp gid]
o' <- letSubExp "o" =<< eIndex ii2 [eSubExp gid]

-- Index to write `v'` at
i <-
letExp "i" =<< toExp (pe64 o' + pe64 segment_local_o + pe64 segment_o)

pure (i, v')
)
ns_II1
$ L.zip5 reparr scatter_offsets_T n_arr rep_II1 rep_II2

pure $
IrregularRep
{ irregularS = ns_full,
irregularF = ns_full_F,
irregularO = ns_full_O,
irregularD = elems
}

-- Do 'map2 replicate ns A', where 'A' is an irregular array (and so
-- is the result, obviously).
replicateIrreg ::
Expand Down Expand Up @@ -536,6 +637,11 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) =
~+~ sExt it (untyped (pe64 v'))
~*~ primExpFromSubExp (IntType it) s'
pure $ insertIrregular ns res_F res_O (distResTag res) res_D' env
Concat 0 arr shp -> do
ns <- dataArr segments env inps shp
reparr <- mapM (getIrregRep segments env inps) (NE.toList arr)
rep' <- concatIrreg segments env ns reparr
pure $ insertRep (distResTag res) (Irregular rep') env
Replicate (Shape [n]) (Var v) -> do
ns <- dataArr segments env inps n
rep <- getIrregRep segments env inps v
Expand Down
27 changes: 27 additions & 0 deletions tests/flattening/concat-check-index.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
-- Validation of flattening with 2 lists
--
-- ==
-- entry: validate_flattening_2
-- input {[0i64, 1i64, 2i64, 3i64, 4i64, 5i64]
-- [0i64, 2i64, 4i64, 6i64, 8i64, 10i64]
-- [0i64,0i64,2i64,6i64,12i64,20i64,30i64]
-- [0i64,0i64,
-- 0i64,1i64,0i64,1i64,
-- 0i64,1i64,2i64,0i64,1i64,4i64,
-- 0i64,1i64,2i64,3i64,0i64,1i64,4i64,9i64,
-- 0i64,1i64,2i64,3i64,4i64,0i64,1i64,4i64,9i64,16i64
--]}
-- output {[true, true, true, true, true, true]}
entry validate_flattening_2 (ns: []i64) (shp: []i64) (offsets: []i64) (expected: []i64) : []bool =
map2 (\n i ->
let irreg = opaque (iota n `concat` (iota n |> map (**2)))
in
if shp[i] == 0i64 && length irreg == 0i64 then true
else if shp[i] == 0i64 then false
else
let gts = iota shp[i] |> map (\j -> expected[offsets[i] + j]) :> [n + n]i64
let pairs = zip irreg gts
let eq = map (\(pd, gt) -> pd == gt) pairs
in
reduce (&&) true eq
)ns (indices ns)
8 changes: 8 additions & 0 deletions tests/flattening/concat-iota.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- Validation of flattening with 2 lists
--
-- ==
-- entry: validate_flattening
-- input {[0i64, 1i64, 2i64, 3i64, 4i64, 5i64]}
-- output {[0i64, 0i64, 2i64, 6i64, 12i64, 20i64]}
entry validate_flattening (ns: []i64) : []i64 =
map (\n -> i64.sum (opaque (iota n `concat` iota n))) ns
8 changes: 8 additions & 0 deletions tests/flattening/concat-rep.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- Validation of flattening with 3 lists
--
-- ==
-- entry: validate_flattening2
-- input {[0i64, 1i64, 2i64, 3i64, 4i64, 5i64, 10i64, 13i64]}
-- output {[0i64, 1i64, 4i64, 9i64, 16i64, 25i64, 100i64, 169i64]}
entry validate_flattening2 (ns: []i64) : []i64 =
map (\n -> i64.sum (opaque ((replicate n 1) `concat` iota n `concat` iota n))) ns
Loading