Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constant Propagation (partially) Fails If NestedSDFG Has Multiple States #1817

Open
philip-paul-mueller opened this issue Dec 6, 2024 · 1 comment

Comments

@philip-paul-mueller
Copy link
Collaborator

philip-paul-mueller commented Dec 6, 2024

I found a bug in constant propagation, that I am unable to fix.
The bug is related to nested SDFGs that have multiple states, it works if it has only one state.
My main test case is a simple Map, its range is N, which is also the size of the data containers.
Furthermore, there is also the the symbol lim_area that effects the output of the mapped Tasklets.

I needed it to replace some symbols with constant values.
So I looked at ConstantPropagation and saw that it offers the initial_symbols argument, which I use to pass the value mapping, i.e. {'lim_area': True, 'N': 10}.
When I run this on an SDFG, containing a NestedSDFG that has an SDFG with only one state, the Map is replaced as I expected it:
single_state_nested

However, if the NestedSDFG contains multiple states, then only one variable, lim_area is replaced.
multiple_state_nested

What is interesting that, before the symbol mapping was {'lim_area': 'lim_area', 'N': 'N'}, but after constant propagation it is {'N': '10'}, so instead of removing the entry, as it did for lim_area, it replaced it with '10', which is a string (I passed the integer value).

As a side note if the Nested SDFG had only one state the symbol map is empty.
This means that the transformation only partially fails, as it is able to fully replace lim_area.

This is the reproducer:

from typing import Any, Final, Iterable, Optional, TypeAlias, Union, Literal, overload

import dace
from dace import (
    data as dace_data,
    properties as dace_properties,
    subsets as dace_subsets,
    transformation as dace_transformation,
    nodes as dace_nodes
)
from dace.sdfg import nodes as dace_nodes
from dace.transformation import (
    dataflow as dace_dataflow,
    pass_pipeline as dace_ppl,
    passes as dace_passes,
)


@overload
def count_nodes(
    graph: Union[dace.SDFG, dace.SDFGState],
    node_type: tuple[type, ...] | type,
    return_nodes: Literal[False],
) -> int: ...


@overload
def count_nodes(
    graph: Union[dace.SDFG, dace.SDFGState],
    node_type: tuple[type, ...] | type,
    return_nodes: Literal[True],
) -> list[dace_nodes.Node]: ...


def count_nodes(
    graph: Union[dace.SDFG, dace.SDFGState],
    node_type: tuple[type, ...] | type,
    return_nodes: bool = False,
) -> Union[int, list[dace_nodes.Node]]:
    """Counts the number of nodes in of a particular type in `graph`.

    If `graph` is an SDFGState then only count the nodes inside this state,
    but if `graph` is an SDFG count in all states.

    Args:
        graph: The graph to scan.
        node_type: The type or sequence of types of nodes to look for.
    """

    states = graph.states() if isinstance(graph, dace.SDFG) else [graph]
    found_nodes: list[dace_nodes.Node] = []
    for state_nodes in states:
        for node in state_nodes.nodes():
            if isinstance(node, node_type):
                found_nodes.append(node)
    if return_nodes:
        return found_nodes
    return len(found_nodes)


def check_shapes(
        sdfg: dace.SDFG,
        expected_shape: tuple[str, ...],
        to_string: bool = True,
) -> bool:
    return all(tuple((str(s) if to_string else s) for s in desc.shape) == expected_shape for desc in sdfg.arrays.values())


def check_maps(
        sdfg: dace.SDFG,
        expected_end: str,
) -> bool:
    map_entries: list[dace_nodes.MapEntry] = count_nodes(
            graph=sdfg,
            node_type=dace_nodes.MapEntry,
            return_nodes=True,
    )
    return all(str(map_entry.map.range.ranges[0][1] + 1) == expected_end for map_entry in map_entries)


def check_tasklets(
        sdfg: dace.SDFG,
        expected_symbols: Optional[set[str]] = None,
        forbidden_symbols: Optional[set[str]] = None,
) -> bool:
    assert not ((expected_symbols is None) and (forbidden_symbols is None))
    expected_symbols = expected_symbols or set()
    forbidden_symbols = forbidden_symbols or set()

    tasklets: list[dace_nodes.Tasklet] = count_nodes(
            graph=sdfg,
            node_type=dace_nodes.Tasklet,
            return_nodes=True,
    )
    if not all(expected_symbols.issubset(tasklet.free_symbols) for tasklet in tasklets):
        return False
    if not all(forbidden_symbols.isdisjoint(tasklet.free_symbols) for tasklet in tasklets):
        return False
    return True


def make_multi_state_sdfg() -> dace.SDFG:
    sdfg = dace.SDFG("multi_state_sdfg")
    state = sdfg.add_state("stateS", is_start_block=True)
    sdfg.add_symbol("N", dace.int64)
    sdfg.add_symbol("lim_area", dace.bool_)
    for name in "AB":
        sdfg.add_array(name, shape=("N",), dtype=dace.float64, transient=False)

    stateT = sdfg.add_state("stateT", is_start_block=False)
    stateT.add_mapped_tasklet(
        "Tcomp",
        map_ranges={"__i": "0:N"},
        inputs={"__in": dace.Memlet("A[__i]")},
        code="__out = (__in +  2 * N) if lim_area else (__in - 3 * N)",
        outputs={"__out": dace.Memlet("B[__i]")},
        external_edges=True,
    )

    stateF = sdfg.add_state("stateF", is_start_block=False)
    stateF.add_mapped_tasklet(
        "Fcomp",
        map_ranges={"__i": "0:N"},
        inputs={"__in": dace.Memlet("A[__i]")},
        code="__out = (__in +  3 * N) if lim_area else (__in - 4 * N)",
        outputs={"__out": dace.Memlet("B[__i]")},
        external_edges=True,
    )

    stateJ = sdfg.add_state("stateJ", is_start_block=False)
    sdfg.add_edge(
        state,
        stateT,
        dace.InterstateEdge(condition="lim_area")
    )
    sdfg.add_edge(
        state,
        stateF,
        dace.InterstateEdge(condition="not lim_area")
    )
    sdfg.add_edge(
        stateT,
        stateJ,
        dace.InterstateEdge()
    )
    sdfg.add_edge(
        stateF,
        stateJ,
        dace.InterstateEdge()
    )
    sdfg.validate()
    return sdfg


def make_single_state_sdfg() -> dace.SDFG:
    sdfg = dace.SDFG("single_state_sdfg")
    state = sdfg.add_state(is_start_block=True)
    sdfg.add_symbol("N", dace.int64)
    sdfg.add_symbol("lim_area", dace.bool_)
    for name in "AB":
        sdfg.add_array(name, shape=("N",), dtype=dace.float64, transient=False)

    state.add_mapped_tasklet(
        "PreComp",
        map_ranges={"__i": "0:N"},
        inputs={"__in": dace.Memlet("A[__i]")},
        code="__out = (__in + N) if lim_area else (__in - N)",
        outputs={"__out": dace.Memlet("B[__i]")},
        external_edges=True,
    )
    sdfg.validate()
    return sdfg


def make_single_state_with_two_maps_sdfg() -> dace.SDFG:
    sdfg = dace.SDFG("single_state_sdfg")
    state = sdfg.add_state(is_start_block=True)
    sdfg.add_symbol("N", dace.int64)
    sdfg.add_symbol("lim_area", dace.bool_)
    for name in "ABT":
        sdfg.add_array(name, shape=("N",), dtype=dace.float64, transient=False)
    sdfg.arrays["T"].transient = True

    T = state.add_access("T")

    state.add_mapped_tasklet(
        "comp1",
        map_ranges={"__i": "0:N"},
        inputs={"__in": dace.Memlet("A[__i]")},
        code="__out = (__in + N) if lim_area else (__in - N)",
        outputs={"__out": dace.Memlet("T[__i]")},
        output_nodes={T},
        external_edges=True,
    )
    state.add_mapped_tasklet(
        "comp2",
        map_ranges={"__i": "0:N"},
        inputs={"__in": dace.Memlet("T[__i]")},
        code="__out = (__in + 7 * N) if lim_area else (__in - 4 * N)",
        outputs={"__out": dace.Memlet("B[__i]")},
        input_nodes={T},
        external_edges=True,
    )
    sdfg.validate()
    return sdfg


def make_wrapped_sdfg(
        single_state: bool,
) -> tuple[dace.SDFG, dace_nodes.NestedSDFG]:
    sdfg = dace.SDFG("wrapped_sdfg")
    state = sdfg.add_state("wrap_state", is_start_block=True)
    sdfg.add_symbol("lim_area", dace.bool_)
    sdfg.add_symbol("N", dace.bool_)
    for name in "AB":
        sdfg.add_array(name, shape=("N",), dtype=dace.float64, transient=False)

    inner_sdfg = make_single_state_sdfg() if single_state else make_multi_state_sdfg()
    nsdfg = state.add_nested_sdfg(
        sdfg=inner_sdfg,
        parent=sdfg,
        inputs={"A"},
        outputs={"B"},
        symbol_mapping={"lim_area": "lim_area", "N": "N"},
    )
    state.add_edge(
        state.add_access("A"),
        None,
        nsdfg,
        "A",
        dace.Memlet.from_array("A", sdfg.arrays["A"])
    )
    state.add_edge(
        nsdfg,
        "B",
        state.add_access("B"),
        None,
        dace.Memlet.from_array("B", sdfg.arrays["B"])
    )
    sdfg.validate()
    return sdfg, nsdfg


def gt_substitute_compiletime_symbols(
    sdfg: dace.SDFG,
    repl: dict[str, Any],
    validate: bool = False,
    validate_all: bool = False,
) -> None:
    const_prop = dace_passes.ConstantPropagation()
    const_prop.recursive = True
    const_prop.progress = False

    const_prop.apply_pass(
        sdfg=sdfg,
        initial_symbols=repl,
        _=None,
    )
    if validate_all:
        sdfg.validate()
    return 


def test_nested_sdfg_with_single_state():
    sdfg, nested_sdfg = make_wrapped_sdfg(single_state=True)
    assert check_shapes(sdfg, ("N",))
    assert check_shapes(nested_sdfg.sdfg, ("N",))
    assert check_maps(nested_sdfg.sdfg, "N")
    assert check_tasklets(nested_sdfg.sdfg, expected_symbols={"N", "lim_area"})

    repl = {"N": 10, "lim_area": True}
    gt_substitute_compiletime_symbols(sdfg, repl)

    assert check_shapes(sdfg, (10,), to_string=False)
    assert check_shapes(nested_sdfg.sdfg, ("10",))
    assert check_maps(nested_sdfg.sdfg, "10")
    assert check_tasklets(nested_sdfg.sdfg, forbidden_symbols={"N", "lim_area"})
    assert len(nested_sdfg.symbol_mapping) == 0


def test_nested_sdfg_with_multiple_states():
    sdfg, nested_sdfg = make_wrapped_sdfg(single_state=False)
    assert check_shapes(sdfg, ("N",))
    assert check_shapes(nested_sdfg.sdfg, ("N",))
    assert check_maps(nested_sdfg.sdfg, "N")
    assert check_tasklets(nested_sdfg.sdfg, expected_symbols={"N", "lim_area"})

    repl = {"N": 10, "lim_area": True}

    gt_substitute_compiletime_symbols(sdfg, repl)

    assert check_shapes(sdfg, (10,), to_string=False)
    assert check_shapes(nested_sdfg.sdfg, ("10",))
    assert check_maps(nested_sdfg.sdfg, "10")
    assert check_tasklets(nested_sdfg.sdfg, forbidden_symbols={"N", "lim_area"})
    assert len(nested_sdfg.symbol_mapping) == 0


def test_single_state_top_sdfg():
    # This test works because everything is inside a single state.
    sdfg = make_single_state_sdfg()
    assert sdfg.number_of_nodes() == 1

    assert check_maps(sdfg, "N")
    assert check_shapes(sdfg, ("N",))
    assert check_tasklets(sdfg, expected_symbols={"N", "lim_area"})

    repl = {"N": 10, "lim_area": True}
    gt_substitute_compiletime_symbols(sdfg, repl)

    assert check_maps(sdfg, "10")
    assert check_shapes(sdfg, (10,), to_string=False)
    assert check_tasklets(sdfg, forbidden_symbols={"N", "lim_area"})


def test_single_state_with_two_maps():
    # This test works because everything is inside a single state.
    sdfg = make_single_state_with_two_maps_sdfg()
    assert sdfg.number_of_nodes() == 1

    assert check_maps(sdfg, "N")
    assert check_shapes(sdfg, ("N",))
    assert check_tasklets(sdfg, expected_symbols={"N", "lim_area"})

    repl = {"N": 10, "lim_area": True}
    gt_substitute_compiletime_symbols(sdfg, repl)

    assert check_maps(sdfg, "10")
    assert check_shapes(sdfg, (10,), to_string=False)
    assert check_tasklets(sdfg, forbidden_symbols={"N", "lim_area"})


def test_multi_state_top_sdfg():
    sdfg = make_multi_state_sdfg()
    assert sdfg.number_of_nodes() == 4

    start_state: dace.SDFGState = sdfg.start_state
    assert start_state.label == "stateS"
    assert all("lim_area" in edge.data.free_symbols for edge in sdfg.out_edges(start_state))

    assert check_maps(sdfg, "N")
    assert check_shapes(sdfg, ("N",))

    tasklets: list[dace_nodes.Tasklet] = count_nodes(
            graph=sdfg,
            node_type=dace_nodes.Tasklet,
            return_nodes=True,
    )
    assert all({"N", "lim_area"}.issubset(tasklet.free_symbols) for tasklet in tasklets)

    repl = {"N": 10, "lim_area": True}
    gt_substitute_compiletime_symbols(sdfg, repl)

    assert check_maps(sdfg, "10")
    assert check_shapes(sdfg, (10,), to_string=False)
    # The edges constituting the state if, shall no longer depend on `lim_area` as it was replaced.
    assert not any("lim_area" in edge.data.free_symbols for edge in sdfg.out_edges(start_state))
    # The tasklets shall no longer be dependent on `N` or `lim_area` as we have replaced them.
    assert all({"N", "lim_area"}.isdisjoint(tasklet.free_symbols) for tasklet in tasklets)


def test_single_stet_nested_with_top_map():
    sdfg, nested_sdfg = make_wrapped_sdfg(single_state=True)
    assert sdfg.number_of_nodes() == 1
    state: dace.SDFGState = list(sdfg.states())[0]

    sdfg.add_datadesc("new_input", sdfg.arrays["A"].clone())
    sdfg.arrays["A"].transient = True
    A: dace_nodes.AccessNode = next(iter(dnode for dnode in state.data_nodes() if dnode.data == "A"))
    state.add_mapped_tasklet(
        "compOutside",
        map_ranges={"__i": "0:N"},
        inputs={"__in": dace.Memlet("new_input[__i]")},
        code="__out = (__in + 10 * N) if lim_area else (__in - 14 * N)",
        outputs={"__out": dace.Memlet("A[__i]")},
        output_nodes={A},
        external_edges=True,
    )
    sdfg.validate()

    assert check_maps(sdfg, "N")
    assert check_shapes(sdfg, ("N",))
    assert check_tasklets(sdfg, expected_symbols={"N", "lim_area"})
    assert check_shapes(nested_sdfg.sdfg, ("N",))
    assert check_maps(nested_sdfg.sdfg, "N")
    assert check_tasklets(nested_sdfg.sdfg, expected_symbols={"N", "lim_area"})

    repl = {"N": 10, "lim_area": True}
    gt_substitute_compiletime_symbols(sdfg, repl)

    assert check_maps(sdfg, "10")
    assert check_shapes(sdfg, (10,), to_string=False)
    assert check_tasklets(sdfg, forbidden_symbols={"N", "lim_area"})
    assert check_shapes(nested_sdfg.sdfg, ("10",))
    assert check_maps(nested_sdfg.sdfg, "10")
    assert check_tasklets(nested_sdfg.sdfg, forbidden_symbols={"N", "lim_area"})
    assert len(nested_sdfg.symbol_mapping) == 0
@philip-paul-mueller
Copy link
Collaborator Author

I updated the reproducer, it can now be used as unit tests afterwards.

I extended the test and it seems that it fails as soon as there is an SDFG with multiple states, however, some symbols are correctly replaced.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant