diff --git a/CHANGELOG.md b/CHANGELOG.md index d3c3a3faea..dadc31d326 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. * `futhark-bench`: Add ``--skip-compilation`` flag. + * `scatter` expressions nested in `map`s are now parallelised. + ### Removed * `futhark-dataset`: Removed `--binary-no-header` and diff --git a/src/Futhark/Pass/ExtractKernels.hs b/src/Futhark/Pass/ExtractKernels.hs index febfea2f8b..5b10c837e3 100644 --- a/src/Futhark/Pass/ExtractKernels.hs +++ b/src/Futhark/Pass/ExtractKernels.hs @@ -991,6 +991,20 @@ maybeDistributeStm bnd@(Let _ aux (BasicOp (Reshape reshape _))) acc = map DimNew (newDims reshape) addKernel $ oneStm $ Let outerpat aux $ BasicOp $ Reshape reshape' arr +maybeDistributeStm stm@(Let _ aux (BasicOp (Concat d _ _ w))) acc = + distributeSingleStm acc stm >>= \case + Just (kernels, _, nest, acc') + | (outer, _) <- nest, + x:xs <- map snd $ loopNestingParamsAndArrs outer -> + localScope (typeEnvFromKernelAcc acc') $ do + let d' = d + length (snd nest) + 1 + outerpat = loopNestingPattern $ fst nest + addKernels kernels + addKernel $ oneStm $ Let outerpat aux $ BasicOp $ Concat d' x xs w + return acc' + _ -> + addStmToKernel stm acc + maybeDistributeStm bnd acc = addStmToKernel bnd acc diff --git a/tests/distribution/segconcat0.fut b/tests/distribution/segconcat0.fut new file mode 100644 index 0000000000..533c509136 --- /dev/null +++ b/tests/distribution/segconcat0.fut @@ -0,0 +1,9 @@ +-- Nested concatenation just becomes a concatenation along an inner dimension. +-- == +-- input { [[1,2,3],[4,5,6]] [[3,2,1],[6,5,4]] } +-- output { [[1,2,3,3,2,1], +-- [4,5,6,6,5,4]] } +-- structure distributed { Kernel 0 } + +let main (xss: [][]i32) (yss: [][]i32) = + map (\(xs, ys) -> concat xs ys) (zip xss yss) \ No newline at end of file diff --git a/tests/distribution/segconcat1.fut b/tests/distribution/segconcat1.fut new file mode 100644 index 0000000000..9302dcfeda --- /dev/null +++ b/tests/distribution/segconcat1.fut @@ -0,0 +1,8 @@ +-- Nested concatenation with more arrays. +-- == +-- input { [[1,2],[3,4],[5,6]] [[1,2],[3,4],[5,6]] [[1,2],[3,4],[5,6]] } +-- output { [[1,2,1,2,1,2], [3,4,3,4,3,4], [5,6,5,6,5,6]] } +-- structure distributed { Kernel 0 } + +let main (xss: [][]i32) (yss: [][]i32) (zss: [][]i32) = + map (\(xs, ys, zs) -> concat xs ys zs) (zip xss yss zss) \ No newline at end of file