Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Analysis passes for access range analysis #1484

Merged
merged 3 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions dace/memlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,7 @@ def to_json(self):
attrs['is_data_src'] = self._is_data_src

# Fill in legacy (DEPRECATED) values for backwards compatibility
attrs['num_accesses'] = \
str(self.volume) if not self.dynamic else -1
attrs['num_accesses'] = str(self.volume) if not self.dynamic else -1

return {"type": "Memlet", "attributes": attrs}

Expand Down Expand Up @@ -421,13 +420,11 @@ def from_array(dataname, datadesc, wcr=None):
return Memlet.simple(dataname, rng, wcr_str=wcr)

def __hash__(self):
return hash((self.volume, self.src_subset, self.dst_subset, str(self.wcr)))
return hash((self.data, self.volume, self.src_subset, self.dst_subset, str(self.wcr)))

def __eq__(self, other):
return all([
self.volume == other.volume, self.src_subset == other.src_subset, self.dst_subset == other.dst_subset,
self.wcr == other.wcr
])
return all((self.data == other.data, self.volume == other.volume, self.src_subset == other.src_subset,
self.dst_subset == other.dst_subset, self.wcr == other.wcr))

def replace(self, repl_dict):
"""
Expand Down
80 changes: 79 additions & 1 deletion dace/transformation/passes/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections import defaultdict
from dace.transformation import pass_pipeline as ppl
from dace import SDFG, SDFGState, properties, InterstateEdge
from dace import SDFG, SDFGState, properties, InterstateEdge, Memlet, data as dt
from dace.sdfg.graph import Edge
from dace.sdfg import nodes as nd
from dace.sdfg.analysis import cfg
Expand Down Expand Up @@ -505,3 +505,81 @@
del result[desc][write]
top_result[sdfg.sdfg_id] = result
return top_result


@properties.make_properties
class AccessRanges(ppl.Pass):
"""
For each data descriptor, finds all memlets used to access it (read/write ranges).
"""

CATEGORY: str = 'Analysis'

def modifies(self) -> ppl.Modifies:
return ppl.Modifies.Nothing

Check warning on line 519 in dace/transformation/passes/analysis.py

View check run for this annotation

Codecov / codecov/patch

dace/transformation/passes/analysis.py#L519

Added line #L519 was not covered by tests

def should_reapply(self, modified: ppl.Modifies) -> bool:
return modified & ppl.Modifies.Memlets

Check warning on line 522 in dace/transformation/passes/analysis.py

View check run for this annotation

Codecov / codecov/patch

dace/transformation/passes/analysis.py#L522

Added line #L522 was not covered by tests

def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Memlet]]]:
"""
:return: A dictionary mapping each data descriptor name to a set of memlets.
"""
top_result: Dict[int, Dict[str, Set[Memlet]]] = dict()

for sdfg in top_sdfg.all_sdfgs_recursive():
result: Dict[str, Set[Memlet]] = defaultdict(set)
for state in sdfg.states():
for anode in state.data_nodes():
for e in state.all_edges(anode):
if e.dst is anode and e.dst_conn == 'set': # Skip reference sets
continue
if e.data.is_empty(): # Skip empty memlets
continue
# Find (hopefully propagated) root memlet
e = state.memlet_tree(e).root().edge
result[anode.data].add(e.data)
top_result[sdfg.sdfg_id] = result
return top_result


@properties.make_properties
class FindReferenceSources(ppl.Pass):
"""
For each Reference data descriptor, finds all memlets used to set it. If a Tasklet was used
to set the reference, the Tasklet is given as a source.
"""

CATEGORY: str = 'Analysis'

def modifies(self) -> ppl.Modifies:
return ppl.Modifies.Nothing

Check warning on line 556 in dace/transformation/passes/analysis.py

View check run for this annotation

Codecov / codecov/patch

dace/transformation/passes/analysis.py#L556

Added line #L556 was not covered by tests

def should_reapply(self, modified: ppl.Modifies) -> bool:
return modified & ppl.Modifies.Memlets

Check warning on line 559 in dace/transformation/passes/analysis.py

View check run for this annotation

Codecov / codecov/patch

dace/transformation/passes/analysis.py#L559

Added line #L559 was not covered by tests

def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Union[Memlet, nd.CodeNode]]]]:
"""
:return: A dictionary mapping each data descriptor name to a set of memlets.
"""
top_result: Dict[int, Dict[str, Set[Union[Memlet, nd.CodeNode]]]] = dict()

for sdfg in top_sdfg.all_sdfgs_recursive():
result: Dict[str, Set[Memlet]] = defaultdict(set)
reference_descs = set(k for k, v in sdfg.arrays.items() if isinstance(v, dt.Reference))
for state in sdfg.states():
for anode in state.data_nodes():
if anode.data not in reference_descs:
continue
for e in state.in_edges(anode):
if e.dst_conn != 'set':
continue
true_src = state.memlet_path(e)[0].src
if isinstance(true_src, nd.CodeNode):
# Code -> Reference
result[anode.data].add(true_src)
else:
# Array -> Reference
result[anode.data].add(e.data)
top_result[sdfg.sdfg_id] = result
return top_result
61 changes: 61 additions & 0 deletions tests/passes/access_ranges_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
""" Tests the AccessRanges analysis pass. """
import dace
from dace.transformation.passes.analysis import AccessRanges
import numpy as np

N = dace.symbol('N')


def test_simple():

@dace.program
def tester(A: dace.float64[N, N], B: dace.float64[20, 20]):
for i, j in dace.map[0:20, 0:N]:
A[i, j] = 1

sdfg = tester.to_sdfg(simplify=True)
ranges = AccessRanges().apply_pass(sdfg, {})
assert len(ranges) == 1 # Only one SDFG
ranges = ranges[0]
assert len(ranges) == 1 # Only one array is accessed

# Construct write memlet
memlet = dace.Memlet('A[0:20, 0:N]')
memlet._is_data_src = False

assert ranges['A'] == {memlet}


def test_simple_ranges():

@dace.program
def tester(A: dace.float64[N, N], B: dace.float64[20, 20]):
A[:, :] = 0
A[1:21, 1:21] = B
A[0, 0] += 1

sdfg = tester.to_sdfg(simplify=True)
ranges = AccessRanges().apply_pass(sdfg, {})
assert len(ranges) == 1 # Only one SDFG
ranges = ranges[0]
assert len(ranges) == 2 # Two arrays are accessed

assert len(ranges['B']) == 1
assert next(iter(ranges['B'])).src_subset == dace.subsets.Range([(0, 19, 1), (0, 19, 1)])

# Construct read/write memlets
memlet1 = dace.Memlet('A[0:N, 0:N]')
memlet1._is_data_src = False
memlet2 = dace.Memlet('A[1:21, 1:21] -> 0:20, 0:20')
memlet2._is_data_src = False
memlet3 = dace.Memlet('A[0, 0]')
memlet4 = dace.Memlet('A[0, 0]')
memlet4._is_data_src = False

assert ranges['A'] == {memlet1, memlet2, memlet3, memlet4}


if __name__ == '__main__':
test_simple()
test_simple_ranges()
21 changes: 19 additions & 2 deletions tests/sdfg/reference_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
""" Tests the use of Reference data descriptors. """
import dace
from dace.transformation.passes.analysis import FindReferenceSources
import numpy as np


def test_reference_branch():
def _create_branch_sdfg():
sdfg = dace.SDFG('refbranch')
sdfg.add_array('A', [20], dace.float64)
sdfg.add_array('B', [20], dace.float64)
Expand All @@ -29,6 +30,11 @@ def test_reference_branch():
r = finish.add_read('ref')
w = finish.add_write('out')
finish.add_nedge(r, w, dace.Memlet('ref'))
return sdfg


def test_reference_branch():
sdfg = _create_branch_sdfg()

A = np.random.rand(20)
B = np.random.rand(20)
Expand All @@ -41,5 +47,16 @@ def test_reference_branch():
assert np.allclose(out, A)


def test_reference_sources_pass():
sdfg = _create_branch_sdfg()
sources = FindReferenceSources().apply_pass(sdfg, {})
assert len(sources) == 1 # There is only one SDFG
sources = sources[0]
assert len(sources) == 1 and 'ref' in sources # There is one reference
sources = sources['ref']
assert sources == {dace.Memlet('A[0:20]', volume=1), dace.Memlet('B[0:20]', volume=1)}


if __name__ == '__main__':
test_reference_branch()
test_reference_sources_pass()
Loading