From 6d331181598874c42aa731d871a12c6c65d53356 Mon Sep 17 00:00:00 2001 From: cpelley <carwyn.pelley@metoffice.gov.uk> Date: Fri, 8 Nov 2024 10:51:21 +0000 Subject: [PATCH] MAINT: refactored events --- dagrunner/execute_graph.py | 47 ++++++++++++++++--------------------- dagrunner/utils/networkx.py | 47 +++++++++++++++++++++++++++++++------ 2 files changed, 60 insertions(+), 34 deletions(-) diff --git a/dagrunner/execute_graph.py b/dagrunner/execute_graph.py index 48ee432..79b1dcb 100755 --- a/dagrunner/execute_graph.py +++ b/dagrunner/execute_graph.py @@ -19,55 +19,49 @@ from dagrunner.runner.schedulers import SCHEDULERS from dagrunner.utils import ( CaptureProcMemory, + Singleton, TimeIt, + as_iterable, function_to_argparse_parse_args, logger, - as_iterable, - Singleton ) from dagrunner.utils.networkx import visualise_graph -class _SKIP_EVENT(Singleton): - """ - A plugin that returns a 'SKIP_EVENT' will cause `plugin_executor` to skip execution - of all descendant node execution. - """ - - _instance = None - +class _EventBase: def __repr__(self): - return "SKIP_EVENT" + # Ensures easy identification when printing/logging. + return self.__class__.__name__.upper() def __hash__(self): - return hash("SKIP_EVENT") + # Ensures that can be used as keys in dictionaries or stored as sets. + return hash(self.__class__.__name__.upper()) def __reduce__(self): + # Ensures that can be serialised and deserialised using pickle. return (self.__class__, ()) -SKIP_EVENT = _SKIP_EVENT() - - -class _IGNORE_EVENT(Singleton): +class _SkipEvent(_EventBase, metaclass=Singleton): """ - A plugin that returns an 'IGNORE_EVENT' will be filtered out as arguments by - `plugin_executor` in descendant node execution. + A plugin that returns a 'SKIP_EVENT' will cause `plugin_executor` to skip execution + of all descendant node execution. """ + pass - _instance = None - def __repr__(self): - return "IGNORE_EVENT" +SKIP_EVENT = _SkipEvent() - def __hash__(self): - return hash("IGNORE_EVENT") - def __reduce__(self): - return (self.__class__, ()) +class _IgnoreEvent(_EventBase, metaclass=Singleton): + """ + A plugin that returns an 'IGNORE_EVENT' will be filtered out as arguments by + `plugin_executor` in descendant node execution. + """ + pass -IGNORE_EVENT = _IGNORE_EVENT() +IGNORE_EVENT = _IgnoreEvent() class SkipBranch(Exception): @@ -415,7 +409,6 @@ def _process_graph(self): if callable(self._nxgraph): self._nxgraph = self._nxgraph(**self._nxgraph_kwargs) - if CONFIG["dagrunner_visualisation"].pop("enabled", False) is True: self.visualise(**CONFIG["dagrunner_visualisation"]) diff --git a/dagrunner/utils/networkx.py b/dagrunner/utils/networkx.py index f735f28..5eefa68 100644 --- a/dagrunner/utils/networkx.py +++ b/dagrunner/utils/networkx.py @@ -233,6 +233,7 @@ def visualise_graph_mermaid( - `title`: The title of the visualisation. - `output_filepath`: The output filepath to save the visualisation to. """ + def gen_label(node_id, node, label_by): label = f"{node_id}" if label_by: @@ -242,13 +243,24 @@ def gen_label(node_id, node, label_by): label += f"\n{str(node)}" return label - def add_node(node, mermaid, table, node_id, node_target_id_map, node_info_lookup, label_by, group_by): + def add_node( + node, + mermaid, + table, + node_id, + node_target_id_map, + node_info_lookup, + label_by, + group_by, + ): if node not in node_target_id_map: node_target_id_map[node] = node_id label = gen_label(node_id, node, label_by) tooltip = pprint.pformat(node_info_lookup[node]) - subgraphs = [getattr(node, key) for key in group_by if getattr(node, key, None)] + subgraphs = [ + getattr(node, key) for key in group_by if getattr(node, key, None) + ] for subgraph in subgraphs: mermaid.add_raw(f"subgraph {subgraph}") mermaid.add_node( @@ -262,7 +274,6 @@ def add_node(node, mermaid, table, node_id, node_target_id_map, node_info_lookup node_id += 1 return node_id - mermaid = MermaidGraph(title=title or "") table = HTMLTable(["id", "node", "info"]) @@ -272,10 +283,28 @@ def add_node(node, mermaid, table, node_id, node_target_id_map, node_info_lookup node_target_id_map = {} node_id = 0 for target in graph.nodes: - node_id = add_node(target, mermaid, table, node_id, node_target_id_map, node_info_lookup, label_by, group_by) + node_id = add_node( + target, + mermaid, + table, + node_id, + node_target_id_map, + node_info_lookup, + label_by, + group_by, + ) for pred in graph.predecessors(target): - node_id = add_node(pred, mermaid, table, node_id, node_target_id_map, node_info_lookup, label_by, group_by) + node_id = add_node( + pred, + mermaid, + table, + node_id, + node_target_id_map, + node_info_lookup, + label_by, + group_by, + ) mermaid.add_connection(node_target_id_map[pred], node_target_id_map[target]) if output_filepath: @@ -361,8 +390,12 @@ def visualise_graph( node_info_lookup[node] = {"node data": data} if backend == "mermaid": - visualise_graph_mermaid(graph, node_info_lookup, title, output_filepath, **kwargs) + visualise_graph_mermaid( + graph, node_info_lookup, title, output_filepath, **kwargs + ) elif backend == "matplotlib": - visualise_graph_matplotlib(graph, node_info_lookup, title, output_filepath, **kwargs) + visualise_graph_matplotlib( + graph, node_info_lookup, title, output_filepath, **kwargs + ) else: raise ValueError(f"Unsupported visualisation backend: {backend}")