diff --git a/dace/transformation/dataflow/prune_connectors.py b/dace/transformation/dataflow/prune_connectors.py index 8254592477..a8371047df 100644 --- a/dace/transformation/dataflow/prune_connectors.py +++ b/dace/transformation/dataflow/prune_connectors.py @@ -41,17 +41,19 @@ def _get_prune_sets(self, state: SDFGState) -> Tuple[Set[str], Set[str]]: """ nsdfg = self.nsdfg - # Note the implementation of `read_and_write_sets()` filters array - # that fully written and read from the read set and only includes - # them in the write set. Thus we have to assume that every write - # is also a read to compensate for this. + # From the input connectors (i.e. data container on the inside) remove + # all those that are not used for reading and from the output containers + # remove those that are not used fro reading. + # NOTE: If a data container is used for reading and writing then only the + # output connector is retained, except the output is a WCR, then the input + # is also retained. read_set, write_set = nsdfg.sdfg.read_and_write_sets() - prune_in = nsdfg.in_connectors.keys() - (read_set | write_set) + prune_in = nsdfg.in_connectors.keys() - read_set prune_out = nsdfg.out_connectors.keys() - write_set - # Note symbols can not be passed through connectors. For that reason - # we do not have to check for them. They should be passed through - # the symbol mapping. + for e in state.out_edges(nsdfg): + if e.data.wcr is not None and e.src_conn in prune_in: + prune_in.remove(e.src_conn) return prune_in, prune_out diff --git a/tests/transformations/prune_connectors_test.py b/tests/transformations/prune_connectors_test.py index 96fcf9819b..4026ec3e1c 100644 --- a/tests/transformations/prune_connectors_test.py +++ b/tests/transformations/prune_connectors_test.py @@ -4,6 +4,8 @@ import os import copy import pytest +from typing import Tuple + import dace from dace.transformation.dataflow import PruneConnectors from dace.transformation.helpers import nest_state_subgraph @@ -137,6 +139,102 @@ def make_sdfg(): return sdfg_outer +def _make_read_write_sdfg( + conforming_memlet: bool, +) -> Tuple[dace.SDFG, dace.nodes.NestedSDFG]: + """Creates an SDFG for the `test_read_write_{1, 2}` tests. + + The SDFG is rather synthetic, it has an input `in_arg` and adds to every element + 10 and stores that in array `A`, through access node `A1`. From this access node + the data flows into a nested SDFG. However, the data is not read but overwritten, + through a map that writes through access node `inner_A`. That access node + then writes into container `inner_B`. Both `inner_A` and `inner_B` are outputs + of the nested SDFG and are written back into data container `A` and `B`. + + Depending on `conforming_memlet` the memlet that copies `inner_A` into `inner_B` + will either be associated to `inner_A` (`True`) or `inner_B` (`False`). + This choice has consequences on if the transformation can apply or not. + + Notes: + This is most likely a bug, see [issue#1643](https://github.com/spcl/dace/issues/1643), + however, it is the historical behaviour. + """ + + # Creating the outer SDFG. + osdfg = dace.SDFG("Outer_sdfg") + ostate = osdfg.add_state(is_start_block=True) + + osdfg.add_array("in_arg", dtype=dace.float64, shape=(4, 4), transient=False) + osdfg.add_array("A", dtype=dace.float64, shape=(4, 4), transient=False) + osdfg.add_array("B", dtype=dace.float64, shape=(4, 4), transient=False) + in_arg, A1, A2, B = (ostate.add_access(name) for name in ["in_arg", "A", "A", "B"]) + + ostate.add_mapped_tasklet( + "producer", + map_ranges={"i": "0:4", "j": "0:4"}, + inputs={"__in": dace.Memlet("in_arg[i, j]")}, + code="__out = __in + 10.", + outputs={"__out": dace.Memlet("A[i, j]")}, + input_nodes={in_arg}, + output_nodes={A1}, + external_edges=True, + ) + + # Creating the inner SDFG + isdfg = dace.SDFG("Inner_sdfg") + istate = isdfg.add_state(is_start_block=True) + + isdfg.add_array("inner_A", dtype=dace.float64, shape=(4, 4), transient=False) + isdfg.add_array("inner_B", dtype=dace.float64, shape=(4, 4), transient=False) + inner_A, inner_B = (istate.add_access(name) for name in ["inner_A", "inner_B"]) + + istate.add_mapped_tasklet( + "inner_consumer", + map_ranges={"i": "0:4", "j": "0:4"}, + inputs={}, + code="__out = 10", + outputs={"__out": dace.Memlet("inner_A[i, j]")}, + output_nodes={inner_A}, + external_edges=True, + ) + + # Depending on to which data container this memlet is associated, + # the transformation will apply or it will not apply. + if conforming_memlet: + # Because the `data` field of the inncoming and outgoing memlet are both + # set to `inner_A` the read to `inner_A` will be removed and the + # transformation can apply. + istate.add_nedge( + inner_A, + inner_B, + dace.Memlet("inner_A[0:4, 0:4] -> 0:4, 0:4"), + ) + else: + # Because the `data` filed of the involved memlets differs the read to + # `inner_A` will not be removed thus the transformation can not remove + # the incoming `inner_A`. + istate.add_nedge( + inner_A, + inner_B, + dace.Memlet("inner_B[0:4, 0:4] -> 0:4, 0:4"), + ) + + # Add the nested SDFG + nsdfg = ostate.add_nested_sdfg( + sdfg=isdfg, + parent=osdfg, + inputs={"inner_A"}, + outputs={"inner_A", "inner_B"}, + ) + + # Connecting the nested SDFG + ostate.add_edge(A1, None, nsdfg, "inner_A", dace.Memlet("A[0:4, 0:4]")) + ostate.add_edge(nsdfg, "inner_A", A2, None, dace.Memlet("A[0:4, 0:4]")) + ostate.add_edge(nsdfg, "inner_B", B, None, dace.Memlet("B[0:4, 0:4]")) + + return osdfg, nsdfg + + def test_prune_connectors(n=None): if n is None: n = 64 @@ -234,6 +332,16 @@ def test_unused_retval_2(): assert np.allclose(a, 1) +def test_read_write_1(): + # Because the memlet is conforming, we can apply the transformation. + sdfg = _make_read_write_sdfg(True) + + assert first_mode == PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=osdfg, expr_index=0, permissive=False) + + + + + def test_prune_connectors_with_dependencies(): sdfg = dace.SDFG('tester') A, A_desc = sdfg.add_array('A', [4], dace.float64) @@ -312,6 +420,21 @@ def test_prune_connectors_with_dependencies(): assert np.allclose(np_d, np_d_) +def test_read_write_1(): + # Because the memlet is conforming, we can apply the transformation. + sdfg, nsdfg = _make_read_write_sdfg(True) + + assert PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False) + sdfg.apply_transformations_repeated(PruneConnectors, validate=True, validate_all=True) + + +def test_read_write_2(): + # Because the memlet is not conforming, we can not apply the transformation. + sdfg, nsdfg = _make_read_write_sdfg(False) + + assert not PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--N", default=64) @@ -324,3 +447,5 @@ def test_prune_connectors_with_dependencies(): test_unused_retval() test_unused_retval_2() test_prune_connectors_with_dependencies() + test_read_write_1() + test_read_write_2()