Skip to content

Commit

Permalink
The PruneConnectors now again has the old behaviour in when an inpu…
Browse files Browse the repository at this point in the history
…t is maintained and when not.

However, this has some consequences.
I added a test, that shows that this leads, depending on the memlet configuration to different outcome of the transformation.
Okay, it is only affects if the transformation can be applied or not, but still.
This is also in line with my [issue#1643](spcl#1643) that shows that this is a problem.
  • Loading branch information
philip-paul-mueller committed Sep 20, 2024
1 parent da7754a commit e1f2bf6
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 8 deletions.
18 changes: 10 additions & 8 deletions dace/transformation/dataflow/prune_connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,19 @@ def _get_prune_sets(self, state: SDFGState) -> Tuple[Set[str], Set[str]]:
"""
nsdfg = self.nsdfg

# Note the implementation of `read_and_write_sets()` filters array
# that fully written and read from the read set and only includes
# them in the write set. Thus we have to assume that every write
# is also a read to compensate for this.
# From the input connectors (i.e. data container on the inside) remove
# all those that are not used for reading and from the output containers
# remove those that are not used fro reading.
# NOTE: If a data container is used for reading and writing then only the
# output connector is retained, except the output is a WCR, then the input
# is also retained.
read_set, write_set = nsdfg.sdfg.read_and_write_sets()
prune_in = nsdfg.in_connectors.keys() - (read_set | write_set)
prune_in = nsdfg.in_connectors.keys() - read_set
prune_out = nsdfg.out_connectors.keys() - write_set

# Note symbols can not be passed through connectors. For that reason
# we do not have to check for them. They should be passed through
# the symbol mapping.
for e in state.out_edges(nsdfg):
if e.data.wcr is not None and e.src_conn in prune_in:
prune_in.remove(e.src_conn)

return prune_in, prune_out

Expand Down
125 changes: 125 additions & 0 deletions tests/transformations/prune_connectors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import os
import copy
import pytest
from typing import Tuple

import dace
from dace.transformation.dataflow import PruneConnectors
from dace.transformation.helpers import nest_state_subgraph
Expand Down Expand Up @@ -137,6 +139,102 @@ def make_sdfg():
return sdfg_outer


def _make_read_write_sdfg(
conforming_memlet: bool,
) -> Tuple[dace.SDFG, dace.nodes.NestedSDFG]:
"""Creates an SDFG for the `test_read_write_{1, 2}` tests.
The SDFG is rather synthetic, it has an input `in_arg` and adds to every element
10 and stores that in array `A`, through access node `A1`. From this access node
the data flows into a nested SDFG. However, the data is not read but overwritten,
through a map that writes through access node `inner_A`. That access node
then writes into container `inner_B`. Both `inner_A` and `inner_B` are outputs
of the nested SDFG and are written back into data container `A` and `B`.
Depending on `conforming_memlet` the memlet that copies `inner_A` into `inner_B`
will either be associated to `inner_A` (`True`) or `inner_B` (`False`).
This choice has consequences on if the transformation can apply or not.
Notes:
This is most likely a bug, see [issue#1643](https://github.com/spcl/dace/issues/1643),
however, it is the historical behaviour.
"""

# Creating the outer SDFG.
osdfg = dace.SDFG("Outer_sdfg")
ostate = osdfg.add_state(is_start_block=True)

osdfg.add_array("in_arg", dtype=dace.float64, shape=(4, 4), transient=False)
osdfg.add_array("A", dtype=dace.float64, shape=(4, 4), transient=False)
osdfg.add_array("B", dtype=dace.float64, shape=(4, 4), transient=False)
in_arg, A1, A2, B = (ostate.add_access(name) for name in ["in_arg", "A", "A", "B"])

ostate.add_mapped_tasklet(
"producer",
map_ranges={"i": "0:4", "j": "0:4"},
inputs={"__in": dace.Memlet("in_arg[i, j]")},
code="__out = __in + 10.",
outputs={"__out": dace.Memlet("A[i, j]")},
input_nodes={in_arg},
output_nodes={A1},
external_edges=True,
)

# Creating the inner SDFG
isdfg = dace.SDFG("Inner_sdfg")
istate = isdfg.add_state(is_start_block=True)

isdfg.add_array("inner_A", dtype=dace.float64, shape=(4, 4), transient=False)
isdfg.add_array("inner_B", dtype=dace.float64, shape=(4, 4), transient=False)
inner_A, inner_B = (istate.add_access(name) for name in ["inner_A", "inner_B"])

istate.add_mapped_tasklet(
"inner_consumer",
map_ranges={"i": "0:4", "j": "0:4"},
inputs={},
code="__out = 10",
outputs={"__out": dace.Memlet("inner_A[i, j]")},
output_nodes={inner_A},
external_edges=True,
)

# Depending on to which data container this memlet is associated,
# the transformation will apply or it will not apply.
if conforming_memlet:
# Because the `data` field of the inncoming and outgoing memlet are both
# set to `inner_A` the read to `inner_A` will be removed and the
# transformation can apply.
istate.add_nedge(
inner_A,
inner_B,
dace.Memlet("inner_A[0:4, 0:4] -> 0:4, 0:4"),
)
else:
# Because the `data` filed of the involved memlets differs the read to
# `inner_A` will not be removed thus the transformation can not remove
# the incoming `inner_A`.
istate.add_nedge(
inner_A,
inner_B,
dace.Memlet("inner_B[0:4, 0:4] -> 0:4, 0:4"),
)

# Add the nested SDFG
nsdfg = ostate.add_nested_sdfg(
sdfg=isdfg,
parent=osdfg,
inputs={"inner_A"},
outputs={"inner_A", "inner_B"},
)

# Connecting the nested SDFG
ostate.add_edge(A1, None, nsdfg, "inner_A", dace.Memlet("A[0:4, 0:4]"))
ostate.add_edge(nsdfg, "inner_A", A2, None, dace.Memlet("A[0:4, 0:4]"))
ostate.add_edge(nsdfg, "inner_B", B, None, dace.Memlet("B[0:4, 0:4]"))

return osdfg, nsdfg


def test_prune_connectors(n=None):
if n is None:
n = 64
Expand Down Expand Up @@ -234,6 +332,16 @@ def test_unused_retval_2():
assert np.allclose(a, 1)


def test_read_write_1():
# Because the memlet is conforming, we can apply the transformation.
sdfg = _make_read_write_sdfg(True)

assert first_mode == PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=osdfg, expr_index=0, permissive=False)





def test_prune_connectors_with_dependencies():
sdfg = dace.SDFG('tester')
A, A_desc = sdfg.add_array('A', [4], dace.float64)
Expand Down Expand Up @@ -312,6 +420,21 @@ def test_prune_connectors_with_dependencies():
assert np.allclose(np_d, np_d_)


def test_read_write_1():
# Because the memlet is conforming, we can apply the transformation.
sdfg, nsdfg = _make_read_write_sdfg(True)

assert PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False)
sdfg.apply_transformations_repeated(PruneConnectors, validate=True, validate_all=True)


def test_read_write_2():
# Because the memlet is not conforming, we can not apply the transformation.
sdfg, nsdfg = _make_read_write_sdfg(False)

assert not PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--N", default=64)
Expand All @@ -324,3 +447,5 @@ def test_prune_connectors_with_dependencies():
test_unused_retval()
test_unused_retval_2()
test_prune_connectors_with_dependencies()
test_read_write_1()
test_read_write_2()

0 comments on commit e1f2bf6

Please sign in to comment.