Skip to content

Commit

Permalink
label control and subgraph control
Browse files Browse the repository at this point in the history
  • Loading branch information
cpelley committed Nov 7, 2024
1 parent 10faa53 commit f14baa0
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 28 deletions.
4 changes: 4 additions & 0 deletions dagrunner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class GlobalConfiguration(object, metaclass=Singleton):
collapse_properties
backend
output_filepath
group_by
label_by
# Logging
[dagrunner_logging]
Expand All @@ -63,6 +65,8 @@ class GlobalConfiguration(object, metaclass=Singleton):
"collapse_properties": None,
"backend": None,
"output_filepath": None,
"group_by": None,
"label_by": None,
},
}

Expand Down
2 changes: 2 additions & 0 deletions dagrunner/dagrunner.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#enabled=True
#collapse_properties=leadtime
#backend=mermaid
#group_by=diagnostic
#label_by=step

#[dagrunner_graph_filter]
#diagnostic=pmsl
Expand Down
9 changes: 9 additions & 0 deletions dagrunner/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,20 @@
import socket
import threading
import time
from typing import Iterable
from abc import ABC, abstractmethod

import dagrunner.utils._doc_styles as doc_styles


def as_iterable(obj):
if not isinstance(obj, Iterable) or isinstance(
obj, (str, bytes)
):
obj = [obj]
return obj


class Singleton(type):
"""
Singleton metaclass.
Expand Down
73 changes: 45 additions & 28 deletions dagrunner/utils/networkx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import networkx as nx

from . import as_iterable
from .visualisation import HTMLTable, MermaidGraph, MermaidHTML


Expand Down Expand Up @@ -220,6 +221,8 @@ def visualise_graph_mermaid(
node_info_lookup: dict = None,
title: str = None,
output_filepath: str = None,
group_by: str = None,
label_by: Iterable = None,
):
"""
Visualise a networkx graph using matplotlib.
Expand All @@ -230,37 +233,53 @@ def visualise_graph_mermaid(
- `title`: The title of the visualisation.
- `output_filepath`: The output filepath to save the visualisation to.
"""
mermaid = MermaidGraph(title=title or "")
table = HTMLTable(["id", "node", "info"])

node_target_id_map = {}
node_id = 0
for target in graph.nodes:
if target not in node_target_id_map:
node_target_id_map[target] = node_id
label = f"{node_id}\n{str(target)}"
tooltip = pprint.pformat(node_info_lookup[target])
def gen_label(node_id, node, label_by):
label = f"{node_id}"
if label_by:
for key in label_by:
label += f"\n{key}: {getattr(node, key)}"
else:
label += f"\n{str(node)}"
return label

def add_node(node, mermaid, table, node_id, node_target_id_map, node_info_lookup, label_by, group_by):
if node not in node_target_id_map:
node_target_id_map[node] = node_id
label = gen_label(node_id, node, label_by)
tooltip = pprint.pformat(node_info_lookup[node])

subgraphs = [getattr(node, key) for key in group_by if hasattr(node, key)]
for subgraph in subgraphs:
mermaid.add_raw(f"subgraph {subgraph}")
mermaid.add_node(
node_id,
label=label,
tooltip=tooltip,
)
table.add_row(node_id, target, tooltip)
for subgraph in subgraphs:
mermaid.add_raw("end")
table.add_row(node_id, node, tooltip)
node_id += 1
return node_id


mermaid = MermaidGraph(title=title or "")
table = HTMLTable(["id", "node", "info"])

if label_by:
label_by = as_iterable(label_by)
if group_by:
group_by = as_iterable(group_by)

node_target_id_map = {}
node_id = 0
for target in graph.nodes:
node_id = add_node(target, mermaid, table, node_id, node_target_id_map, node_info_lookup, label_by, group_by)

for pred in graph.predecessors(target):
if pred not in node_target_id_map:
node_target_id_map[pred] = node_id
label = f"{node_id}\n{str(pred)}"
tooltip = pprint.pformat(node_info_lookup[pred])
mermaid.add_node(
node_id,
label=label,
tooltip=tooltip,
)
table.add_row(node_id, pred, tooltip)
node_id += 1
node_id = add_node(pred, mermaid, table, node_id, node_target_id_map, node_info_lookup, label_by, group_by)
mermaid.add_connection(node_target_id_map[pred], node_target_id_map[target])

if output_filepath:
MermaidHTML(mermaid, table).save(output_filepath)
else:
Expand All @@ -277,6 +296,7 @@ def visualise_graph(
collapse_properties: Iterable = None,
title=None,
output_filepath=None,
**kwargs,
):
"""
Visualise a networkx graph.
Expand All @@ -297,10 +317,7 @@ def visualise_graph(
"""
node_info_lookup = {}
if collapse_properties:
if not isinstance(collapse_properties, Iterable) or isinstance(
collapse_properties, (str, bytes)
):
collapse_properties = [collapse_properties]
collapse_properties = as_iterable(collapse_properties)
if not dataclasses.is_dataclass(next(iter(graph.nodes))):
raise TypeError(
"Graph collapse along properties only supported for dataclasses right "
Expand Down Expand Up @@ -346,8 +363,8 @@ def visualise_graph(
node_info_lookup[node] = {"node data": data}

if backend == "mermaid":
visualise_graph_mermaid(graph, node_info_lookup, title, output_filepath)
visualise_graph_mermaid(graph, node_info_lookup, title, output_filepath, **kwargs)
elif backend == "matplotlib":
visualise_graph_matplotlib(graph, node_info_lookup, title, output_filepath)
visualise_graph_matplotlib(graph, node_info_lookup, title, output_filepath, **kwargs)
else:
raise ValueError(f"Unsupported visualisation backend: {backend}")
3 changes: 3 additions & 0 deletions dagrunner/utils/visualisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def __init__(self, title=None):
self._cont = ""
self._title = title or ""

def add_raw(self, raw):
self._cont += f"\n{raw}"

def add_node(self, nodeid, label=None, tooltip=None, url=None):
if label:
label = label.replace("\n", self.CARRIAGE_RETURN)
Expand Down

0 comments on commit f14baa0

Please sign in to comment.