From 6d331181598874c42aa731d871a12c6c65d53356 Mon Sep 17 00:00:00 2001
From: cpelley <carwyn.pelley@metoffice.gov.uk>
Date: Fri, 8 Nov 2024 10:51:21 +0000
Subject: [PATCH] MAINT: refactored events

---
 dagrunner/execute_graph.py  | 47 ++++++++++++++++---------------------
 dagrunner/utils/networkx.py | 47 +++++++++++++++++++++++++++++++------
 2 files changed, 60 insertions(+), 34 deletions(-)

diff --git a/dagrunner/execute_graph.py b/dagrunner/execute_graph.py
index 48ee432..79b1dcb 100755
--- a/dagrunner/execute_graph.py
+++ b/dagrunner/execute_graph.py
@@ -19,55 +19,49 @@
 from dagrunner.runner.schedulers import SCHEDULERS
 from dagrunner.utils import (
     CaptureProcMemory,
+    Singleton,
     TimeIt,
+    as_iterable,
     function_to_argparse_parse_args,
     logger,
-    as_iterable,
-    Singleton
 )
 from dagrunner.utils.networkx import visualise_graph
 
 
-class _SKIP_EVENT(Singleton):
-    """
-    A plugin that returns a 'SKIP_EVENT' will cause `plugin_executor` to skip execution
-    of all descendant node execution.
-    """
-
-    _instance = None
-
+class _EventBase:
     def __repr__(self):
-        return "SKIP_EVENT"
+        # Ensures easy identification when printing/logging.
+        return self.__class__.__name__.upper()
 
     def __hash__(self):
-        return hash("SKIP_EVENT")
+        # Ensures that can be used as keys in dictionaries or stored as sets.
+        return hash(self.__class__.__name__.upper())
 
     def __reduce__(self):
+        # Ensures that can be serialised and deserialised using pickle.
         return (self.__class__, ())
 
 
-SKIP_EVENT = _SKIP_EVENT()
-
-
-class _IGNORE_EVENT(Singleton):
+class _SkipEvent(_EventBase, metaclass=Singleton):
     """
-    A plugin that returns an 'IGNORE_EVENT' will be filtered out as arguments by
-    `plugin_executor` in descendant node execution.
+    A plugin that returns a 'SKIP_EVENT' will cause `plugin_executor` to skip execution
+    of all descendant node execution.
     """
+    pass
 
-    _instance = None
 
-    def __repr__(self):
-        return "IGNORE_EVENT"
+SKIP_EVENT = _SkipEvent()
 
-    def __hash__(self):
-        return hash("IGNORE_EVENT")
 
-    def __reduce__(self):
-        return (self.__class__, ())
+class _IgnoreEvent(_EventBase, metaclass=Singleton):
+    """
+    A plugin that returns an 'IGNORE_EVENT' will be filtered out as arguments by
+    `plugin_executor` in descendant node execution.
+    """
+    pass
 
 
-IGNORE_EVENT = _IGNORE_EVENT()
+IGNORE_EVENT = _IgnoreEvent()
 
 
 class SkipBranch(Exception):
@@ -415,7 +409,6 @@ def _process_graph(self):
         if callable(self._nxgraph):
             self._nxgraph = self._nxgraph(**self._nxgraph_kwargs)
 
-
         if CONFIG["dagrunner_visualisation"].pop("enabled", False) is True:
             self.visualise(**CONFIG["dagrunner_visualisation"])
 
diff --git a/dagrunner/utils/networkx.py b/dagrunner/utils/networkx.py
index f735f28..5eefa68 100644
--- a/dagrunner/utils/networkx.py
+++ b/dagrunner/utils/networkx.py
@@ -233,6 +233,7 @@ def visualise_graph_mermaid(
     - `title`: The title of the visualisation.
     - `output_filepath`: The output filepath to save the visualisation to.
     """
+
     def gen_label(node_id, node, label_by):
         label = f"{node_id}"
         if label_by:
@@ -242,13 +243,24 @@ def gen_label(node_id, node, label_by):
             label += f"\n{str(node)}"
         return label
 
-    def add_node(node, mermaid, table, node_id, node_target_id_map, node_info_lookup, label_by, group_by):
+    def add_node(
+        node,
+        mermaid,
+        table,
+        node_id,
+        node_target_id_map,
+        node_info_lookup,
+        label_by,
+        group_by,
+    ):
         if node not in node_target_id_map:
             node_target_id_map[node] = node_id
             label = gen_label(node_id, node, label_by)
             tooltip = pprint.pformat(node_info_lookup[node])
 
-            subgraphs = [getattr(node, key) for key in group_by if getattr(node, key, None)]
+            subgraphs = [
+                getattr(node, key) for key in group_by if getattr(node, key, None)
+            ]
             for subgraph in subgraphs:
                 mermaid.add_raw(f"subgraph {subgraph}")
             mermaid.add_node(
@@ -262,7 +274,6 @@ def add_node(node, mermaid, table, node_id, node_target_id_map, node_info_lookup
             node_id += 1
         return node_id
 
-
     mermaid = MermaidGraph(title=title or "")
     table = HTMLTable(["id", "node", "info"])
 
@@ -272,10 +283,28 @@ def add_node(node, mermaid, table, node_id, node_target_id_map, node_info_lookup
     node_target_id_map = {}
     node_id = 0
     for target in graph.nodes:
-        node_id = add_node(target, mermaid, table, node_id, node_target_id_map, node_info_lookup, label_by, group_by)
+        node_id = add_node(
+            target,
+            mermaid,
+            table,
+            node_id,
+            node_target_id_map,
+            node_info_lookup,
+            label_by,
+            group_by,
+        )
 
         for pred in graph.predecessors(target):
-            node_id = add_node(pred, mermaid, table, node_id, node_target_id_map, node_info_lookup, label_by, group_by)
+            node_id = add_node(
+                pred,
+                mermaid,
+                table,
+                node_id,
+                node_target_id_map,
+                node_info_lookup,
+                label_by,
+                group_by,
+            )
             mermaid.add_connection(node_target_id_map[pred], node_target_id_map[target])
 
     if output_filepath:
@@ -361,8 +390,12 @@ def visualise_graph(
             node_info_lookup[node] = {"node data": data}
 
     if backend == "mermaid":
-        visualise_graph_mermaid(graph, node_info_lookup, title, output_filepath, **kwargs)
+        visualise_graph_mermaid(
+            graph, node_info_lookup, title, output_filepath, **kwargs
+        )
     elif backend == "matplotlib":
-        visualise_graph_matplotlib(graph, node_info_lookup, title, output_filepath, **kwargs)
+        visualise_graph_matplotlib(
+            graph, node_info_lookup, title, output_filepath, **kwargs
+        )
     else:
         raise ValueError(f"Unsupported visualisation backend: {backend}")