Skip to content

Commit

Permalink
Fix data dependent while loop generation
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Dec 13, 2023
1 parent d26c507 commit b83d05d
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 52 deletions.
1 change: 1 addition & 0 deletions dace/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def generate_code(sdfg, validate=True) -> List[CodeObject]:
# Convert any loop constructs with hierarchical loop regions into simple 1-level state machine loops.
# TODO (later): Adapt codegen to deal with hierarchical CFGs instead.
sdutils.inline_loop_blocks(sdfg)
sdutils.inline_control_flow_regions(sdfg)

# Before generating the code, run type inference on the SDFG connectors
infer_types.infer_connector_types(sdfg)
Expand Down
76 changes: 53 additions & 23 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2400,36 +2400,66 @@ def _is_test_simple(self, node: ast.AST):
return all(self._is_test_simple(value) for value in node.values)
return is_test_simple

def _visit_test(self, node: ast.Expr):
def _visit_complex_test(self, node: ast.Expr):
test_region = ControlFlowRegion('%s_%s' % ('cond_prep', node.lineno), self.sdfg)
inner_start = test_region.add_state('%s_start_%s' % ('cond_prep', node.lineno))

p_last_cfg_target, p_last_block, p_target = self.last_cfg_target, self.last_block, self.cfg_target
self.cfg_target, self.last_block, self.last_cfg_target = test_region, inner_start, test_region

parsed_node = self.visit(node)
if isinstance(parsed_node, (list, tuple)) and len(parsed_node) == 1:
parsed_node = parsed_node[0]
if isinstance(parsed_node, str) and parsed_node in self.sdfg.arrays:
datadesc = self.sdfg.arrays[parsed_node]
if isinstance(datadesc, data.Array):
parsed_node += '[0]'

self.last_cfg_target, self.last_block, self.cfg_target = p_last_cfg_target, p_last_block, p_target

return parsed_node, test_region

def _visit_test(self, node: ast.Expr) -> Tuple[str, str, bool]:
is_test_simple = self._is_test_simple(node)

# Visit test-condition
if not is_test_simple:
parsed_node = self.visit(node)
if isinstance(parsed_node, (list, tuple)) and len(parsed_node) == 1:
parsed_node = parsed_node[0]
if isinstance(parsed_node, str) and parsed_node in self.sdfg.arrays:
datadesc = self.sdfg.arrays[parsed_node]
if isinstance(datadesc, data.Array):
parsed_node += '[0]'
parsed_node, test_region = self._visit_complex_test(node)
self.cfg_target.add_node(test_region)
self._on_block_added(test_region)
else:
parsed_node = astutils.unparse(node)
test_region = None

# Generate conditions
cond = astutils.unparse(parsed_node)
cond_else = astutils.unparse(astutils.negate_expr(parsed_node))

return cond, cond_else
return cond, cond_else, test_region

def visit_While(self, node: ast.While):
# Get loop condition expression
loop_cond, _ = self._visit_test(node.test)
# Get loop condition expression and create the necessary states for it.
loop_cond, _, test_region = self._visit_test(node.test)
loop_region = self._add_loop_region(loop_cond, label=f'while_{node.lineno}', inverted=False)

# Parse body
self._recursive_visit(node.body, f'while_{node.lineno}', node.lineno, parent=loop_region,
unconnected_last_block=False)

if test_region is not None:
iter_end_blocks = set()
iter_end_blocks.update(loop_region.continue_states)
for inner_node in loop_region.nodes():
if loop_region.out_degree(inner_node) == 0:
iter_end_blocks.add(inner_node)
loop_region.continue_states = set()

test_region_copy = copy.deepcopy(test_region)
loop_region.add_node(test_region_copy)

for block in iter_end_blocks:
loop_region.add_edge(block, test_region_copy, dace.InterstateEdge())

# Add symbols from test as necessary
symcond = pystr_to_symbolic(loop_cond)
if symbolic.issymbolic(symcond):
Expand All @@ -2455,32 +2485,32 @@ def visit_While(self, node: ast.While):
self.last_block = loop_region

def visit_Break(self, node: ast.Break):
if not isinstance(self.cfg_target, LoopRegion):
error_msg = "'break' is only supported inside for and while loops "
if isinstance(self.cfg_target, LoopRegion):
self.cfg_target.break_states.append(self.last_block)
else:
error_msg = "'break' is only supported inside loops "
if self.nested:
error_msg += ("('break' is not supported in Maps and cannot be "
" used in nested DaCe program calls to break out "
" of loops of outer scopes)")
error_msg += ("('break' is not supported in Maps and cannot be used in nested DaCe program calls to "
" break out of loops of outer scopes)")
raise DaceSyntaxError(self, node, error_msg)
self.cfg_target.break_states.append(self.last_block)

def visit_Continue(self, node: ast.Continue):
if not isinstance(self.cfg_target, LoopRegion):
error_msg = ("'continue' is only supported inside for and while loops ")
if isinstance(self.cfg_target, LoopRegion):
self.cfg_target.continue_states.append(self.last_block)
else:
error_msg = ("'continue' is only supported inside loops ")
if self.nested:
error_msg += ("('continue' is not supported in Maps and cannot "
" be used in nested DaCe program calls to "
error_msg += ("('continue' is not supported in Maps and cannot be used in nested DaCe program calls to "
" continue loops of outer scopes)")
raise DaceSyntaxError(self, node, error_msg)
self.cfg_target.continue_states.append(self.last_block)

def visit_If(self, node: ast.If):
# Add a guard state
self._add_state('if_guard')
self.last_block.debuginfo = self.current_lineinfo

# Generate conditions
cond, cond_else = self._visit_test(node.test)
cond, cond_else, _ = self._visit_test(node.test)

# Visit recursively
laststate, first_if_state, last_if_state, return_stmt = \
Expand Down
1 change: 1 addition & 0 deletions dace/frontend/python/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF

if not self.use_experimental_cfg_blocks:
sdutils.inline_loop_blocks(sdfg)
sdutils.inline_control_flow_regions(sdfg)

# Apply simplification pass automatically
if not cached and (simplify == True or
Expand Down
1 change: 1 addition & 0 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2181,6 +2181,7 @@ def compile(self, output_file=None, validate=True) -> 'CompiledSDFG':
# Convert any loop constructs with hierarchical loop regions into simple 1-level state machine loops.
# TODO (later): Adapt codegen to deal with hierarchical CFGs instead.
sdutils.inline_loop_blocks(sdfg)
sdutils.inline_control_flow_regions(sdfg)

# Rename SDFG to avoid runtime issues with clashing names
index = 0
Expand Down
51 changes: 25 additions & 26 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,31 @@ def __str__(self):
def __repr__(self) -> str:
return f'ControlFlowBlock ({self.label})'

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k in ('_parent_graph', '_sdfg'): # Skip derivative attributes
continue
setattr(result, k, copy.deepcopy(v, memo))

for k in ('_parent_graph', '_sdfg'):
if id(getattr(self, k)) in memo:
setattr(result, k, memo[id(getattr(self, k))])
else:
setattr(result, k, None)

for node in result.nodes():
if isinstance(node, nd.NestedSDFG):
try:
node.sdfg.parent = result
except AttributeError:
# NOTE: There are cases where a NestedSDFG does not have `sdfg` attribute.
# TODO: Investigate why this happens.
pass
return result

@property
def label(self) -> str:
return self._label
Expand Down Expand Up @@ -1192,31 +1217,6 @@ def __init__(self, label=None, sdfg=None, debuginfo=None, location=None):
self.location = location if location is not None else {}
self._default_lineinfo = None

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k in ('_parent_graph', '_sdfg'): # Skip derivative attributes
continue
setattr(result, k, copy.deepcopy(v, memo))

for k in ('_parent_graph', '_sdfg'):
if id(getattr(self, k)) in memo:
setattr(result, k, memo[id(getattr(self, k))])
else:
setattr(result, k, None)

for node in result.nodes():
if isinstance(node, nd.NestedSDFG):
try:
node.sdfg.parent = result
except AttributeError:
# NOTE: There are cases where a NestedSDFG does not have `sdfg` attribute.
# TODO: Investigate why this happens.
pass
return result

@property
def parent(self):
""" Returns the parent SDFG of this state. """
Expand Down Expand Up @@ -2459,7 +2459,6 @@ def add_state_after(self, state: SDFGState, label=None, is_start_state=False) ->
self.add_edge(state, new_state, dace.sdfg.InterstateEdge())
return new_state

@abc.abstractmethod
def _used_symbols_internal(self,
all_symbols: bool,
defined_syms: Optional[Set] = None,
Expand Down
31 changes: 30 additions & 1 deletion dace/sdfg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dace.sdfg.graph import MultiConnectorEdge
from dace.sdfg.sdfg import SDFG
from dace.sdfg.nodes import Node, NestedSDFG
from dace.sdfg.state import SDFGState, StateSubgraphView, LoopRegion, ControlFlowBlock, GraphT
from dace.sdfg.state import SDFGState, StateSubgraphView, LoopRegion, ControlFlowBlock, ControlFlowRegion, GraphT
from dace.sdfg.scope import ScopeSubgraphView
from dace.sdfg import nodes as nd, graph as gr, propagation
from dace import config, data as dt, dtypes, memlet as mm, subsets as sbs, symbolic
Expand Down Expand Up @@ -1276,6 +1276,35 @@ def inline_loop_blocks(sdfg: SDFG, permissive: bool = False, progress: bool = No
return counter


def inline_control_flow_regions(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int:
# Avoid import loops
from dace.transformation.interstate import ControlFlowRegionInline

counter = 0
blocks = [(n, p) for n, p in sdfg.all_nodes_recursive()
if isinstance(n, ControlFlowRegion) and not isinstance(n, LoopRegion)]

for _block, _graph in optional_progressbar(reversed(blocks), title='Inlining control flow blocks',
n=len(blocks), progress=progress):
block: ControlFlowBlock = _block
graph: GraphT = _graph
id = block.sdfg.sdfg_id

# We have to reevaluate every time due to changing IDs
block_id = graph.node_id(block)

candidate = {
ControlFlowRegionInline.region: block,
}
inliner = ControlFlowRegionInline()
inliner.setup_match(graph, id, block_id, candidate, 0, override=True)
if inliner.can_be_applied(graph, 0, block.sdfg, permissive=permissive):
inliner.apply(graph, block.sdfg)
counter += 1

return counter


def inline_sdfgs(sdfg: SDFG, permissive: bool = False, progress: bool = None, multistate: bool = True) -> int:
"""
Inlines all possible nested SDFGs (or sub-SDFGs) using an optimized
Expand Down
2 changes: 1 addition & 1 deletion dace/transformation/interstate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
""" This module initializes the inter-state transformations package."""

from .control_flow_inline import LoopRegionInline
from .control_flow_inline import LoopRegionInline, ControlFlowRegionInline
from .state_fusion import StateFusion
from .state_fusion_with_happens_before import StateFusionExtended
from .state_elimination import (EndStateElimination, StartStateElimination, StateAssignElimination,
Expand Down
56 changes: 55 additions & 1 deletion dace/transformation/interstate/control_flow_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,63 @@
from dace.transformation import transformation


class ControlFlowRegionInline(transformation.MultiStateTransformation):
"""
Inlines a control flow region into a single state machine.
"""

region = transformation.PatternNode(ControlFlowRegion)

@staticmethod
def annotates_memlets():
return False

@classmethod
def expressions(cls):
return [sdutil.node_path_graph(cls.region)]

def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool:
if isinstance(self.region, LoopRegion):
return False
return True

def apply(self, graph: ControlFlowRegion, sdfg: SDFG) -> Optional[int]:
parent: ControlFlowRegion = graph

internal_start = self.region.start_block

end_state = parent.add_state(self.region.label + '_end')

# Add all region states and make sure to keep track of all the ones that need to be connected in the end.
to_connect: Set[SDFGState] = set()
for node in self.region.nodes():
parent.add_node(node)
if self.region.out_degree(node) == 0:
to_connect.add(node)

# Add all region edges.
for edge in self.region.edges():
parent.add_edge(edge.src, edge.dst, edge.data)

# Redirect all edges to the region to the internal start state.
for b_edge in parent.in_edges(self.region):
parent.add_edge(b_edge.src, internal_start, b_edge.data)
parent.remove_edge(b_edge)
# Redirect all edges exiting the region to instead exit the end state.
for a_edge in parent.out_edges(self.region):
parent.add_edge(end_state, a_edge.dst, a_edge.data)
parent.remove_edge(a_edge)

for node in to_connect:
parent.add_edge(node, end_state, InterstateEdge())

# Remove the original loop.
parent.remove_node(self.region)


class LoopRegionInline(transformation.MultiStateTransformation):
"""
Inlines a loop regions into a single state machine.
Inlines a loop region into a single state machine.
"""

loop = transformation.PatternNode(LoopRegion)
Expand Down

0 comments on commit b83d05d

Please sign in to comment.