From c7cdab95172fe6fe60c3e50b34389b81c46bd697 Mon Sep 17 00:00:00 2001 From: rashidakanchwala <37628668+rashidakanchwala@users.noreply.github.com> Date: Fri, 1 Nov 2024 12:40:02 +0000 Subject: [PATCH] Fix `tag` being undefined bug from the backend. (#2162) Resolves #2106 --- .../data_access/repositories/graph.py | 9 ++-- package/tests/conftest.py | 53 +++++++++++++++++++ .../test_responses/test_pipelines.py | 14 +++++ 3 files changed, 72 insertions(+), 4 deletions(-) diff --git a/package/kedro_viz/data_access/repositories/graph.py b/package/kedro_viz/data_access/repositories/graph.py index bea6095bc..463012800 100644 --- a/package/kedro_viz/data_access/repositories/graph.py +++ b/package/kedro_viz/data_access/repositories/graph.py @@ -12,11 +12,12 @@ def __init__(self): self.nodes_dict: Dict[str, GraphNode] = {} self.nodes_list: List[GraphNode] = [] - def has_node(self, node: GraphNode) -> bool: - return node.id in self.nodes_dict - def add_node(self, node: GraphNode) -> GraphNode: - if not self.has_node(node): + existing_node = self.nodes_dict.get(node.id) + if existing_node: + # Update tags or other attributes if the node already exists + existing_node.tags.update(node.tags) + else: self.nodes_dict[node.id] = node self.nodes_list.append(node) return self.nodes_dict[node.id] diff --git a/package/tests/conftest.py b/package/tests/conftest.py index 5c1a300ab..ea25e94f7 100644 --- a/package/tests/conftest.py +++ b/package/tests/conftest.py @@ -222,6 +222,7 @@ def example_pipeline_with_node_namespaces(): inputs=["raw_transaction_data", "cleaned_transaction_data"], outputs="validated_transaction_data", name="validation_node", + tags=["validation"], ), node( func=lambda validated_data, enrichment_data: ( @@ -381,6 +382,23 @@ def edge_case_example_pipelines( } +@pytest.fixture +def example_pipelines_with_additional_tags(example_pipeline_with_node_namespaces): + """ + Fixture to mock the use cases mentioned in + https://github.com/kedro-org/kedro-viz/issues/2106 + """ + + pipelines_dict = { + "pipeline": example_pipeline_with_node_namespaces, + "pipeline_with_tags": pipeline( + example_pipeline_with_node_namespaces, tags=["tag1", "tag2"] + ), + } + + yield pipelines_dict + + @pytest.fixture def expected_modular_pipeline_tree_for_edge_cases(): expected_tree_for_edge_cases_file_path = ( @@ -554,6 +572,41 @@ def example_api_for_edge_case_pipelines( yield api +@pytest.fixture +def example_api_for_pipelines_with_additional_tags( + data_access_manager: DataAccessManager, + example_pipelines_with_additional_tags: Dict[str, Pipeline], + example_catalog: DataCatalog, + session_store: BaseSessionStore, + mocker, +): + api = apps.create_api_app_from_project(mock.MagicMock()) + + # For readability we are not hashing the node id + mocker.patch("kedro_viz.utils._hash", side_effect=lambda value: value) + mocker.patch( + "kedro_viz.data_access.repositories.modular_pipelines._hash", + side_effect=lambda value: value, + ) + + populate_data( + data_access_manager, + example_catalog, + example_pipelines_with_additional_tags, + session_store, + {}, + ) + mocker.patch( + "kedro_viz.api.rest.responses.pipelines.data_access_manager", + new=data_access_manager, + ) + mocker.patch( + "kedro_viz.api.rest.responses.nodes.data_access_manager", + new=data_access_manager, + ) + yield api + + @pytest.fixture def example_transcoded_api( data_access_manager: DataAccessManager, diff --git a/package/tests/test_api/test_rest/test_responses/test_pipelines.py b/package/tests/test_api/test_rest/test_responses/test_pipelines.py index 4b933e33e..b1d14d8ca 100755 --- a/package/tests/test_api/test_rest/test_responses/test_pipelines.py +++ b/package/tests/test_api/test_rest/test_responses/test_pipelines.py @@ -35,6 +35,20 @@ def test_endpoint_main_no_default_pipeline(self, example_api_no_default_pipeline {"id": "data_processing", "name": "data_processing"}, ] + def test_endpoint_main_for_pipelines_with_additional_tags( + self, + example_api_for_pipelines_with_additional_tags, + ): + expected_tags = [ + {"id": "tag1", "name": "tag1"}, + {"id": "tag2", "name": "tag2"}, + {"id": "validation", "name": "validation"}, + ] + client = TestClient(example_api_for_pipelines_with_additional_tags) + response = client.get("/api/main") + actual_tags = response.json()["tags"] + assert actual_tags == expected_tags + def test_endpoint_main_for_edge_case_pipelines( self, example_api_for_edge_case_pipelines,