Skip to content

Commit

Permalink
MAINT: refactored events
Browse files Browse the repository at this point in the history
  • Loading branch information
cpelley committed Nov 8, 2024
1 parent db89c57 commit 6d33118
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 34 deletions.
47 changes: 20 additions & 27 deletions dagrunner/execute_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,55 +19,49 @@
from dagrunner.runner.schedulers import SCHEDULERS
from dagrunner.utils import (
CaptureProcMemory,
Singleton,
TimeIt,
as_iterable,
function_to_argparse_parse_args,
logger,
as_iterable,
Singleton
)
from dagrunner.utils.networkx import visualise_graph


class _SKIP_EVENT(Singleton):
"""
A plugin that returns a 'SKIP_EVENT' will cause `plugin_executor` to skip execution
of all descendant node execution.
"""

_instance = None

class _EventBase:
def __repr__(self):
return "SKIP_EVENT"
# Ensures easy identification when printing/logging.
return self.__class__.__name__.upper()

def __hash__(self):
return hash("SKIP_EVENT")
# Ensures that can be used as keys in dictionaries or stored as sets.
return hash(self.__class__.__name__.upper())

def __reduce__(self):
# Ensures that can be serialised and deserialised using pickle.
return (self.__class__, ())


SKIP_EVENT = _SKIP_EVENT()


class _IGNORE_EVENT(Singleton):
class _SkipEvent(_EventBase, metaclass=Singleton):
"""
A plugin that returns an 'IGNORE_EVENT' will be filtered out as arguments by
`plugin_executor` in descendant node execution.
A plugin that returns a 'SKIP_EVENT' will cause `plugin_executor` to skip execution
of all descendant node execution.
"""
pass

_instance = None

def __repr__(self):
return "IGNORE_EVENT"
SKIP_EVENT = _SkipEvent()

def __hash__(self):
return hash("IGNORE_EVENT")

def __reduce__(self):
return (self.__class__, ())
class _IgnoreEvent(_EventBase, metaclass=Singleton):
"""
A plugin that returns an 'IGNORE_EVENT' will be filtered out as arguments by
`plugin_executor` in descendant node execution.
"""
pass


IGNORE_EVENT = _IGNORE_EVENT()
IGNORE_EVENT = _IgnoreEvent()


class SkipBranch(Exception):
Expand Down Expand Up @@ -415,7 +409,6 @@ def _process_graph(self):
if callable(self._nxgraph):
self._nxgraph = self._nxgraph(**self._nxgraph_kwargs)


if CONFIG["dagrunner_visualisation"].pop("enabled", False) is True:
self.visualise(**CONFIG["dagrunner_visualisation"])

Expand Down
47 changes: 40 additions & 7 deletions dagrunner/utils/networkx.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def visualise_graph_mermaid(
- `title`: The title of the visualisation.
- `output_filepath`: The output filepath to save the visualisation to.
"""

def gen_label(node_id, node, label_by):
label = f"{node_id}"
if label_by:
Expand All @@ -242,13 +243,24 @@ def gen_label(node_id, node, label_by):
label += f"\n{str(node)}"
return label

def add_node(node, mermaid, table, node_id, node_target_id_map, node_info_lookup, label_by, group_by):
def add_node(
node,
mermaid,
table,
node_id,
node_target_id_map,
node_info_lookup,
label_by,
group_by,
):
if node not in node_target_id_map:
node_target_id_map[node] = node_id
label = gen_label(node_id, node, label_by)
tooltip = pprint.pformat(node_info_lookup[node])

subgraphs = [getattr(node, key) for key in group_by if getattr(node, key, None)]
subgraphs = [
getattr(node, key) for key in group_by if getattr(node, key, None)
]
for subgraph in subgraphs:
mermaid.add_raw(f"subgraph {subgraph}")
mermaid.add_node(
Expand All @@ -262,7 +274,6 @@ def add_node(node, mermaid, table, node_id, node_target_id_map, node_info_lookup
node_id += 1
return node_id


mermaid = MermaidGraph(title=title or "")
table = HTMLTable(["id", "node", "info"])

Expand All @@ -272,10 +283,28 @@ def add_node(node, mermaid, table, node_id, node_target_id_map, node_info_lookup
node_target_id_map = {}
node_id = 0
for target in graph.nodes:
node_id = add_node(target, mermaid, table, node_id, node_target_id_map, node_info_lookup, label_by, group_by)
node_id = add_node(
target,
mermaid,
table,
node_id,
node_target_id_map,
node_info_lookup,
label_by,
group_by,
)

for pred in graph.predecessors(target):
node_id = add_node(pred, mermaid, table, node_id, node_target_id_map, node_info_lookup, label_by, group_by)
node_id = add_node(
pred,
mermaid,
table,
node_id,
node_target_id_map,
node_info_lookup,
label_by,
group_by,
)
mermaid.add_connection(node_target_id_map[pred], node_target_id_map[target])

if output_filepath:
Expand Down Expand Up @@ -361,8 +390,12 @@ def visualise_graph(
node_info_lookup[node] = {"node data": data}

if backend == "mermaid":
visualise_graph_mermaid(graph, node_info_lookup, title, output_filepath, **kwargs)
visualise_graph_mermaid(
graph, node_info_lookup, title, output_filepath, **kwargs
)
elif backend == "matplotlib":
visualise_graph_matplotlib(graph, node_info_lookup, title, output_filepath, **kwargs)
visualise_graph_matplotlib(
graph, node_info_lookup, title, output_filepath, **kwargs
)
else:
raise ValueError(f"Unsupported visualisation backend: {backend}")

0 comments on commit 6d33118

Please sign in to comment.