diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 8701d459f4..fe74497051 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -1376,12 +1376,20 @@ def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: read_set = set() write_set = set() for state in self.states(): - for edge in state.parent_graph.in_edges(state): - read_set |= edge.data.free_symbols & self.arrays.keys() # Get dictionaries of subsets read and written from each state rs, ws = state._read_and_write_sets() read_set |= rs.keys() write_set |= ws.keys() + + array_names = self.arrays.keys() + for edge in self.all_interstate_edges(): + read_set |= edge.data.free_symbols & array_names + + # By definition, data that is referenced by symbolic condition expressions + # (branching condition, loop condition, ...) is also part of the read set. + for cfr in self.all_control_flow_regions(): + read_set |= cfr.used_symbols(all_symbols=True, with_contents=False) & array_names + return read_set, write_set def arglist(self, scalars_only=False, free_symbols=None) -> Dict[str, dt.Data]: diff --git a/tests/transformations/prune_connectors_test.py b/tests/transformations/prune_connectors_test.py index b7b287d77e..5844abc8b2 100644 --- a/tests/transformations/prune_connectors_test.py +++ b/tests/transformations/prune_connectors_test.py @@ -417,6 +417,49 @@ def test_read_write(): assert not PruneConnectors.can_be_applied_to(nsdfg=nsdfg, sdfg=sdfg, expr_index=0, permissive=False) +def test_prune_connectors_with_conditional_block(): + """ + Verifies that a connector to scalar data (here 'cond') in a NestedSDFG is not removed, + when this data is only accessed by condition expressions in ControlFlowRegion nodes. + """ + sdfg = dace.SDFG('tester') + A, A_desc = sdfg.add_array('A', [4], dace.float64) + B, B_desc = sdfg.add_array('B', [4], dace.float64) + COND, COND_desc = sdfg.add_array('COND', [4], dace.bool_) + OUT, OUT_desc = sdfg.add_array('OUT', [4], dace.float64) + + nsdfg = dace.SDFG('nested') + a, _ = nsdfg.add_scalar('a', A_desc.dtype) + b, _ = nsdfg.add_scalar('b', B_desc.dtype) + cond, _ = nsdfg.add_scalar('cond', COND_desc.dtype) + out, _ = nsdfg.add_scalar('out', OUT_desc.dtype) + + if_region = dace.sdfg.state.ConditionalBlock("if") + nsdfg.add_node(if_region) + entry_state = nsdfg.add_state("entry", is_start_block=True) + nsdfg.add_edge(entry_state, if_region, dace.InterstateEdge()) + + then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=nsdfg) + a_state = then_body.add_state("true_branch", is_start_block=True) + if_region.add_branch(dace.sdfg.state.CodeBlock(cond), then_body) + a_state.add_nedge(a_state.add_access(a), a_state.add_access(out), dace.Memlet(out)) + + else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=nsdfg) + b_state = else_body.add_state("false_branch", is_start_block=True) + if_region.add_branch(dace.sdfg.state.CodeBlock(f"not ({cond})"), else_body) + b_state.add_nedge(b_state.add_access(b), b_state.add_access(out), dace.Memlet(out)) + + state = sdfg.add_state() + nsdfg_node = state.add_nested_sdfg(nsdfg, sdfg, inputs={a, b, cond}, outputs={out}) + me, mx = state.add_map('map', dict(i="0:4")) + state.add_memlet_path(state.add_access(A), me, nsdfg_node, dst_conn=a, memlet=dace.Memlet(f"{A}[i]")) + state.add_memlet_path(state.add_access(B), me, nsdfg_node, dst_conn=b, memlet=dace.Memlet(f"{B}[i]")) + state.add_memlet_path(state.add_access(COND), me, nsdfg_node, dst_conn=cond, memlet=dace.Memlet(f"{COND}[i]")) + state.add_memlet_path(nsdfg_node, mx, state.add_access(OUT), src_conn=out, memlet=dace.Memlet(f"{OUT}[i]")) + + assert 0 == sdfg.apply_transformations_repeated(PruneConnectors) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--N", default=64) @@ -431,3 +474,4 @@ def test_read_write(): test_prune_connectors_with_dependencies() test_read_write_1() test_read_write_2() + test_prune_connectors_with_conditional_block()