Skip to content

Commit

Permalink
new approach for tags
Browse files Browse the repository at this point in the history
Signed-off-by: Jitendra Gundaniya <[email protected]>
  • Loading branch information
jitu5 committed May 1, 2024
1 parent 7a205b0 commit c8f4983
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 47 deletions.
41 changes: 20 additions & 21 deletions package/kedro_viz/data_access/repositories/modular_pipelines.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""`kedro_viz.data_access.repositories.modular_pipelines`
defines repository to centralise access to modular pipelines data."""

from typing import Dict, Optional, Set, Union
from typing import Dict, Optional, Union

from kedro_viz.constants import ROOT_MODULAR_PIPELINE_ID
from kedro_viz.models.flowchart import (
Expand Down Expand Up @@ -163,9 +163,25 @@ def add_output(self, modular_pipeline_id: str, output_node: GraphNode):
self.tree[modular_pipeline_id].external_outputs.add(output_node.id)

def add_tags(self, modular_pipeline_id: str, node_tags: set):
if modular_pipeline_id in self.tree:
self.tree[modular_pipeline_id].tags |= node_tags

"""
Add tags to a modular pipeline.
Args:
modular_pipeline_id: ID of the modular pipeline to add the tags to.
node_tags: The tags to add to the modular pipeline.
Example:
>>> modular_pipelines = ModularPipelinesRepository()
>>> node_tags = {"tag1", "tag2"}
>>> modular_pipelines.add_tags("data_science", node_tags)
>>> data_science_pipeline = modular_pipelines.get_or_create_modular_pipeline(
... "data_science"
... )
>>> assert "tag1" in data_science_pipeline.tags
>>> assert "tag2" in data_science_pipeline.tags
"""
if modular_pipeline_id in self.tree:
self.tree[modular_pipeline_id].tags |= node_tags

def add_child(self, modular_pipeline_id: str, child: ModularPipelineChild):
"""Add a child to a modular pipeline.
Expand Down Expand Up @@ -215,8 +231,6 @@ def extract_from_node(self, node: GraphNode) -> Optional[str]:
return None

modular_pipeline = self.get_or_create_modular_pipeline(modular_pipeline_id)
# Inherit tags from the nodes of the modular pipeline.
self.inherit_tags_recursive(modular_pipeline_id, node.tags)

# Add the node's registered pipelines to the modular pipeline's registered pipelines.
# Basically this means if the node belongs to the "__default__" pipeline, for example,
Expand All @@ -231,21 +245,6 @@ def extract_from_node(self, node: GraphNode) -> Optional[str]:
)
return modular_pipeline_id

def inherit_tags_recursive(self, modular_pipeline_id: str, tags: Set[str]):
"""Recursively collects a set of tags from a modular pipeline to all of its
child modular pipelines.
Args:
modular_pipeline_id: ID of the modular pipeline to check existence in the repository.
tags: A set of tags to be added to the modular pipeline and its children.
"""

modular_pipeline = self.tree.get(modular_pipeline_id)
if modular_pipeline:
modular_pipeline.tags.update(tags)
for child in modular_pipeline.children:
if child.type == GraphNodeType.MODULAR_PIPELINE:
self.inherit_tags_recursive(child.id, modular_pipeline.tags)

def has_modular_pipeline(self, modular_pipeline_id: str) -> bool:
"""Return whether this modular pipeline repository has a given modular pipeline ID.
Args:
Expand Down
1 change: 0 additions & 1 deletion package/kedro_viz/services/modular_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,4 @@ def expand_tree(
expanded_tree[parent_id].external_outputs.update(
modular_pipeline_node.external_outputs
)
expanded_tree[parent_id].tags.update(modular_pipeline_node.tags)
return expanded_tree
Original file line number Diff line number Diff line change
Expand Up @@ -36,31 +36,6 @@ def test_extract_from_node(self, identity):
modular_pipelines.extract_from_node(task_node)
assert modular_pipelines.has_modular_pipeline("data_science")

def test_tags_inheritance(self, identity):
task_node = GraphNode.create_task_node(
node(
identity,
inputs="x",
outputs=None,
namespace="parent",
tags={"tag1", "tag2"},
)
)

modular_pipelines = ModularPipelinesRepository()
modular_pipelines.add_child(
"parent",
ModularPipelineChild(
id="parent.child", type=GraphNodeType.MODULAR_PIPELINE
),
)

modular_pipelines.extract_from_node(task_node)
modular_pipeline = modular_pipelines.get_or_create_modular_pipeline("parent")

assert "tag1" in modular_pipeline.tags
assert "tag2" in modular_pipeline.tags

def test_add_input(self):
kedro_dataset = CSVDataset(filepath="foo.csv")
modular_pipelines = ModularPipelinesRepository()
Expand Down Expand Up @@ -108,3 +83,14 @@ def test_add_output_should_raise_if_adding_non_data_node(self, identity):
modular_pipelines = ModularPipelinesRepository()
with pytest.raises(ValueError):
modular_pipelines.add_output("data_science", task_node)

def test_add_tags(self):
modular_pipelines = ModularPipelinesRepository()
node_tags = {"tag1", "tag2"}
modular_pipelines.get_or_create_modular_pipeline("data_science")
modular_pipelines.add_tags("data_science", node_tags)
data_science_pipeline = modular_pipelines.get_or_create_modular_pipeline(
"data_science"
)
assert "tag1" in data_science_pipeline.tags
assert "tag2" in data_science_pipeline.tags

0 comments on commit c8f4983

Please sign in to comment.