From 4db2bf95a25f438deba5d6ba739e7b7acd0b45e2 Mon Sep 17 00:00:00 2001 From: Sajid Alam <90610031+SajidAlamQB@users.noreply.github.com> Date: Wed, 30 Oct 2024 11:15:43 +0000 Subject: [PATCH 1/2] Breakdown flowchart models into separate files (#2144) * imitial Signed-off-by: Sajid Alam * update Signed-off-by: Sajid Alam * split into modular pipelines Signed-off-by: Sajid Alam * remove comment Signed-off-by: Sajid Alam * move GraphNodeType to nodes Signed-off-by: Sajid Alam * refactor Signed-off-by: Sajid Alam * fix refactors Signed-off-by: Sajid Alam * fix imports Signed-off-by: Sajid Alam * resolve circular dependency Signed-off-by: Sajid Alam * fix tests Signed-off-by: Sajid Alam * lint Signed-off-by: Sajid Alam * changes based on review Signed-off-by: Sajid Alam * split flowchart test file Signed-off-by: Sajid Alam * Update node_metadata.py Signed-off-by: Sajid Alam * Update ruff.toml Signed-off-by: Sajid Alam * lint Signed-off-by: Sajid Alam * move test files Signed-off-by: Sajid Alam * moved to named_entities.py Signed-off-by: Sajid Alam --------- Signed-off-by: Sajid Alam --- package/kedro_viz/api/rest/responses.py | 6 +- package/kedro_viz/data_access/managers.py | 8 +- .../data_access/repositories/graph.py | 3 +- .../repositories/modular_pipelines.py | 4 +- .../repositories/registered_pipelines.py | 2 +- .../data_access/repositories/tags.py | 2 +- .../kedro_viz/models/flowchart/__init__.py | 0 package/kedro_viz/models/flowchart/edge.py | 15 + .../kedro_viz/models/flowchart/model_utils.py | 45 ++ .../models/flowchart/named_entities.py | 41 ++ .../models/flowchart/node_metadata.py | 406 +++++++++++++ .../{flowchart.py => flowchart/nodes.py} | 539 ++---------------- package/kedro_viz/services/layers.py | 2 +- package/tests/conftest.py | 3 +- .../test_api/test_rest/test_responses.py | 2 +- .../tests/test_data_access/test_managers.py | 6 +- .../test_repositories/test_graph.py | 3 +- .../test_modular_pipelines.py | 7 +- .../test_models/test_flowchart/__init__.py | 0 .../test_node_metadata.py} | 268 +-------- .../test_models/test_flowchart/test_nodes.py | 248 ++++++++ .../test_flowchart/test_pipeline.py | 32 ++ package/tests/test_services/test_layers.py | 2 +- ruff.toml | 3 +- 24 files changed, 855 insertions(+), 792 deletions(-) create mode 100644 package/kedro_viz/models/flowchart/__init__.py create mode 100644 package/kedro_viz/models/flowchart/edge.py create mode 100644 package/kedro_viz/models/flowchart/model_utils.py create mode 100644 package/kedro_viz/models/flowchart/named_entities.py create mode 100644 package/kedro_viz/models/flowchart/node_metadata.py rename package/kedro_viz/models/{flowchart.py => flowchart/nodes.py} (53%) create mode 100644 package/tests/test_models/test_flowchart/__init__.py rename package/tests/test_models/{test_flowchart.py => test_flowchart/test_node_metadata.py} (55%) create mode 100644 package/tests/test_models/test_flowchart/test_nodes.py create mode 100644 package/tests/test_models/test_flowchart/test_pipeline.py diff --git a/package/kedro_viz/api/rest/responses.py b/package/kedro_viz/api/rest/responses.py index 5a38ef6b4c..1e885eced1 100644 --- a/package/kedro_viz/api/rest/responses.py +++ b/package/kedro_viz/api/rest/responses.py @@ -12,15 +12,13 @@ from kedro_viz.api.rest.utils import get_package_compatibilities from kedro_viz.data_access import data_access_manager -from kedro_viz.models.flowchart import ( - DataNode, +from kedro_viz.models.flowchart.node_metadata import ( DataNodeMetadata, ParametersNodeMetadata, - TaskNode, TaskNodeMetadata, - TranscodedDataNode, TranscodedDataNodeMetadata, ) +from kedro_viz.models.flowchart.nodes import DataNode, TaskNode, TranscodedDataNode from kedro_viz.models.metadata import Metadata, PackageCompatibility logger = logging.getLogger(__name__) diff --git a/package/kedro_viz/data_access/managers.py b/package/kedro_viz/data_access/managers.py index 40e8ac56f6..4468804c77 100644 --- a/package/kedro_viz/data_access/managers.py +++ b/package/kedro_viz/data_access/managers.py @@ -20,15 +20,15 @@ from kedro_viz.constants import DEFAULT_REGISTERED_PIPELINE_ID, ROOT_MODULAR_PIPELINE_ID from kedro_viz.integrations.utils import UnavailableDataset -from kedro_viz.models.flowchart import ( +from kedro_viz.models.flowchart.edge import GraphEdge +from kedro_viz.models.flowchart.model_utils import GraphNodeType +from kedro_viz.models.flowchart.named_entities import RegisteredPipeline +from kedro_viz.models.flowchart.nodes import ( DataNode, - GraphEdge, GraphNode, - GraphNodeType, ModularPipelineChild, ModularPipelineNode, ParametersNode, - RegisteredPipeline, TaskNode, TranscodedDataNode, ) diff --git a/package/kedro_viz/data_access/repositories/graph.py b/package/kedro_viz/data_access/repositories/graph.py index 601e52d060..bea6095bc9 100644 --- a/package/kedro_viz/data_access/repositories/graph.py +++ b/package/kedro_viz/data_access/repositories/graph.py @@ -3,7 +3,8 @@ from typing import Dict, Generator, List, Optional, Set -from kedro_viz.models.flowchart import GraphEdge, GraphNode +from kedro_viz.models.flowchart.edge import GraphEdge +from kedro_viz.models.flowchart.nodes import GraphNode class GraphNodesRepository: diff --git a/package/kedro_viz/data_access/repositories/modular_pipelines.py b/package/kedro_viz/data_access/repositories/modular_pipelines.py index 746f6700df..dc51df7f80 100644 --- a/package/kedro_viz/data_access/repositories/modular_pipelines.py +++ b/package/kedro_viz/data_access/repositories/modular_pipelines.py @@ -8,9 +8,9 @@ from kedro.pipeline.node import Node as KedroNode from kedro_viz.constants import ROOT_MODULAR_PIPELINE_ID -from kedro_viz.models.flowchart import ( +from kedro_viz.models.flowchart.model_utils import GraphNodeType +from kedro_viz.models.flowchart.nodes import ( GraphNode, - GraphNodeType, ModularPipelineChild, ModularPipelineNode, ) diff --git a/package/kedro_viz/data_access/repositories/registered_pipelines.py b/package/kedro_viz/data_access/repositories/registered_pipelines.py index d73f621867..1309548fac 100644 --- a/package/kedro_viz/data_access/repositories/registered_pipelines.py +++ b/package/kedro_viz/data_access/repositories/registered_pipelines.py @@ -4,7 +4,7 @@ from collections import OrderedDict, defaultdict from typing import Dict, List, Optional, Set -from kedro_viz.models.flowchart import RegisteredPipeline +from kedro_viz.models.flowchart.named_entities import RegisteredPipeline class RegisteredPipelinesRepository: diff --git a/package/kedro_viz/data_access/repositories/tags.py b/package/kedro_viz/data_access/repositories/tags.py index 0bb46949ac..a7bd33e31f 100644 --- a/package/kedro_viz/data_access/repositories/tags.py +++ b/package/kedro_viz/data_access/repositories/tags.py @@ -3,7 +3,7 @@ from typing import Iterable, List, Set -from kedro_viz.models.flowchart import Tag +from kedro_viz.models.flowchart.named_entities import Tag class TagsRepository: diff --git a/package/kedro_viz/models/flowchart/__init__.py b/package/kedro_viz/models/flowchart/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/package/kedro_viz/models/flowchart/edge.py b/package/kedro_viz/models/flowchart/edge.py new file mode 100644 index 0000000000..439cafc782 --- /dev/null +++ b/package/kedro_viz/models/flowchart/edge.py @@ -0,0 +1,15 @@ +"""`kedro_viz.models.flowchart.edge` defines data models to represent Kedro edges in a viz graph.""" + +from pydantic import BaseModel + + +class GraphEdge(BaseModel, frozen=True): + """Represent an edge in the graph + + Args: + source (str): The id of the source node. + target (str): The id of the target node. + """ + + source: str + target: str diff --git a/package/kedro_viz/models/flowchart/model_utils.py b/package/kedro_viz/models/flowchart/model_utils.py new file mode 100644 index 0000000000..f12e94b669 --- /dev/null +++ b/package/kedro_viz/models/flowchart/model_utils.py @@ -0,0 +1,45 @@ +"""`kedro_viz.models.flowchart.model_utils` defines utils for Kedro entities in a viz graph.""" + +import logging +from enum import Enum +from types import FunctionType +from typing import Any, Dict, Optional + +logger = logging.getLogger(__name__) + + +def _parse_filepath(dataset_description: Dict[str, Any]) -> Optional[str]: + """ + Extract the file path from a dataset description dictionary. + """ + filepath = dataset_description.get("filepath") or dataset_description.get("path") + return str(filepath) if filepath else None + + +def _extract_wrapped_func(func: FunctionType) -> FunctionType: + """Extract a wrapped decorated function to inspect the source code if available. + Adapted from https://stackoverflow.com/a/43506509/1684058 + """ + if func.__closure__ is None: + return func + closure = (c.cell_contents for c in func.__closure__) + wrapped_func = next((c for c in closure if isinstance(c, FunctionType)), None) + # return the original function if it's not a decorated function + return func if wrapped_func is None else wrapped_func + + +# ============================================================================= +# Shared base classes and enumerations for model components +# ============================================================================= + + +class GraphNodeType(str, Enum): + """Represent all possible node types in the graph representation of a Kedro pipeline. + The type needs to inherit from str as well so FastAPI can serialise it. See: + https://fastapi.tiangolo.com/tutorial/path-params/#working-with-python-enumerations + """ + + TASK = "task" + DATA = "data" + PARAMETERS = "parameters" + MODULAR_PIPELINE = "modularPipeline" # CamelCase for frontend compatibility diff --git a/package/kedro_viz/models/flowchart/named_entities.py b/package/kedro_viz/models/flowchart/named_entities.py new file mode 100644 index 0000000000..65944c0764 --- /dev/null +++ b/package/kedro_viz/models/flowchart/named_entities.py @@ -0,0 +1,41 @@ +"""kedro_viz.models.flowchart.named_entities` defines data models for representing named entities +such as tags and registered pipelines within a Kedro visualization graph.""" + +from typing import Optional + +from pydantic import BaseModel, Field, ValidationInfo, field_validator + + +class NamedEntity(BaseModel): + """Represent a named entity (Tag/Registered Pipeline) in a Kedro project + Args: + id (str): Id of the registered pipeline + + Raises: + AssertionError: If id is not supplied during instantiation + """ + + id: str + name: Optional[str] = Field( + default=None, + validate_default=True, + description="The name of the entity", + ) + + @field_validator("name") + @classmethod + def set_name(cls, _, info: ValidationInfo): + """Ensures that the 'name' field is set to the value of 'id' if 'name' is not provided.""" + assert "id" in info.data + return info.data["id"] + + +class RegisteredPipeline(NamedEntity): + """Represent a registered pipeline in a Kedro project.""" + + +class Tag(NamedEntity): + """Represent a tag in a Kedro project.""" + + def __hash__(self) -> int: + return hash(self.id) diff --git a/package/kedro_viz/models/flowchart/node_metadata.py b/package/kedro_viz/models/flowchart/node_metadata.py new file mode 100644 index 0000000000..20940a9b3a --- /dev/null +++ b/package/kedro_viz/models/flowchart/node_metadata.py @@ -0,0 +1,406 @@ +""" +`kedro_viz.models.flowchart.node_metadata` defines data models to represent +Kedro metadata in a visualization graph. +""" + +import inspect +import logging +from abc import ABC +from pathlib import Path +from typing import ClassVar, Dict, List, Optional, Union, cast + +from kedro.pipeline.node import Node as KedroNode +from pydantic import BaseModel, Field, field_validator, model_validator + +try: + # kedro 0.18.12 onwards + from kedro.io.core import AbstractDataset +except ImportError: # pragma: no cover + # older versions + from kedro.io.core import AbstractDataSet as AbstractDataset # type: ignore + +from kedro_viz.models.utils import get_dataset_type + +from .model_utils import _extract_wrapped_func, _parse_filepath +from .nodes import DataNode, ParametersNode, TaskNode, TranscodedDataNode + +logger = logging.getLogger(__name__) + + +class GraphNodeMetadata(BaseModel, ABC): + """Represent a graph node's metadata.""" + + +class TaskNodeMetadata(GraphNodeMetadata): + """Represent the metadata of a TaskNode. + + Args: + task_node (TaskNode): Task node to which this metadata belongs to. + + Raises: + AssertionError: If task_node is not supplied during instantiation. + """ + + task_node: TaskNode = Field(..., exclude=True) + + code: Optional[str] = Field( + default=None, + validate_default=True, + description="Source code of the node's function", + ) + + filepath: Optional[str] = Field( + default=None, + validate_default=True, + description="Path to the file where the node is defined", + ) + + parameters: Optional[Dict] = Field( + default=None, + validate_default=True, + description="The parameters of the node, if available", + ) + run_command: Optional[str] = Field( + default=None, + validate_default=True, + description="The command to run the pipeline to this node", + ) + + inputs: Optional[List[str]] = Field( + default=None, validate_default=True, description="The inputs to the TaskNode" + ) + outputs: Optional[List[str]] = Field( + default=None, validate_default=True, description="The outputs from the TaskNode" + ) + + @model_validator(mode="before") + @classmethod + def check_task_node_exists(cls, values): + assert "task_node" in values + cls.set_task_and_kedro_node(values["task_node"]) + return values + + @classmethod + def set_task_and_kedro_node(cls, task_node): + cls.task_node = task_node + cls.kedro_node = cast(KedroNode, task_node.kedro_obj) + + @field_validator("code") + @classmethod + def set_code(cls, code): + # this is required to handle partial, curry functions + if inspect.isfunction(cls.kedro_node.func): + code = inspect.getsource(_extract_wrapped_func(cls.kedro_node.func)) + return code + + return None + + @field_validator("filepath") + @classmethod + def set_filepath(cls, filepath): + # this is required to handle partial, curry functions + if inspect.isfunction(cls.kedro_node.func): + code_full_path = ( + Path(inspect.getfile(cls.kedro_node.func)).expanduser().resolve() + ) + + try: + filepath = code_full_path.relative_to(Path.cwd().parent) + except ValueError: # pragma: no cover + # if the filepath can't be resolved relative to the current directory, + # e.g. either during tests or during launching development server + # outside of a Kedro project, simply return the fullpath to the file. + filepath = code_full_path + + return str(filepath) + + return None + + @field_validator("parameters") + @classmethod + def set_parameters(cls, _): + return cls.task_node.parameters + + @field_validator("run_command") + @classmethod + def set_run_command(cls, _): + return f"kedro run --to-nodes='{cls.kedro_node.name}'" + + @field_validator("inputs") + @classmethod + def set_inputs(cls, _): + return cls.kedro_node.inputs + + @field_validator("outputs") + @classmethod + def set_outputs(cls, _): + return cls.kedro_node.outputs + + +class DataNodeMetadata(GraphNodeMetadata): + """Represent the metadata of a DataNode. + + Args: + data_node (DataNode): Data node to which this metadata belongs to. + + Attributes: + is_all_previews_enabled (bool): Class-level attribute to determine if + previews are enabled for all nodes. This can be configured via CLI + or UI to manage the preview settings. + + Raises: + AssertionError: If data_node is not supplied during instantiation. + """ + + data_node: DataNode = Field(..., exclude=True) + + is_all_previews_enabled: ClassVar[bool] = True + + type: Optional[str] = Field( + default=None, validate_default=True, description="The type of the data node" + ) + + filepath: Optional[str] = Field( + default=None, + validate_default=True, + description="The path to the actual data file for the underlying dataset", + ) + + run_command: Optional[str] = Field( + default=None, + validate_default=True, + description="Command to run the pipeline to this node", + ) + + preview: Optional[Union[Dict, str]] = Field( + default=None, + validate_default=True, + description="Preview data for the underlying datanode", + ) + + preview_type: Optional[str] = Field( + default=None, + validate_default=True, + description="Type of preview for the dataset", + ) + + stats: Optional[Dict] = Field( + default=None, + validate_default=True, + description="The statistics for the data node.", + ) + + @model_validator(mode="before") + @classmethod + def check_data_node_exists(cls, values): + assert "data_node" in values + cls.set_data_node_and_dataset(values["data_node"]) + return values + + @classmethod + def set_is_all_previews_enabled(cls, value: bool): + cls.is_all_previews_enabled = value + + @classmethod + def set_data_node_and_dataset(cls, data_node): + cls.data_node = data_node + cls.dataset = cast(AbstractDataset, data_node.kedro_obj) + + # dataset.release clears the cache before loading to ensure that this issue + # does not arise: https://github.com/kedro-org/kedro-viz/pull/573. + cls.dataset.release() + + @field_validator("type") + @classmethod + def set_type(cls, _): + return cls.data_node.dataset_type + + @field_validator("filepath") + @classmethod + def set_filepath(cls, _): + dataset_description = cls.dataset._describe() + return _parse_filepath(dataset_description) + + @field_validator("run_command") + @classmethod + def set_run_command(cls, _): + if not cls.data_node.is_free_input: + return f"kedro run --to-outputs={cls.data_node.name}" + return None + + @field_validator("preview") + @classmethod + def set_preview(cls, _): + if ( + not cls.data_node.is_preview_enabled() + or not hasattr(cls.dataset, "preview") + or not cls.is_all_previews_enabled + ): + return None + + try: + preview_args = ( + cls.data_node.get_preview_args() if cls.data_node.viz_metadata else None + ) + if preview_args is None: + return cls.dataset.preview() + return cls.dataset.preview(**preview_args) + + except Exception as exc: # noqa: BLE001 + logger.warning( + "'%s' could not be previewed. Full exception: %s: %s", + cls.data_node.name, + type(exc).__name__, + exc, + ) + return None + + @field_validator("preview_type") + @classmethod + def set_preview_type(cls, _): + if ( + not cls.data_node.is_preview_enabled() + or not hasattr(cls.dataset, "preview") + or not cls.is_all_previews_enabled + ): + return None + + try: + preview_type_annotation = inspect.signature( + cls.dataset.preview + ).return_annotation + # Attempt to get the name attribute, if it exists. + # Otherwise, use str to handle the annotation directly. + preview_type_name = getattr( + preview_type_annotation, "__name__", str(preview_type_annotation) + ) + return preview_type_name + + except Exception as exc: # noqa: BLE001 # pragma: no cover + logger.warning( + "'%s' did not have preview type. Full exception: %s: %s", + cls.data_node.name, + type(exc).__name__, + exc, + ) + return None + + @field_validator("stats") + @classmethod + def set_stats(cls, _): + return cls.data_node.stats + + +class TranscodedDataNodeMetadata(GraphNodeMetadata): + """Represent the metadata of a TranscodedDataNode. + Args: + transcoded_data_node: The transcoded data node to which this metadata belongs. + + Raises: + AssertionError: If `transcoded_data_node` is not supplied during instantiation. + """ + + transcoded_data_node: TranscodedDataNode = Field(..., exclude=True) + + # Only available if the dataset has filepath set. + filepath: Optional[str] = Field( + default=None, + validate_default=True, + description="The path to the actual data file for the underlying dataset", + ) + + run_command: Optional[str] = Field( + default=None, + validate_default=True, + description="Command to run the pipeline to this node", + ) + original_type: Optional[str] = Field( + default=None, + validate_default=True, + description="The dataset type of the underlying transcoded data node original version", + ) + transcoded_types: Optional[List[str]] = Field( + default=None, + validate_default=True, + description="The list of all dataset types for the transcoded versions", + ) + + # Statistics for the underlying data node + stats: Optional[Dict] = Field( + default=None, + validate_default=True, + description="The statistics for the transcoded data node metadata.", + ) + + @model_validator(mode="before") + @classmethod + def check_transcoded_data_node_exists(cls, values): + assert "transcoded_data_node" in values + cls.transcoded_data_node = values["transcoded_data_node"] + return values + + @field_validator("filepath") + @classmethod + def set_filepath(cls, _): + dataset_description = cls.transcoded_data_node.original_version._describe() + return _parse_filepath(dataset_description) + + @field_validator("run_command") + @classmethod + def set_run_command(cls, _): + if not cls.transcoded_data_node.is_free_input: + return f"kedro run --to-outputs={cls.transcoded_data_node.original_name}" + return None + + @field_validator("original_type") + @classmethod + def set_original_type(cls, _): + return get_dataset_type(cls.transcoded_data_node.original_version) + + @field_validator("transcoded_types") + @classmethod + def set_transcoded_types(cls, _): + return [ + get_dataset_type(transcoded_version) + for transcoded_version in cls.transcoded_data_node.transcoded_versions + ] + + @field_validator("stats") + @classmethod + def set_stats(cls, _): + return cls.transcoded_data_node.stats + + +class ParametersNodeMetadata(GraphNodeMetadata): + """Represent the metadata of a ParametersNode. + + Args: + parameters_node (ParametersNode): The underlying parameters node + for the parameters metadata node. + + Raises: + AssertionError: If parameters_node is not supplied during instantiation. + """ + + parameters_node: ParametersNode = Field(..., exclude=True) + parameters: Optional[Dict] = Field( + default=None, + validate_default=True, + description="The parameters dictionary for the parameters metadata node", + ) + + @model_validator(mode="before") + @classmethod + def check_parameters_node_exists(cls, values): + assert "parameters_node" in values + cls.parameters_node = values["parameters_node"] + return values + + @field_validator("parameters") + @classmethod + def set_parameters(cls, _): + if cls.parameters_node.is_single_parameter(): + return { + cls.parameters_node.parameter_name: cls.parameters_node.parameter_value + } + return cls.parameters_node.parameter_value diff --git a/package/kedro_viz/models/flowchart.py b/package/kedro_viz/models/flowchart/nodes.py similarity index 53% rename from package/kedro_viz/models/flowchart.py rename to package/kedro_viz/models/flowchart/nodes.py index 299dbc120e..0289fe1e1e 100644 --- a/package/kedro_viz/models/flowchart.py +++ b/package/kedro_viz/models/flowchart/nodes.py @@ -1,12 +1,8 @@ -"""`kedro_viz.models.flowchart` defines data models to represent Kedro entities in a viz graph.""" +"""`kedro_viz.models.flowchart.nodes` defines models to represent Kedro nodes in a viz graph.""" -import abc -import inspect import logging -from enum import Enum -from pathlib import Path -from types import FunctionType -from typing import Any, ClassVar, Dict, List, Optional, Set, Union, cast +from abc import ABC +from typing import Any, Dict, Optional, Set, Union, cast from fastapi.encoders import jsonable_encoder from kedro.pipeline.node import Node as KedroNode @@ -19,9 +15,6 @@ model_validator, ) -from kedro_viz.models.utils import get_dataset_type -from kedro_viz.utils import TRANSCODING_SEPARATOR, _strip_transcoding - try: # kedro 0.18.11 onwards from kedro.io.core import DatasetError @@ -35,75 +28,15 @@ # older versions from kedro.io.core import AbstractDataSet as AbstractDataset # type: ignore -logger = logging.getLogger(__name__) - - -def _parse_filepath(dataset_description: Dict[str, Any]) -> Optional[str]: - filepath = dataset_description.get("filepath") or dataset_description.get("path") - return str(filepath) if filepath else None - - -class NamedEntity(BaseModel): - """Represent a named entity (Tag/Registered Pipeline) in a Kedro project - Args: - id (str): Id of the registered pipeline - - Raises: - AssertionError: If id is not supplied during instantiation - """ - - id: str - name: Optional[str] = Field( - default=None, - validate_default=True, - description="The name of the registered pipeline", - ) - - @field_validator("name") - @classmethod - def set_name(cls, _, info: ValidationInfo): - assert "id" in info.data - return info.data["id"] - - -class RegisteredPipeline(NamedEntity): - """Represent a registered pipeline in a Kedro project""" - - -class GraphNodeType(str, Enum): - """Represent all possible node types in the graph representation of a Kedro pipeline. - The type needs to inherit from str as well so FastAPI can serialise it. See: - https://fastapi.tiangolo.com/tutorial/path-params/#working-with-python-enumerations - """ - - TASK = "task" - DATA = "data" - PARAMETERS = "parameters" - MODULAR_PIPELINE = ( - "modularPipeline" # camelCase so it can be referred directly to in the frontend - ) - - -class ModularPipelineChild(BaseModel, frozen=True): - """Represent a child of a modular pipeline. - - Args: - id (str): Id of the modular pipeline child - type (GraphNodeType): Type of modular pipeline child - """ - - id: str - type: GraphNodeType - +from kedro_viz.models.utils import get_dataset_type +from kedro_viz.utils import TRANSCODING_SEPARATOR, _strip_transcoding -class Tag(NamedEntity): - """Represent a tag in a Kedro project""" +from .model_utils import GraphNodeType - def __hash__(self) -> int: - return hash(self.id) +logger = logging.getLogger(__name__) -class GraphNode(BaseModel, abc.ABC): +class GraphNode(BaseModel, ABC): """Represent a node in the graph representation of a Kedro pipeline. All node models except the metadata node models should inherit from this class @@ -281,8 +214,16 @@ def has_metadata(self) -> bool: return self.kedro_obj is not None -class GraphNodeMetadata(BaseModel, abc.ABC): - """Represent a graph node's metadata""" +class ModularPipelineChild(BaseModel, frozen=True): + """Represent a child of a modular pipeline. + + Args: + id (str): Id of the modular pipeline child + type (GraphNodeType): Type of modular pipeline child + """ + + id: str + type: GraphNodeType class TaskNode(GraphNode): @@ -317,154 +258,6 @@ def set_namespace(cls, _, info: ValidationInfo): return info.data["kedro_obj"].namespace -def _extract_wrapped_func(func: FunctionType) -> FunctionType: - """Extract a wrapped decorated function to inspect the source code if available. - Adapted from https://stackoverflow.com/a/43506509/1684058 - """ - if func.__closure__ is None: - return func - closure = (c.cell_contents for c in func.__closure__) - wrapped_func = next((c for c in closure if isinstance(c, FunctionType)), None) - # return the original function if it's not a decorated function - return func if wrapped_func is None else wrapped_func - - -class ModularPipelineNode(GraphNode): - """Represent a modular pipeline node in the graph""" - - # A modular pipeline doesn't belong to any other modular pipeline, - # in the same sense as other types of GraphNode do. - # Therefore it's default to None. - # The parent-child relationship between modular pipeline themselves is modelled explicitly. - modular_pipelines: Optional[Set[str]] = None - - # Model the modular pipelines tree using a child-references representation of a tree. - # See: https://docs.mongodb.com/manual/tutorial/model-tree-structures-with-child-references/ - # for more details. - # For example, if a node namespace is "uk.data_science", - # the "uk" modular pipeline node's children are ["uk.data_science"] - children: Set[ModularPipelineChild] = Field( - set(), description="The children for the modular pipeline node" - ) - - inputs: Set[str] = Field( - set(), description="The input datasets to the modular pipeline node" - ) - - outputs: Set[str] = Field( - set(), description="The output datasets from the modular pipeline node" - ) - - # The type for Modular Pipeline Node - type: str = GraphNodeType.MODULAR_PIPELINE.value - - -class TaskNodeMetadata(GraphNodeMetadata): - """Represent the metadata of a TaskNode - - Args: - task_node (TaskNode): Task node to which this metadata belongs to. - - Raises: - AssertionError: If task_node is not supplied during instantiation - """ - - task_node: TaskNode = Field(..., exclude=True) - - code: Optional[str] = Field( - default=None, - validate_default=True, - description="Source code of the node's function", - ) - - filepath: Optional[str] = Field( - default=None, - validate_default=True, - description="Path to the file where the node is defined", - ) - - parameters: Optional[Dict] = Field( - default=None, - validate_default=True, - description="The parameters of the node, if available", - ) - run_command: Optional[str] = Field( - default=None, - validate_default=True, - description="The command to run the pipeline to this node", - ) - - inputs: Optional[List[str]] = Field( - default=None, validate_default=True, description="The inputs to the TaskNode" - ) - outputs: Optional[List[str]] = Field( - default=None, validate_default=True, description="The outputs from the TaskNode" - ) - - @model_validator(mode="before") - @classmethod - def check_task_node_exists(cls, values): - assert "task_node" in values - cls.set_task_and_kedro_node(values["task_node"]) - return values - - @classmethod - def set_task_and_kedro_node(cls, task_node): - cls.task_node = task_node - cls.kedro_node = cast(KedroNode, task_node.kedro_obj) - - @field_validator("code") - @classmethod - def set_code(cls, code): - # this is required to handle partial, curry functions - if inspect.isfunction(cls.kedro_node.func): - code = inspect.getsource(_extract_wrapped_func(cls.kedro_node.func)) - return code - - return None - - @field_validator("filepath") - @classmethod - def set_filepath(cls, filepath): - # this is required to handle partial, curry functions - if inspect.isfunction(cls.kedro_node.func): - code_full_path = ( - Path(inspect.getfile(cls.kedro_node.func)).expanduser().resolve() - ) - - try: - filepath = code_full_path.relative_to(Path.cwd().parent) - except ValueError: # pragma: no cover - # if the filepath can't be resolved relative to the current directory, - # e.g. either during tests or during launching development server - # outside of a Kedro project, simply return the fullpath to the file. - filepath = code_full_path - - return str(filepath) - - return None - - @field_validator("parameters") - @classmethod - def set_parameters(cls, _): - return cls.task_node.parameters - - @field_validator("run_command") - @classmethod - def set_run_command(cls, _): - return f"kedro run --to-nodes='{cls.kedro_node.name}'" - - @field_validator("inputs") - @classmethod - def set_inputs(cls, _): - return cls.kedro_node.inputs - - @field_validator("outputs") - @classmethod - def set_outputs(cls, _): - return cls.kedro_node.outputs - - class DataNode(GraphNode): """Represent a graph node of type data @@ -580,241 +373,6 @@ def has_metadata(self) -> bool: return True -class DataNodeMetadata(GraphNodeMetadata): - """Represent the metadata of a DataNode - - Args: - data_node (DataNode): Data node to which this metadata belongs to. - - Attributes: - is_all_previews_enabled (bool): Class-level attribute to determine if - previews are enabled for all nodes. This can be configured via CLI - or UI to manage the preview settings. - - Raises: - AssertionError: If data_node is not supplied during instantiation - """ - - data_node: DataNode = Field(..., exclude=True) - - is_all_previews_enabled: ClassVar[bool] = True - - type: Optional[str] = Field( - default=None, validate_default=True, description="The type of the data node" - ) - - filepath: Optional[str] = Field( - default=None, - validate_default=True, - description="The path to the actual data file for the underlying dataset", - ) - - run_command: Optional[str] = Field( - default=None, - validate_default=True, - description="Command to run the pipeline to this node", - ) - - preview: Optional[Union[Dict, str]] = Field( - default=None, - validate_default=True, - description="Preview data for the underlying datanode", - ) - - preview_type: Optional[str] = Field( - default=None, - validate_default=True, - description="Type of preview for the dataset", - ) - - stats: Optional[Dict] = Field( - default=None, - validate_default=True, - description="The statistics for the data node.", - ) - - @model_validator(mode="before") - @classmethod - def check_data_node_exists(cls, values): - assert "data_node" in values - cls.set_data_node_and_dataset(values["data_node"]) - return values - - @classmethod - def set_is_all_previews_enabled(cls, value: bool): - cls.is_all_previews_enabled = value - - @classmethod - def set_data_node_and_dataset(cls, data_node): - cls.data_node = data_node - cls.dataset = cast(AbstractDataset, data_node.kedro_obj) - - # dataset.release clears the cache before loading to ensure that this issue - # does not arise: https://github.com/kedro-org/kedro-viz/pull/573. - cls.dataset.release() - - @field_validator("type") - @classmethod - def set_type(cls, _): - return cls.data_node.dataset_type - - @field_validator("filepath") - @classmethod - def set_filepath(cls, _): - dataset_description = cls.dataset._describe() - return _parse_filepath(dataset_description) - - @field_validator("run_command") - @classmethod - def set_run_command(cls, _): - if not cls.data_node.is_free_input: - return f"kedro run --to-outputs={cls.data_node.name}" - return None - - @field_validator("preview") - @classmethod - def set_preview(cls, _): - if ( - not cls.data_node.is_preview_enabled() - or not hasattr(cls.dataset, "preview") - or not cls.is_all_previews_enabled - ): - return None - - try: - preview_args = ( - cls.data_node.get_preview_args() if cls.data_node.viz_metadata else None - ) - if preview_args is None: - return cls.dataset.preview() - return cls.dataset.preview(**preview_args) - - except Exception as exc: # noqa: BLE001 - logger.warning( - "'%s' could not be previewed. Full exception: %s: %s", - cls.data_node.name, - type(exc).__name__, - exc, - ) - return None - - @field_validator("preview_type") - @classmethod - def set_preview_type(cls, _): - if ( - not cls.data_node.is_preview_enabled() - or not hasattr(cls.dataset, "preview") - or not cls.is_all_previews_enabled - ): - return None - - try: - preview_type_annotation = inspect.signature( - cls.dataset.preview - ).return_annotation - # Attempt to get the name attribute, if it exists. - # Otherwise, use str to handle the annotation directly. - preview_type_name = getattr( - preview_type_annotation, "__name__", str(preview_type_annotation) - ) - return preview_type_name - - except Exception as exc: # noqa: BLE001 # pragma: no cover - logger.warning( - "'%s' did not have preview type. Full exception: %s: %s", - cls.data_node.name, - type(exc).__name__, - exc, - ) - return None - - @field_validator("stats") - @classmethod - def set_stats(cls, _): - return cls.data_node.stats - - -class TranscodedDataNodeMetadata(GraphNodeMetadata): - """Represent the metadata of a TranscodedDataNode - Args: - transcoded_data_node (TranscodedDataNode): The underlying transcoded - data node to which this metadata belongs to. - - Raises: - AssertionError: If transcoded_data_node is not supplied during instantiation - """ - - transcoded_data_node: TranscodedDataNode = Field(..., exclude=True) - - # Only available if the dataset has filepath set. - filepath: Optional[str] = Field( - default=None, - validate_default=True, - description="The path to the actual data file for the underlying dataset", - ) - - run_command: Optional[str] = Field( - default=None, - validate_default=True, - description="Command to run the pipeline to this node", - ) - original_type: Optional[str] = Field( - default=None, - validate_default=True, - description="The dataset type of the underlying transcoded data node original version", - ) - transcoded_types: Optional[List[str]] = Field( - default=None, - validate_default=True, - description="The list of all dataset types for the transcoded versions", - ) - - # Statistics for the underlying data node - stats: Optional[Dict] = Field( - default=None, - validate_default=True, - description="The statistics for the transcoded data node metadata.", - ) - - @model_validator(mode="before") - @classmethod - def check_transcoded_data_node_exists(cls, values): - assert "transcoded_data_node" in values - cls.transcoded_data_node = values["transcoded_data_node"] - return values - - @field_validator("filepath") - @classmethod - def set_filepath(cls, _): - dataset_description = cls.transcoded_data_node.original_version._describe() - return _parse_filepath(dataset_description) - - @field_validator("run_command") - @classmethod - def set_run_command(cls, _): - if not cls.transcoded_data_node.is_free_input: - return f"kedro run --to-outputs={cls.transcoded_data_node.original_name}" - return None - - @field_validator("original_type") - @classmethod - def set_original_type(cls, _): - return get_dataset_type(cls.transcoded_data_node.original_version) - - @field_validator("transcoded_types") - @classmethod - def set_transcoded_types(cls, _): - return [ - get_dataset_type(transcoded_version) - for transcoded_version in cls.transcoded_data_node.transcoded_versions - ] - - @field_validator("stats") - @classmethod - def set_stats(cls, _): - return cls.transcoded_data_node.stats - - class ParametersNode(GraphNode): """Represent a graph node of type parameters Args: @@ -882,48 +440,31 @@ def parameter_value(self) -> Any: return None -class ParametersNodeMetadata(GraphNodeMetadata): - """Represent the metadata of a ParametersNode - - Args: - parameters_node (ParametersNode): The underlying parameters node - for the parameters metadata node. +class ModularPipelineNode(GraphNode): + """Represent a modular pipeline node in the graph""" - Raises: - AssertionError: If parameters_node is not supplied during instantiation - """ + # A modular pipeline doesn't belong to any other modular pipeline, + # in the same sense as other types of GraphNode do. + # Therefore, it's default to None. + # The parent-child relationship between modular pipeline themselves is modelled explicitly. + modular_pipelines: Optional[Set[str]] = None - parameters_node: ParametersNode = Field(..., exclude=True) - parameters: Optional[Dict] = Field( - default=None, - validate_default=True, - description="The parameters dictionary for the parameters metadata node", + # Model the modular pipelines tree using a child-references representation of a tree. + # See: https://docs.mongodb.com/manual/tutorial/model-tree-structures-with-child-references/ + # for more details. + # For example, if a node namespace is "uk.data_science", + # the "uk" modular pipeline node's children are ["uk.data_science"] + children: Set[ModularPipelineChild] = Field( + set(), description="The children for the modular pipeline node" ) - @model_validator(mode="before") - @classmethod - def check_parameters_node_exists(cls, values): - assert "parameters_node" in values - cls.parameters_node = values["parameters_node"] - return values - - @field_validator("parameters") - @classmethod - def set_parameters(cls, _): - if cls.parameters_node.is_single_parameter(): - return { - cls.parameters_node.parameter_name: cls.parameters_node.parameter_value - } - return cls.parameters_node.parameter_value - - -class GraphEdge(BaseModel, frozen=True): - """Represent an edge in the graph + inputs: Set[str] = Field( + set(), description="The input datasets to the modular pipeline node" + ) - Args: - source (str): The id of the source node. - target (str): The id of the target node. - """ + outputs: Set[str] = Field( + set(), description="The output datasets from the modular pipeline node" + ) - source: str - target: str + # The type for Modular Pipeline Node + type: str = GraphNodeType.MODULAR_PIPELINE.value diff --git a/package/kedro_viz/services/layers.py b/package/kedro_viz/services/layers.py index f8840534fc..7cba369aa1 100644 --- a/package/kedro_viz/services/layers.py +++ b/package/kedro_viz/services/layers.py @@ -5,7 +5,7 @@ from graphlib import CycleError, TopologicalSorter from typing import Dict, List, Set -from kedro_viz.models.flowchart import GraphNode +from kedro_viz.models.flowchart.nodes import GraphNode logger = logging.getLogger(__name__) diff --git a/package/tests/conftest.py b/package/tests/conftest.py index 7c66051328..c6b802974a 100644 --- a/package/tests/conftest.py +++ b/package/tests/conftest.py @@ -21,7 +21,8 @@ ) from kedro_viz.integrations.kedro.hooks import DatasetStatsHook from kedro_viz.integrations.kedro.sqlite_store import SQLiteStore -from kedro_viz.models.flowchart import DataNodeMetadata, GraphNode +from kedro_viz.models.flowchart.node_metadata import DataNodeMetadata +from kedro_viz.models.flowchart.nodes import GraphNode from kedro_viz.server import populate_data diff --git a/package/tests/test_api/test_rest/test_responses.py b/package/tests/test_api/test_rest/test_responses.py index 6f4581d3a3..8dbf549416 100644 --- a/package/tests/test_api/test_rest/test_responses.py +++ b/package/tests/test_api/test_rest/test_responses.py @@ -19,7 +19,7 @@ save_api_responses_to_fs, write_api_response_to_fs, ) -from kedro_viz.models.flowchart import TaskNode +from kedro_viz.models.flowchart.nodes import TaskNode from kedro_viz.models.metadata import Metadata diff --git a/package/tests/test_data_access/test_managers.py b/package/tests/test_data_access/test_managers.py index 66bd08f1e9..abb8df9be5 100644 --- a/package/tests/test_data_access/test_managers.py +++ b/package/tests/test_data_access/test_managers.py @@ -15,11 +15,11 @@ ModularPipelinesRepository, ) from kedro_viz.integrations.utils import UnavailableDataset -from kedro_viz.models.flowchart import ( +from kedro_viz.models.flowchart.edge import GraphEdge +from kedro_viz.models.flowchart.named_entities import Tag +from kedro_viz.models.flowchart.nodes import ( DataNode, - GraphEdge, ParametersNode, - Tag, TaskNode, TranscodedDataNode, ) diff --git a/package/tests/test_data_access/test_repositories/test_graph.py b/package/tests/test_data_access/test_repositories/test_graph.py index c45232ebd1..51f8684368 100644 --- a/package/tests/test_data_access/test_repositories/test_graph.py +++ b/package/tests/test_data_access/test_repositories/test_graph.py @@ -4,7 +4,8 @@ GraphEdgesRepository, GraphNodesRepository, ) -from kedro_viz.models.flowchart import GraphEdge, GraphNode +from kedro_viz.models.flowchart.edge import GraphEdge +from kedro_viz.models.flowchart.nodes import GraphNode class TestGraphNodeRepository: diff --git a/package/tests/test_data_access/test_repositories/test_modular_pipelines.py b/package/tests/test_data_access/test_repositories/test_modular_pipelines.py index 5b5a5e783b..ef6058ca8b 100644 --- a/package/tests/test_data_access/test_repositories/test_modular_pipelines.py +++ b/package/tests/test_data_access/test_repositories/test_modular_pipelines.py @@ -6,11 +6,8 @@ from kedro_viz.constants import ROOT_MODULAR_PIPELINE_ID from kedro_viz.data_access.repositories import ModularPipelinesRepository -from kedro_viz.models.flowchart import ( - GraphNodeType, - ModularPipelineChild, - ModularPipelineNode, -) +from kedro_viz.models.flowchart.model_utils import GraphNodeType +from kedro_viz.models.flowchart.nodes import ModularPipelineChild, ModularPipelineNode @pytest.fixture diff --git a/package/tests/test_models/test_flowchart/__init__.py b/package/tests/test_models/test_flowchart/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/package/tests/test_models/test_flowchart.py b/package/tests/test_models/test_flowchart/test_node_metadata.py similarity index 55% rename from package/tests/test_models/test_flowchart.py rename to package/tests/test_models/test_flowchart/test_node_metadata.py index 01238f286d..f8ebd4f8ec 100644 --- a/package/tests/test_models/test_flowchart.py +++ b/package/tests/test_models/test_flowchart/test_node_metadata.py @@ -1,7 +1,6 @@ from functools import partial from pathlib import Path from textwrap import dedent -from unittest.mock import call, patch import pytest from kedro.io import MemoryDataset @@ -9,18 +8,13 @@ from kedro_datasets.pandas import CSVDataset, ParquetDataset from kedro_datasets.partitions.partitioned_dataset import PartitionedDataset -from kedro_viz.models.flowchart import ( - DataNode, +from kedro_viz.models.flowchart.node_metadata import ( DataNodeMetadata, - GraphNode, - ParametersNode, ParametersNodeMetadata, - RegisteredPipeline, - TaskNode, TaskNodeMetadata, - TranscodedDataNode, TranscodedDataNodeMetadata, ) +from kedro_viz.models.flowchart.nodes import GraphNode def identity(x): @@ -56,264 +50,6 @@ def full_func(a, b, c, x): partial_func = partial(full_func, 3, 1, 4) -class TestGraphNodeCreation: - @pytest.mark.parametrize( - "namespace,expected_modular_pipelines", - [ - (None, set()), - ( - "uk.data_science.model_training", - set( - [ - "uk", - "uk.data_science", - "uk.data_science.model_training", - ] - ), - ), - ], - ) - def test_create_task_node(self, namespace, expected_modular_pipelines): - kedro_node = node( - identity, - inputs="x", - outputs="y", - name="identity_node", - tags={"tag"}, - namespace=namespace, - ) - task_node = GraphNode.create_task_node( - kedro_node, "identity_node", expected_modular_pipelines - ) - assert isinstance(task_node, TaskNode) - assert task_node.kedro_obj is kedro_node - assert task_node.name == "identity_node" - assert task_node.tags == {"tag"} - assert task_node.pipelines == set() - assert task_node.modular_pipelines == expected_modular_pipelines - assert task_node.namespace == namespace - - @pytest.mark.parametrize( - "dataset_name, expected_modular_pipelines", - [ - ("dataset", set()), - ( - "uk.data_science.model_training.dataset", - set( - [ - "uk", - "uk.data_science", - "uk.data_science.model_training", - ] - ), - ), - ], - ) - def test_create_data_node(self, dataset_name, expected_modular_pipelines): - kedro_dataset = CSVDataset(filepath="foo.csv") - data_node = GraphNode.create_data_node( - dataset_id=dataset_name, - dataset_name=dataset_name, - layer="raw", - tags=set(), - dataset=kedro_dataset, - stats={"rows": 10, "columns": 5, "file_size": 1024}, - modular_pipelines=set(expected_modular_pipelines), - ) - assert isinstance(data_node, DataNode) - assert data_node.kedro_obj is kedro_dataset - assert data_node.id == dataset_name - assert data_node.name == dataset_name - assert data_node.layer == "raw" - assert data_node.tags == set() - assert data_node.pipelines == set() - assert data_node.modular_pipelines == expected_modular_pipelines - assert data_node.stats["rows"] == 10 - assert data_node.stats["columns"] == 5 - assert data_node.stats["file_size"] == 1024 - - @pytest.mark.parametrize( - "transcoded_dataset_name, original_name", - [ - ("dataset@pandas2", "dataset"), - ( - "uk.data_science.model_training.dataset@pandas2", - "uk.data_science.model_training.dataset", - ), - ], - ) - def test_create_transcoded_data_node(self, transcoded_dataset_name, original_name): - kedro_dataset = CSVDataset(filepath="foo.csv") - data_node = GraphNode.create_data_node( - dataset_id=original_name, - dataset_name=transcoded_dataset_name, - layer="raw", - tags=set(), - dataset=kedro_dataset, - stats={"rows": 10, "columns": 2, "file_size": 1048}, - modular_pipelines=set(), - ) - assert isinstance(data_node, TranscodedDataNode) - assert data_node.id == original_name - assert data_node.name == original_name - assert data_node.layer == "raw" - assert data_node.tags == set() - assert data_node.pipelines == set() - assert data_node.stats["rows"] == 10 - assert data_node.stats["columns"] == 2 - assert data_node.stats["file_size"] == 1048 - - def test_create_parameters_all_parameters(self): - parameters_dataset = MemoryDataset( - data={"test_split_ratio": 0.3, "num_epochs": 1000} - ) - parameters_node = GraphNode.create_parameters_node( - dataset_id="parameters", - dataset_name="parameters", - layer=None, - tags=set(), - parameters=parameters_dataset, - modular_pipelines=set(), - ) - assert isinstance(parameters_node, ParametersNode) - assert parameters_node.kedro_obj is parameters_dataset - assert parameters_node.id == "parameters" - assert parameters_node.is_all_parameters() - assert not parameters_node.is_single_parameter() - assert parameters_node.parameter_value == { - "test_split_ratio": 0.3, - "num_epochs": 1000, - } - assert not parameters_node.modular_pipelines - - @pytest.mark.parametrize( - "dataset_name,expected_modular_pipelines", - [ - ("params:test_split_ratio", set()), - ( - "params:uk.data_science.model_training.test_split_ratio", - set(["uk", "uk.data_science", "uk.data_science.model_training"]), - ), - ], - ) - def test_create_parameters_node_single_parameter( - self, dataset_name, expected_modular_pipelines - ): - parameters_dataset = MemoryDataset(data=0.3) - parameters_node = GraphNode.create_parameters_node( - dataset_id=dataset_name, - dataset_name=dataset_name, - layer=None, - tags=set(), - parameters=parameters_dataset, - modular_pipelines=expected_modular_pipelines, - ) - assert isinstance(parameters_node, ParametersNode) - assert parameters_node.kedro_obj is parameters_dataset - assert not parameters_node.is_all_parameters() - assert parameters_node.is_single_parameter() - assert parameters_node.parameter_value == 0.3 - assert parameters_node.modular_pipelines == expected_modular_pipelines - - def test_create_single_parameter_with_complex_type(self): - parameters_dataset = MemoryDataset(data=object()) - parameters_node = GraphNode.create_parameters_node( - dataset_id="params:test_split_ratio", - dataset_name="params:test_split_ratio", - layer=None, - tags=set(), - parameters=parameters_dataset, - modular_pipelines=set(), - ) - assert isinstance(parameters_node, ParametersNode) - assert parameters_node.kedro_obj is parameters_dataset - assert not parameters_node.is_all_parameters() - assert parameters_node.is_single_parameter() - assert isinstance(parameters_node.parameter_value, str) - - def test_create_all_parameters_with_complex_type(self): - mock_object = object() - parameters_dataset = MemoryDataset( - data={ - "test_split_ratio": 0.3, - "num_epochs": 1000, - "complex_param": mock_object, - } - ) - parameters_node = GraphNode.create_parameters_node( - dataset_id="parameters", - dataset_name="parameters", - layer=None, - tags=set(), - parameters=parameters_dataset, - modular_pipelines=set(), - ) - assert isinstance(parameters_node, ParametersNode) - assert parameters_node.kedro_obj is parameters_dataset - assert parameters_node.id == "parameters" - assert parameters_node.is_all_parameters() - assert not parameters_node.is_single_parameter() - assert isinstance(parameters_node.parameter_value, str) - - def test_create_non_existing_parameter_node(self): - """Test the case where ``parameters`` is equal to None""" - parameters_node = GraphNode.create_parameters_node( - dataset_id="non_existing", - dataset_name="non_existing", - layer=None, - tags=set(), - parameters=None, - modular_pipelines=set(), - ) - assert isinstance(parameters_node, ParametersNode) - assert parameters_node.parameter_value is None - - @patch("logging.Logger.warning") - def test_create_non_existing_parameter_node_empty_dataset(self, patched_warning): - """Test the case where ``parameters`` is equal to a MemoryDataset with no data""" - parameters_dataset = MemoryDataset() - parameters_node = GraphNode.create_parameters_node( - dataset_id="non_existing", - dataset_name="non_existing", - layer=None, - tags=set(), - parameters=parameters_dataset, - modular_pipelines=set(), - ) - assert parameters_node.parameter_value is None - patched_warning.assert_has_calls( - [call("Cannot find parameter `%s` in the catalog.", "non_existing")] - ) - - -class TestGraphNodePipelines: - def test_registered_pipeline_name(self): - pipeline = RegisteredPipeline(id="__default__") - assert pipeline.name == "__default__" - - def test_modular_pipeline_name(self): - pipeline = GraphNode.create_modular_pipeline_node("data_engineering") - assert pipeline.name == "data_engineering" - - def test_add_node_to_pipeline(self): - default_pipeline = RegisteredPipeline(id="__default__") - another_pipeline = RegisteredPipeline(id="testing") - kedro_dataset = CSVDataset(filepath="foo.csv") - data_node = GraphNode.create_data_node( - dataset_id="dataset@transcoded", - dataset_name="dataset@transcoded", - layer="raw", - tags=set(), - dataset=kedro_dataset, - stats={"rows": 10, "columns": 2, "file_size": 1048}, - modular_pipelines=set(), - ) - assert data_node.pipelines == set() - data_node.add_pipeline(default_pipeline.id) - assert data_node.belongs_to_pipeline(default_pipeline.id) - assert not data_node.belongs_to_pipeline(another_pipeline.id) - - class TestGraphNodeMetadata: @pytest.mark.parametrize( "dataset,has_metadata", [(MemoryDataset(data=1), True), (None, False)] diff --git a/package/tests/test_models/test_flowchart/test_nodes.py b/package/tests/test_models/test_flowchart/test_nodes.py new file mode 100644 index 0000000000..2d7a59d338 --- /dev/null +++ b/package/tests/test_models/test_flowchart/test_nodes.py @@ -0,0 +1,248 @@ +from unittest.mock import call, patch + +import pytest +from kedro.io import MemoryDataset +from kedro.pipeline.node import node +from kedro_datasets.pandas import CSVDataset + +from kedro_viz.models.flowchart.nodes import ( + DataNode, + GraphNode, + ParametersNode, + TaskNode, + TranscodedDataNode, +) + + +def identity(x): + return x + + +class TestGraphNodeCreation: + @pytest.mark.parametrize( + "namespace,expected_modular_pipelines", + [ + (None, set()), + ( + "uk.data_science.model_training", + set( + [ + "uk", + "uk.data_science", + "uk.data_science.model_training", + ] + ), + ), + ], + ) + def test_create_task_node(self, namespace, expected_modular_pipelines): + kedro_node = node( + identity, + inputs="x", + outputs="y", + name="identity_node", + tags={"tag"}, + namespace=namespace, + ) + task_node = GraphNode.create_task_node( + kedro_node, "identity_node", expected_modular_pipelines + ) + assert isinstance(task_node, TaskNode) + assert task_node.kedro_obj is kedro_node + assert task_node.name == "identity_node" + assert task_node.tags == {"tag"} + assert task_node.pipelines == set() + assert task_node.modular_pipelines == expected_modular_pipelines + assert task_node.namespace == namespace + + @pytest.mark.parametrize( + "dataset_name, expected_modular_pipelines", + [ + ("dataset", set()), + ( + "uk.data_science.model_training.dataset", + set( + [ + "uk", + "uk.data_science", + "uk.data_science.model_training", + ] + ), + ), + ], + ) + def test_create_data_node(self, dataset_name, expected_modular_pipelines): + kedro_dataset = CSVDataset(filepath="foo.csv") + data_node = GraphNode.create_data_node( + dataset_id=dataset_name, + dataset_name=dataset_name, + layer="raw", + tags=set(), + dataset=kedro_dataset, + stats={"rows": 10, "columns": 5, "file_size": 1024}, + modular_pipelines=set(expected_modular_pipelines), + ) + assert isinstance(data_node, DataNode) + assert data_node.kedro_obj is kedro_dataset + assert data_node.id == dataset_name + assert data_node.name == dataset_name + assert data_node.layer == "raw" + assert data_node.tags == set() + assert data_node.pipelines == set() + assert data_node.modular_pipelines == expected_modular_pipelines + assert data_node.stats["rows"] == 10 + assert data_node.stats["columns"] == 5 + assert data_node.stats["file_size"] == 1024 + + @pytest.mark.parametrize( + "transcoded_dataset_name, original_name", + [ + ("dataset@pandas2", "dataset"), + ( + "uk.data_science.model_training.dataset@pandas2", + "uk.data_science.model_training.dataset", + ), + ], + ) + def test_create_transcoded_data_node(self, transcoded_dataset_name, original_name): + kedro_dataset = CSVDataset(filepath="foo.csv") + data_node = GraphNode.create_data_node( + dataset_id=original_name, + dataset_name=transcoded_dataset_name, + layer="raw", + tags=set(), + dataset=kedro_dataset, + stats={"rows": 10, "columns": 2, "file_size": 1048}, + modular_pipelines=set(), + ) + assert isinstance(data_node, TranscodedDataNode) + assert data_node.id == original_name + assert data_node.name == original_name + assert data_node.layer == "raw" + assert data_node.tags == set() + assert data_node.pipelines == set() + assert data_node.stats["rows"] == 10 + assert data_node.stats["columns"] == 2 + assert data_node.stats["file_size"] == 1048 + + def test_create_parameters_all_parameters(self): + parameters_dataset = MemoryDataset( + data={"test_split_ratio": 0.3, "num_epochs": 1000} + ) + parameters_node = GraphNode.create_parameters_node( + dataset_id="parameters", + dataset_name="parameters", + layer=None, + tags=set(), + parameters=parameters_dataset, + modular_pipelines=set(), + ) + assert isinstance(parameters_node, ParametersNode) + assert parameters_node.kedro_obj is parameters_dataset + assert parameters_node.id == "parameters" + assert parameters_node.is_all_parameters() + assert not parameters_node.is_single_parameter() + assert parameters_node.parameter_value == { + "test_split_ratio": 0.3, + "num_epochs": 1000, + } + assert not parameters_node.modular_pipelines + + @pytest.mark.parametrize( + "dataset_name,expected_modular_pipelines", + [ + ("params:test_split_ratio", set()), + ( + "params:uk.data_science.model_training.test_split_ratio", + set(["uk", "uk.data_science", "uk.data_science.model_training"]), + ), + ], + ) + def test_create_parameters_node_single_parameter( + self, dataset_name, expected_modular_pipelines + ): + parameters_dataset = MemoryDataset(data=0.3) + parameters_node = GraphNode.create_parameters_node( + dataset_id=dataset_name, + dataset_name=dataset_name, + layer=None, + tags=set(), + parameters=parameters_dataset, + modular_pipelines=expected_modular_pipelines, + ) + assert isinstance(parameters_node, ParametersNode) + assert parameters_node.kedro_obj is parameters_dataset + assert not parameters_node.is_all_parameters() + assert parameters_node.is_single_parameter() + assert parameters_node.parameter_value == 0.3 + assert parameters_node.modular_pipelines == expected_modular_pipelines + + def test_create_single_parameter_with_complex_type(self): + parameters_dataset = MemoryDataset(data=object()) + parameters_node = GraphNode.create_parameters_node( + dataset_id="params:test_split_ratio", + dataset_name="params:test_split_ratio", + layer=None, + tags=set(), + parameters=parameters_dataset, + modular_pipelines=set(), + ) + assert isinstance(parameters_node, ParametersNode) + assert parameters_node.kedro_obj is parameters_dataset + assert not parameters_node.is_all_parameters() + assert parameters_node.is_single_parameter() + assert isinstance(parameters_node.parameter_value, str) + + def test_create_all_parameters_with_complex_type(self): + mock_object = object() + parameters_dataset = MemoryDataset( + data={ + "test_split_ratio": 0.3, + "num_epochs": 1000, + "complex_param": mock_object, + } + ) + parameters_node = GraphNode.create_parameters_node( + dataset_id="parameters", + dataset_name="parameters", + layer=None, + tags=set(), + parameters=parameters_dataset, + modular_pipelines=set(), + ) + assert isinstance(parameters_node, ParametersNode) + assert parameters_node.kedro_obj is parameters_dataset + assert parameters_node.id == "parameters" + assert parameters_node.is_all_parameters() + assert not parameters_node.is_single_parameter() + assert isinstance(parameters_node.parameter_value, str) + + def test_create_non_existing_parameter_node(self): + """Test the case where ``parameters`` is equal to None""" + parameters_node = GraphNode.create_parameters_node( + dataset_id="non_existing", + dataset_name="non_existing", + layer=None, + tags=set(), + parameters=None, + modular_pipelines=set(), + ) + assert isinstance(parameters_node, ParametersNode) + assert parameters_node.parameter_value is None + + @patch("logging.Logger.warning") + def test_create_non_existing_parameter_node_empty_dataset(self, patched_warning): + """Test the case where ``parameters`` is equal to a MemoryDataset with no data""" + parameters_dataset = MemoryDataset() + parameters_node = GraphNode.create_parameters_node( + dataset_id="non_existing", + dataset_name="non_existing", + layer=None, + tags=set(), + parameters=parameters_dataset, + modular_pipelines=set(), + ) + assert parameters_node.parameter_value is None + patched_warning.assert_has_calls( + [call("Cannot find parameter `%s` in the catalog.", "non_existing")] + ) diff --git a/package/tests/test_models/test_flowchart/test_pipeline.py b/package/tests/test_models/test_flowchart/test_pipeline.py new file mode 100644 index 0000000000..520aff01d9 --- /dev/null +++ b/package/tests/test_models/test_flowchart/test_pipeline.py @@ -0,0 +1,32 @@ +from kedro_datasets.pandas import CSVDataset + +from kedro_viz.models.flowchart.named_entities import RegisteredPipeline +from kedro_viz.models.flowchart.nodes import GraphNode + + +class TestGraphNodePipelines: + def test_registered_pipeline_name(self): + pipeline = RegisteredPipeline(id="__default__") + assert pipeline.name == "__default__" + + def test_modular_pipeline_name(self): + pipeline = GraphNode.create_modular_pipeline_node("data_engineering") + assert pipeline.name == "data_engineering" + + def test_add_node_to_pipeline(self): + default_pipeline = RegisteredPipeline(id="__default__") + another_pipeline = RegisteredPipeline(id="testing") + kedro_dataset = CSVDataset(filepath="foo.csv") + data_node = GraphNode.create_data_node( + dataset_id="dataset@transcoded", + dataset_name="dataset@transcoded", + layer="raw", + tags=set(), + dataset=kedro_dataset, + stats={"rows": 10, "columns": 2, "file_size": 1048}, + modular_pipelines=set(), + ) + assert data_node.pipelines == set() + data_node.add_pipeline(default_pipeline.id) + assert data_node.belongs_to_pipeline(default_pipeline.id) + assert not data_node.belongs_to_pipeline(another_pipeline.id) diff --git a/package/tests/test_services/test_layers.py b/package/tests/test_services/test_layers.py index 80d76fae5a..c949a9f98b 100644 --- a/package/tests/test_services/test_layers.py +++ b/package/tests/test_services/test_layers.py @@ -1,6 +1,6 @@ import pytest -from kedro_viz.models.flowchart import GraphNode +from kedro_viz.models.flowchart.nodes import GraphNode from kedro_viz.services.layers import sort_layers diff --git a/ruff.toml b/ruff.toml index 52a1d6c8f3..166d54a4a7 100644 --- a/ruff.toml +++ b/ruff.toml @@ -45,7 +45,8 @@ ignore = [ "package/features/steps/sh_run.py" = ["PLW1510"] # `subprocess.run` without explicit `check` argument "*/tests/*.py" = ["SLF", "D", "ARG"] "package/kedro_viz/models/experiment_tracking.py" = ["SLF"] -"package/kedro_viz/models/flowchart.py" = ["SLF"] +"package/kedro_viz/models/flowchart/nodes.py" = ["SLF"] +"package/kedro_viz/models/flowchart/node_metadata.py" = ["SLF"] "package/kedro_viz/integrations/kedro/hooks.py" = ["SLF", "BLE"] "package/kedro_viz/integrations/kedro/sqlite_store.py" = ["BLE"] "package/kedro_viz/integrations/kedro/data_loader.py" = ["SLF"] From 31b0492d7161914252fcbb894e0eea82ec5037c6 Mon Sep 17 00:00:00 2001 From: Ravi Kumar Pilla Date: Wed, 30 Oct 2024 14:33:31 -0500 Subject: [PATCH 2/2] Refactor response classes (#2113) * sync remote * refactor response classes * adjust permissions * adjuste permissions * move from multiprocessing to threading * fix file perm * revert threading * revert except block * adjust lint and tests * fix lint * fix perm * update comments * changes based on PR comments * changes based on PR comments * update file comments * remove pylint * adjust attributes * test assert helper * test assert helper * test assert helper * test assert helper --- package/kedro_viz/api/apps.py | 2 +- package/kedro_viz/api/rest/responses.py | 492 --------------- .../kedro_viz/api/rest/responses/__init__.py | 0 package/kedro_viz/api/rest/responses/base.py | 28 + .../kedro_viz/api/rest/responses/metadata.py | 47 ++ package/kedro_viz/api/rest/responses/nodes.py | 162 +++++ .../kedro_viz/api/rest/responses/pipelines.py | 256 ++++++++ .../api/rest/responses/save_responses.py | 97 +++ package/kedro_viz/api/rest/responses/utils.py | 44 ++ package/kedro_viz/api/rest/router.py | 35 +- package/kedro_viz/data_access/managers.py | 3 +- .../integrations/deployment/base_deployer.py | 2 +- package/kedro_viz/launchers/cli/deploy.py | 6 +- package/kedro_viz/launchers/cli/run.py | 2 +- package/kedro_viz/launchers/cli/utils.py | 13 +- package/kedro_viz/launchers/utils.py | 11 + package/kedro_viz/server.py | 14 +- package/tests/conftest.py | 28 +- .../test_rest/test_responses/__init__.py | 0 .../assert_helpers.py} | 570 +----------------- .../test_rest/test_responses/test_base.py | 10 + .../test_rest/test_responses/test_metadata.py | 24 + .../test_rest/test_responses/test_nodes.py | 91 +++ .../test_responses/test_pipelines.py | 241 ++++++++ .../test_responses/test_save_responses.py | 168 ++++++ .../test_rest/test_responses/test_utils.py | 43 ++ .../tests/test_api/test_rest/test_router.py | 2 +- package/tests/test_server.py | 2 +- 28 files changed, 1310 insertions(+), 1083 deletions(-) delete mode 100644 package/kedro_viz/api/rest/responses.py create mode 100644 package/kedro_viz/api/rest/responses/__init__.py create mode 100755 package/kedro_viz/api/rest/responses/base.py create mode 100755 package/kedro_viz/api/rest/responses/metadata.py create mode 100644 package/kedro_viz/api/rest/responses/nodes.py create mode 100644 package/kedro_viz/api/rest/responses/pipelines.py create mode 100644 package/kedro_viz/api/rest/responses/save_responses.py create mode 100644 package/kedro_viz/api/rest/responses/utils.py create mode 100755 package/tests/test_api/test_rest/test_responses/__init__.py rename package/tests/test_api/test_rest/{test_responses.py => test_responses/assert_helpers.py} (50%) create mode 100755 package/tests/test_api/test_rest/test_responses/test_base.py create mode 100755 package/tests/test_api/test_rest/test_responses/test_metadata.py create mode 100644 package/tests/test_api/test_rest/test_responses/test_nodes.py create mode 100755 package/tests/test_api/test_rest/test_responses/test_pipelines.py create mode 100644 package/tests/test_api/test_rest/test_responses/test_save_responses.py create mode 100644 package/tests/test_api/test_rest/test_responses/test_utils.py diff --git a/package/kedro_viz/api/apps.py b/package/kedro_viz/api/apps.py index d5b5c535ca..e188ab1911 100644 --- a/package/kedro_viz/api/apps.py +++ b/package/kedro_viz/api/apps.py @@ -15,7 +15,7 @@ from jinja2 import Environment, FileSystemLoader from kedro_viz import __version__ -from kedro_viz.api.rest.responses import EnhancedORJSONResponse +from kedro_viz.api.rest.responses.utils import EnhancedORJSONResponse from kedro_viz.integrations.kedro import telemetry as kedro_telemetry from .graphql.router import router as graphql_router diff --git a/package/kedro_viz/api/rest/responses.py b/package/kedro_viz/api/rest/responses.py deleted file mode 100644 index 1e885eced1..0000000000 --- a/package/kedro_viz/api/rest/responses.py +++ /dev/null @@ -1,492 +0,0 @@ -"""`kedro_viz.api.rest.responses` defines REST response types.""" - -import abc -import json -import logging -from typing import Any, Dict, List, Optional, Union - -import orjson -from fastapi.encoders import jsonable_encoder -from fastapi.responses import JSONResponse, ORJSONResponse -from pydantic import BaseModel, ConfigDict - -from kedro_viz.api.rest.utils import get_package_compatibilities -from kedro_viz.data_access import data_access_manager -from kedro_viz.models.flowchart.node_metadata import ( - DataNodeMetadata, - ParametersNodeMetadata, - TaskNodeMetadata, - TranscodedDataNodeMetadata, -) -from kedro_viz.models.flowchart.nodes import DataNode, TaskNode, TranscodedDataNode -from kedro_viz.models.metadata import Metadata, PackageCompatibility - -logger = logging.getLogger(__name__) - - -class APIErrorMessage(BaseModel): - message: str - - -class BaseAPIResponse(BaseModel, abc.ABC): - model_config = ConfigDict(from_attributes=True) - - -class BaseGraphNodeAPIResponse(BaseAPIResponse): - id: str - name: str - tags: List[str] - pipelines: List[str] - type: str - - # If a node is a ModularPipeline node, this value will be None, hence Optional. - modular_pipelines: Optional[List[str]] = None - - -class TaskNodeAPIResponse(BaseGraphNodeAPIResponse): - parameters: Dict - model_config = ConfigDict( - json_schema_extra={ - "example": { - "id": "6ab908b8", - "name": "split_data_node", - "tags": [], - "pipelines": ["__default__", "ds"], - "modular_pipelines": [], - "type": "task", - "parameters": { - "test_size": 0.2, - "random_state": 3, - "features": [ - "engines", - "passenger_capacity", - "crew", - "d_check_complete", - "moon_clearance_complete", - "iata_approved", - "company_rating", - "review_scores_rating", - ], - }, - } - } - ) - - -class DataNodeAPIResponse(BaseGraphNodeAPIResponse): - layer: Optional[str] = None - dataset_type: Optional[str] = None - stats: Optional[Dict] = None - model_config = ConfigDict( - json_schema_extra={ - "example": { - "id": "d7b83b05", - "name": "master_table", - "tags": [], - "pipelines": ["__default__", "dp", "ds"], - "modular_pipelines": [], - "type": "data", - "layer": "primary", - "dataset_type": "kedro_datasets.pandas.csv_dataset.CSVDataset", - "stats": {"rows": 10, "columns": 2, "file_size": 2300}, - } - } - ) - - -NodeAPIResponse = Union[ - TaskNodeAPIResponse, - DataNodeAPIResponse, -] - - -class TaskNodeMetadataAPIResponse(BaseAPIResponse): - code: Optional[str] = None - filepath: Optional[str] = None - parameters: Optional[Dict] = None - inputs: List[str] - outputs: List[str] - run_command: Optional[str] = None - model_config = ConfigDict( - json_schema_extra={ - "example": { - "code": "def split_data(data: pd.DataFrame, parameters: Dict) -> Tuple:", - "filepath": "proj/src/new_kedro_project/pipelines/data_science/nodes.py", - "parameters": {"test_size": 0.2}, - "inputs": ["params:input1", "input2"], - "outputs": ["output1"], - "run_command": "kedro run --to-nodes=split_data", - } - } - ) - - -class DataNodeMetadataAPIResponse(BaseAPIResponse): - filepath: Optional[str] = None - type: str - run_command: Optional[str] = None - preview: Optional[Union[Dict, str]] = None - preview_type: Optional[str] = None - stats: Optional[Dict] = None - model_config = ConfigDict( - json_schema_extra={ - "example": { - "filepath": "/my-kedro-project/data/03_primary/master_table.csv", - "type": "kedro_datasets.pandas.csv_dataset.CSVDataset", - "run_command": "kedro run --to-outputs=master_table", - } - } - ) - - -class TranscodedDataNodeMetadataAPIReponse(BaseAPIResponse): - filepath: Optional[str] = None - original_type: str - transcoded_types: List[str] - run_command: Optional[str] = None - stats: Optional[Dict] = None - - -class ParametersNodeMetadataAPIResponse(BaseAPIResponse): - parameters: Dict - model_config = ConfigDict( - json_schema_extra={ - "example": { - "parameters": { - "test_size": 0.2, - "random_state": 3, - "features": [ - "engines", - "passenger_capacity", - "crew", - "d_check_complete", - "moon_clearance_complete", - "iata_approved", - "company_rating", - "review_scores_rating", - ], - } - } - } - ) - - -NodeMetadataAPIResponse = Union[ - TaskNodeMetadataAPIResponse, - DataNodeMetadataAPIResponse, - TranscodedDataNodeMetadataAPIReponse, - ParametersNodeMetadataAPIResponse, -] - - -class GraphEdgeAPIResponse(BaseAPIResponse): - source: str - target: str - - -class NamedEntityAPIResponse(BaseAPIResponse): - """Model an API field that has an ID and a name. - For example, used for representing modular pipelines and pipelines in the API response. - """ - - id: str - name: Optional[str] = None - - -class ModularPipelineChildAPIResponse(BaseAPIResponse): - """Model a child in a modular pipeline's children field in the API response.""" - - id: str - type: str - - -class ModularPipelinesTreeNodeAPIResponse(BaseAPIResponse): - """Model a node in the tree representation of modular pipelines in the API response.""" - - id: str - name: str - inputs: List[str] - outputs: List[str] - children: List[ModularPipelineChildAPIResponse] - - -# Represent the modular pipelines in the API response as a tree. -# The root node is always designated with the __root__ key. -# Example: -# { -# "__root__": { -# "id": "__root__", -# "name": "Root", -# "inputs": [], -# "outputs": [], -# "children": [ -# {"id": "d577578a", "type": "parameters"}, -# {"id": "data_science", "type": "modularPipeline"}, -# {"id": "f1f1425b", "type": "parameters"}, -# {"id": "data_engineering", "type": "modularPipeline"}, -# ], -# }, -# "data_engineering": { -# "id": "data_engineering", -# "name": "Data Engineering", -# "inputs": ["d577578a"], -# "outputs": [], -# "children": [], -# }, -# "data_science": { -# "id": "data_science", -# "name": "Data Science", -# "inputs": ["f1f1425b"], -# "outputs": [], -# "children": [], -# }, -# } -# } -ModularPipelinesTreeAPIResponse = Dict[str, ModularPipelinesTreeNodeAPIResponse] - - -class GraphAPIResponse(BaseAPIResponse): - nodes: List[NodeAPIResponse] - edges: List[GraphEdgeAPIResponse] - layers: List[str] - tags: List[NamedEntityAPIResponse] - pipelines: List[NamedEntityAPIResponse] - modular_pipelines: ModularPipelinesTreeAPIResponse - selected_pipeline: str - - -class MetadataAPIResponse(BaseAPIResponse): - has_missing_dependencies: bool = False - package_compatibilities: List[PackageCompatibility] = [] - model_config = ConfigDict( - json_schema_extra={ - "has_missing_dependencies": False, - "package_compatibilities": [ - { - "package_name": "fsspec", - "package_version": "2024.6.1", - "is_compatible": True, - }, - { - "package_name": "kedro-datasets", - "package_version": "4.0.0", - "is_compatible": True, - }, - ], - } - ) - - -class EnhancedORJSONResponse(ORJSONResponse): - @staticmethod - def encode_to_human_readable(content: Any) -> bytes: - """A method to encode the given content to JSON, with the - proper formatting to write a human-readable file. - - Returns: - A bytes object containing the JSON to write. - - """ - return orjson.dumps( - content, - option=orjson.OPT_INDENT_2 - | orjson.OPT_NON_STR_KEYS - | orjson.OPT_SERIALIZE_NUMPY, - ) - - -def get_default_response() -> GraphAPIResponse: - """Default response for `/api/main`.""" - default_selected_pipeline_id = ( - data_access_manager.get_default_selected_pipeline().id - ) - - modular_pipelines_tree = ( - data_access_manager.create_modular_pipelines_tree_for_registered_pipeline( - default_selected_pipeline_id - ) - ) - - return GraphAPIResponse( - nodes=data_access_manager.get_nodes_for_registered_pipeline( - default_selected_pipeline_id - ), - edges=data_access_manager.get_edges_for_registered_pipeline( - default_selected_pipeline_id - ), - tags=data_access_manager.tags.as_list(), - layers=data_access_manager.get_sorted_layers_for_registered_pipeline( - default_selected_pipeline_id - ), - pipelines=data_access_manager.registered_pipelines.as_list(), - modular_pipelines=modular_pipelines_tree, - selected_pipeline=default_selected_pipeline_id, - ) - - -def get_node_metadata_response(node_id: str): - """API response for `/api/nodes/node_id`.""" - node = data_access_manager.nodes.get_node_by_id(node_id) - if not node: - return JSONResponse(status_code=404, content={"message": "Invalid node ID"}) - - if not node.has_metadata(): - return JSONResponse(content={}) - - if isinstance(node, TaskNode): - return TaskNodeMetadata(task_node=node) - - if isinstance(node, DataNode): - return DataNodeMetadata(data_node=node) - - if isinstance(node, TranscodedDataNode): - return TranscodedDataNodeMetadata(transcoded_data_node=node) - - return ParametersNodeMetadata(parameters_node=node) - - -def get_selected_pipeline_response(registered_pipeline_id: str): - """API response for `/api/pipeline/pipeline_id`.""" - if not data_access_manager.registered_pipelines.has_pipeline( - registered_pipeline_id - ): - return JSONResponse(status_code=404, content={"message": "Invalid pipeline ID"}) - - modular_pipelines_tree = ( - data_access_manager.create_modular_pipelines_tree_for_registered_pipeline( - registered_pipeline_id - ) - ) - - return GraphAPIResponse( - nodes=data_access_manager.get_nodes_for_registered_pipeline( - registered_pipeline_id - ), - edges=data_access_manager.get_edges_for_registered_pipeline( - registered_pipeline_id - ), - tags=data_access_manager.tags.as_list(), - layers=data_access_manager.get_sorted_layers_for_registered_pipeline( - registered_pipeline_id - ), - pipelines=data_access_manager.registered_pipelines.as_list(), - selected_pipeline=registered_pipeline_id, - modular_pipelines=modular_pipelines_tree, - ) - - -def get_metadata_response(): - """API response for `/api/metadata`.""" - package_compatibilities = get_package_compatibilities() - Metadata.set_package_compatibilities(package_compatibilities) - return Metadata() - - -def get_encoded_response(response: Any) -> bytes: - """Encodes and enhances the default response using human-readable format.""" - jsonable_response = jsonable_encoder(response) - encoded_response = EnhancedORJSONResponse.encode_to_human_readable( - jsonable_response - ) - - return encoded_response - - -def write_api_response_to_fs(file_path: str, response: Any, remote_fs: Any): - """Get encoded responses and writes it to a file""" - encoded_response = get_encoded_response(response) - - with remote_fs.open(file_path, "wb") as file: - file.write(encoded_response) - - -def get_kedro_project_json_data(): - """Decodes the default response and returns the Kedro project JSON data. - This will be used in VSCode extension to get current Kedro project data.""" - encoded_response = get_encoded_response(get_default_response()) - - try: - response_str = encoded_response.decode("utf-8") - json_data = json.loads(response_str) - except UnicodeDecodeError as exc: # pragma: no cover - json_data = None - logger.error("Failed to decode response string. Error: %s", str(exc)) - except json.JSONDecodeError as exc: # pragma: no cover - json_data = None - logger.error("Failed to parse JSON data. Error: %s", str(exc)) - - return json_data - - -def save_api_main_response_to_fs(main_path: str, remote_fs: Any): - """Saves API /main response to a directory.""" - try: - write_api_response_to_fs(main_path, get_default_response(), remote_fs) - except Exception as exc: # pragma: no cover - logger.exception("Failed to save default response. Error: %s", str(exc)) - raise exc - - -def save_api_node_response_to_fs( - nodes_path: str, remote_fs: Any, is_all_previews_enabled: bool -): - """Saves API /nodes/{node} response to a directory.""" - # Set if preview is enabled/disabled for all data nodes - DataNodeMetadata.set_is_all_previews_enabled(is_all_previews_enabled) - - for nodeId in data_access_manager.nodes.get_node_ids(): - try: - write_api_response_to_fs( - f"{nodes_path}/{nodeId}", get_node_metadata_response(nodeId), remote_fs - ) - except Exception as exc: # pragma: no cover - logger.exception( - "Failed to save node data for node ID %s. Error: %s", nodeId, str(exc) - ) - raise exc - - -def save_api_pipeline_response_to_fs(pipelines_path: str, remote_fs: Any): - """Saves API /pipelines/{pipeline} response to a directory.""" - for pipelineId in data_access_manager.registered_pipelines.get_pipeline_ids(): - try: - write_api_response_to_fs( - f"{pipelines_path}/{pipelineId}", - get_selected_pipeline_response(pipelineId), - remote_fs, - ) - except Exception as exc: # pragma: no cover - logger.exception( - "Failed to save pipeline data for pipeline ID %s. Error: %s", - pipelineId, - str(exc), - ) - raise exc - - -def save_api_responses_to_fs(path: str, remote_fs: Any, is_all_previews_enabled: bool): - """Saves all Kedro Viz API responses to a directory.""" - try: - logger.debug( - """Saving/Uploading api files to %s""", - path, - ) - - main_path = f"{path}/api/main" - nodes_path = f"{path}/api/nodes" - pipelines_path = f"{path}/api/pipelines" - - if "file" in remote_fs.protocol: - remote_fs.makedirs(path, exist_ok=True) - remote_fs.makedirs(nodes_path, exist_ok=True) - remote_fs.makedirs(pipelines_path, exist_ok=True) - - save_api_main_response_to_fs(main_path, remote_fs) - save_api_node_response_to_fs(nodes_path, remote_fs, is_all_previews_enabled) - save_api_pipeline_response_to_fs(pipelines_path, remote_fs) - - except Exception as exc: # pragma: no cover - logger.exception( - "An error occurred while preparing data for saving. Error: %s", str(exc) - ) - raise exc diff --git a/package/kedro_viz/api/rest/responses/__init__.py b/package/kedro_viz/api/rest/responses/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/package/kedro_viz/api/rest/responses/base.py b/package/kedro_viz/api/rest/responses/base.py new file mode 100755 index 0000000000..99fe66e85c --- /dev/null +++ b/package/kedro_viz/api/rest/responses/base.py @@ -0,0 +1,28 @@ +"""`kedro_viz.api.rest.responses.base` contains base +response classes and utility functions for the REST endpoints""" + +import abc +import logging + +from pydantic import BaseModel, ConfigDict + +logger = logging.getLogger(__name__) + + +class APINotFoundResponse(BaseModel): + """ + APINotFoundResponse is a Pydantic model representing a response for an API not found error. + + Attributes: + message (str): A message describing the error. + """ + + message: str + + +class BaseAPIResponse(BaseModel, abc.ABC): + """ + BaseAPIResponse is an abstract base class for API responses. + """ + + model_config = ConfigDict(from_attributes=True) diff --git a/package/kedro_viz/api/rest/responses/metadata.py b/package/kedro_viz/api/rest/responses/metadata.py new file mode 100755 index 0000000000..0222d261a1 --- /dev/null +++ b/package/kedro_viz/api/rest/responses/metadata.py @@ -0,0 +1,47 @@ +"""`kedro_viz.api.rest.responses.metadata` contains response classes +and utility functions for the `/metadata` REST endpoint""" + +from typing import List + +from pydantic import ConfigDict + +from kedro_viz.api.rest.responses.base import BaseAPIResponse +from kedro_viz.api.rest.utils import get_package_compatibilities +from kedro_viz.models.metadata import Metadata, PackageCompatibility + + +class MetadataAPIResponse(BaseAPIResponse): + """ + MetadataAPIResponse is a subclass of BaseAPIResponse that represents the response structure for metadata API. + + Attributes: + has_missing_dependencies (bool): Indicates if there are any missing dependencies. Defaults to False. + package_compatibilities (List[PackageCompatibility]): A list of package compatibility information. Defaults to an empty list. + """ + + has_missing_dependencies: bool = False + package_compatibilities: List[PackageCompatibility] = [] + model_config = ConfigDict( + json_schema_extra={ + "has_missing_dependencies": False, + "package_compatibilities": [ + { + "package_name": "fsspec", + "package_version": "2024.6.1", + "is_compatible": True, + }, + { + "package_name": "kedro-datasets", + "package_version": "4.0.0", + "is_compatible": True, + }, + ], + } + ) + + +def get_metadata_response(): + """API response for `/api/metadata`.""" + package_compatibilities = get_package_compatibilities() + Metadata.set_package_compatibilities(package_compatibilities) + return Metadata() diff --git a/package/kedro_viz/api/rest/responses/nodes.py b/package/kedro_viz/api/rest/responses/nodes.py new file mode 100644 index 0000000000..f6df0c53ce --- /dev/null +++ b/package/kedro_viz/api/rest/responses/nodes.py @@ -0,0 +1,162 @@ +"""`kedro_viz.api.rest.responses.nodes` contains response classes +and utility functions for the `/nodes/*` REST endpoints""" + +import logging +from typing import Dict, List, Optional, Union + +from fastapi.responses import JSONResponse +from pydantic import ConfigDict + +from kedro_viz.api.rest.responses.base import BaseAPIResponse +from kedro_viz.data_access import data_access_manager +from kedro_viz.models.flowchart.node_metadata import ( + DataNodeMetadata, + ParametersNodeMetadata, + TaskNodeMetadata, + TranscodedDataNodeMetadata, +) +from kedro_viz.models.flowchart.nodes import DataNode, TaskNode, TranscodedDataNode + +logger = logging.getLogger(__name__) + + +class TaskNodeMetadataAPIResponse(BaseAPIResponse): + """ + TaskNodeMetadataAPIResponse is a data model for representing the metadata of a task node in the Kedro visualization API. + + Attributes: + code (Optional[str]): The code snippet of the task node. + filepath (Optional[str]): The file path where the task node is defined. + parameters (Optional[Dict]): The parameters used by the task node. + inputs (List[str]): The list of input data for the task node. + outputs (List[str]): The list of output data from the task node. + run_command (Optional[str]): The command to run the task node. + """ + + code: Optional[str] = None + filepath: Optional[str] = None + parameters: Optional[Dict] = None + inputs: List[str] + outputs: List[str] + run_command: Optional[str] = None + model_config = ConfigDict( + json_schema_extra={ + "example": { + "code": "def split_data(data: pd.DataFrame, parameters: Dict) -> Tuple:", + "filepath": "proj/src/new_kedro_project/pipelines/data_science/nodes.py", + "parameters": {"test_size": 0.2}, + "inputs": ["params:input1", "input2"], + "outputs": ["output1"], + "run_command": "kedro run --to-nodes=split_data", + } + } + ) + + +class DataNodeMetadataAPIResponse(BaseAPIResponse): + """ + DataNodeMetadataAPIResponse is a class that represents the metadata response for a data node in the Kedro visualization API. + + Attributes: + filepath (Optional[str]): The file path of the data node. + type (str): The type of the data node. + run_command (Optional[str]): The command to run the data node. + preview (Optional[Union[Dict, str]]): A preview of the data node's content. + preview_type (Optional[str]): The type of the preview. + stats (Optional[Dict]): Statistics related to the data node. + """ + + filepath: Optional[str] = None + type: str + run_command: Optional[str] = None + preview: Optional[Union[Dict, str]] = None + preview_type: Optional[str] = None + stats: Optional[Dict] = None + model_config = ConfigDict( + json_schema_extra={ + "example": { + "filepath": "/my-kedro-project/data/03_primary/master_table.csv", + "type": "kedro_datasets.pandas.csv_dataset.CSVDataset", + "run_command": "kedro run --to-outputs=master_table", + } + } + ) + + +class TranscodedDataNodeMetadataAPIReponse(BaseAPIResponse): + """ + TranscodedDataNodeMetadataAPIReponse represents the metadata response for a transcoded data node. + + Attributes: + filepath (Optional[str]): The file path of the transcoded data node. + original_type (str): The original type of the data node. + transcoded_types (List[str]): A list of types to which the data node has been transcoded. + run_command (Optional[str]): The command used to run the transcoding process. + stats (Optional[Dict]): Statistics related to the transcoded data node. + """ + + filepath: Optional[str] = None + original_type: str + transcoded_types: List[str] + run_command: Optional[str] = None + stats: Optional[Dict] = None + + +class ParametersNodeMetadataAPIResponse(BaseAPIResponse): + """ + ParametersNodeMetadataAPIResponse is a subclass of BaseAPIResponse that represents the metadata response for parameters nodes. + + Attributes: + parameters (Dict): A dictionary containing the parameters. + """ + + parameters: Dict + model_config = ConfigDict( + json_schema_extra={ + "example": { + "parameters": { + "test_size": 0.2, + "random_state": 3, + "features": [ + "engines", + "passenger_capacity", + "crew", + "d_check_complete", + "moon_clearance_complete", + "iata_approved", + "company_rating", + "review_scores_rating", + ], + } + } + } + ) + + +NodeMetadataAPIResponse = Union[ + TaskNodeMetadataAPIResponse, + DataNodeMetadataAPIResponse, + TranscodedDataNodeMetadataAPIReponse, + ParametersNodeMetadataAPIResponse, +] + + +def get_node_metadata_response(node_id: str): + """API response for `/api/nodes/node_id`.""" + node = data_access_manager.nodes.get_node_by_id(node_id) + if not node: + return JSONResponse(status_code=404, content={"message": "Invalid node ID"}) + + if not node.has_metadata(): + return JSONResponse(content={}) + + if isinstance(node, TaskNode): + return TaskNodeMetadata(task_node=node) + + if isinstance(node, DataNode): + return DataNodeMetadata(data_node=node) + + if isinstance(node, TranscodedDataNode): + return TranscodedDataNodeMetadata(transcoded_data_node=node) + + return ParametersNodeMetadata(parameters_node=node) diff --git a/package/kedro_viz/api/rest/responses/pipelines.py b/package/kedro_viz/api/rest/responses/pipelines.py new file mode 100644 index 0000000000..c5c096b8e5 --- /dev/null +++ b/package/kedro_viz/api/rest/responses/pipelines.py @@ -0,0 +1,256 @@ +"""`kedro_viz.api.rest.responses.pipelines` contains response classes +and utility functions for the `/main` and `/pipelines/* REST endpoints""" + +import json +import logging +from typing import Dict, List, Optional, Union + +from fastapi.responses import JSONResponse +from pydantic import ConfigDict + +from kedro_viz.api.rest.responses.base import BaseAPIResponse +from kedro_viz.api.rest.responses.utils import get_encoded_response +from kedro_viz.data_access import data_access_manager + +logger = logging.getLogger(__name__) + + +class BaseGraphNodeAPIResponse(BaseAPIResponse): + """ + BaseGraphNodeAPIResponse is a data model for representing the response of a graph node in the API. + + Attributes: + id (str): The unique identifier of the graph node. + name (str): The name of the graph node. + tags (List[str]): A list of tags associated with the graph node. + pipelines (List[str]): A list of pipelines that the graph node belongs to. + type (str): The type of the graph node. + modular_pipelines (Optional[List[str]]): A list of modular pipelines associated with the graph node. + This value will be None if the node is a ModularPipeline node. + """ + + id: str + name: str + tags: List[str] + pipelines: List[str] + type: str + + # If a node is a ModularPipeline node, this value will be None, hence Optional. + modular_pipelines: Optional[List[str]] = None + + +class TaskNodeAPIResponse(BaseGraphNodeAPIResponse): + """ + TaskNodeAPIResponse is a subclass of BaseGraphNodeAPIResponse that represents the response for a task node in the API. + + Attributes: + parameters (Dict): A dictionary containing the parameters for the task node. + """ + + parameters: Dict + model_config = ConfigDict( + json_schema_extra={ + "example": { + "id": "6ab908b8", + "name": "split_data_node", + "tags": [], + "pipelines": ["__default__", "ds"], + "modular_pipelines": [], + "type": "task", + "parameters": { + "test_size": 0.2, + "random_state": 3, + "features": [ + "engines", + "passenger_capacity", + "crew", + "d_check_complete", + "moon_clearance_complete", + "iata_approved", + "company_rating", + "review_scores_rating", + ], + }, + } + } + ) + + +class DataNodeAPIResponse(BaseGraphNodeAPIResponse): + """ + DataNodeAPIResponse is a subclass of BaseGraphNodeAPIResponse that represents the response model for a data node in the API. + + Attributes: + layer (Optional[str]): The layer to which the data node belongs. Default is None. + dataset_type (Optional[str]): The type of dataset. Default is None. + stats (Optional[Dict]): Statistics related to the dataset, such as number of rows, columns, and file size. Default is None. + """ + + layer: Optional[str] = None + dataset_type: Optional[str] = None + stats: Optional[Dict] = None + model_config = ConfigDict( + json_schema_extra={ + "example": { + "id": "d7b83b05", + "name": "master_table", + "tags": [], + "pipelines": ["__default__", "dp", "ds"], + "modular_pipelines": [], + "type": "data", + "layer": "primary", + "dataset_type": "kedro_datasets.pandas.csv_dataset.CSVDataset", + "stats": {"rows": 10, "columns": 2, "file_size": 2300}, + } + } + ) + + +NodeAPIResponse = Union[ + TaskNodeAPIResponse, + DataNodeAPIResponse, +] + + +class GraphEdgeAPIResponse(BaseAPIResponse): + """ + GraphEdgeAPIResponse represents the response model for an edge in the graph. + + Attributes: + source (str): The source node id for the edge. + target (str): The target node id for the edge. + """ + + source: str + target: str + + +class NamedEntityAPIResponse(BaseAPIResponse): + """Model an API field that has an ID and a name. + For example, used for representing modular pipelines and pipelines in the API response. + """ + + id: str + name: Optional[str] = None + + +class ModularPipelineChildAPIResponse(BaseAPIResponse): + """Model a child in a modular pipeline's children field in the API response.""" + + id: str + type: str + + +class ModularPipelinesTreeNodeAPIResponse(BaseAPIResponse): + """Model a node in the tree representation of modular pipelines in the API response.""" + + id: str + name: str + inputs: List[str] + outputs: List[str] + children: List[ModularPipelineChildAPIResponse] + + +# Represent the modular pipelines in the API response as a tree. +# The root node is always designated with the __root__ key. +# Example: +# { +# "__root__": { +# "id": "__root__", +# "name": "Root", +# "inputs": [], +# "outputs": [], +# "children": [ +# {"id": "d577578a", "type": "parameters"}, +# {"id": "data_science", "type": "modularPipeline"}, +# {"id": "f1f1425b", "type": "parameters"}, +# {"id": "data_engineering", "type": "modularPipeline"}, +# ], +# }, +# "data_engineering": { +# "id": "data_engineering", +# "name": "Data Engineering", +# "inputs": ["d577578a"], +# "outputs": [], +# "children": [], +# }, +# "data_science": { +# "id": "data_science", +# "name": "Data Science", +# "inputs": ["f1f1425b"], +# "outputs": [], +# "children": [], +# }, +# } +# } +ModularPipelinesTreeAPIResponse = Dict[str, ModularPipelinesTreeNodeAPIResponse] + + +class GraphAPIResponse(BaseAPIResponse): + """ + GraphAPIResponse is a data model for the response of the graph API. + + Attributes: + nodes (List[NodeAPIResponse]): A list of nodes in the graph. + edges (List[GraphEdgeAPIResponse]): A list of edges connecting the nodes in the graph. + layers (List[str]): A list of layers in the graph. + tags (List[NamedEntityAPIResponse]): A list of tags associated with the graph entities. + pipelines (List[NamedEntityAPIResponse]): A list of pipelines in the graph. + modular_pipelines (ModularPipelinesTreeAPIResponse): A tree structure representing modular pipelines. + selected_pipeline (str): The identifier of the selected pipeline. + """ + + nodes: List[NodeAPIResponse] + edges: List[GraphEdgeAPIResponse] + layers: List[str] + tags: List[NamedEntityAPIResponse] + pipelines: List[NamedEntityAPIResponse] + modular_pipelines: ModularPipelinesTreeAPIResponse + selected_pipeline: str + + +def get_pipeline_response( + pipeline_id: Union[str, None] = None, +) -> Union[GraphAPIResponse, JSONResponse]: + """API response for `/api/pipelines/pipeline_id`.""" + if pipeline_id is None: + pipeline_id = data_access_manager.get_default_selected_pipeline().id + + if not data_access_manager.registered_pipelines.has_pipeline(pipeline_id): + return JSONResponse(status_code=404, content={"message": "Invalid pipeline ID"}) + + modular_pipelines_tree = ( + data_access_manager.create_modular_pipelines_tree_for_registered_pipeline( + pipeline_id + ) + ) + + return GraphAPIResponse( + nodes=data_access_manager.get_nodes_for_registered_pipeline(pipeline_id), + edges=data_access_manager.get_edges_for_registered_pipeline(pipeline_id), + tags=data_access_manager.tags.as_list(), + layers=data_access_manager.get_sorted_layers_for_registered_pipeline( + pipeline_id + ), + pipelines=data_access_manager.registered_pipelines.as_list(), + modular_pipelines=modular_pipelines_tree, + selected_pipeline=pipeline_id, + ) + + +def get_kedro_project_json_data(): + """Decodes the default response and returns the Kedro project JSON data. + This will be used in VSCode extension to get current Kedro project data.""" + encoded_response = get_encoded_response(get_pipeline_response()) + + try: + response_str = encoded_response.decode("utf-8") + json_data = json.loads(response_str) + except UnicodeDecodeError as exc: # pragma: no cover + json_data = None + logger.error("Failed to decode response string. Error: %s", str(exc)) + except json.JSONDecodeError as exc: # pragma: no cover + json_data = None + logger.error("Failed to parse JSON data. Error: %s", str(exc)) + + return json_data diff --git a/package/kedro_viz/api/rest/responses/save_responses.py b/package/kedro_viz/api/rest/responses/save_responses.py new file mode 100644 index 0000000000..bcdd335534 --- /dev/null +++ b/package/kedro_viz/api/rest/responses/save_responses.py @@ -0,0 +1,97 @@ +"""`kedro_viz.api.rest.responses.save_responses` contains response classes +and utility functions for writing and saving REST endpoint responses to file system""" + +import logging +from typing import Any + +from kedro_viz.api.rest.responses.nodes import get_node_metadata_response +from kedro_viz.api.rest.responses.pipelines import get_pipeline_response +from kedro_viz.api.rest.responses.utils import get_encoded_response +from kedro_viz.data_access import data_access_manager +from kedro_viz.models.flowchart.node_metadata import DataNodeMetadata + +logger = logging.getLogger(__name__) + + +def save_api_responses_to_fs(path: str, remote_fs: Any, is_all_previews_enabled: bool): + """Saves all Kedro Viz API responses to a directory.""" + try: + logger.debug( + """Saving/Uploading api files to %s""", + path, + ) + + main_path = f"{path}/api/main" + nodes_path = f"{path}/api/nodes" + pipelines_path = f"{path}/api/pipelines" + + if "file" in remote_fs.protocol: + remote_fs.makedirs(path, exist_ok=True) + remote_fs.makedirs(nodes_path, exist_ok=True) + remote_fs.makedirs(pipelines_path, exist_ok=True) + + save_api_main_response_to_fs(main_path, remote_fs) + save_api_node_response_to_fs(nodes_path, remote_fs, is_all_previews_enabled) + save_api_pipeline_response_to_fs(pipelines_path, remote_fs) + + except Exception as exc: # pragma: no cover + logger.exception( + "An error occurred while preparing data for saving. Error: %s", str(exc) + ) + raise exc + + +def save_api_main_response_to_fs(main_path: str, remote_fs: Any): + """Saves API /main response to a directory.""" + try: + write_api_response_to_fs(main_path, get_pipeline_response(), remote_fs) + except Exception as exc: # pragma: no cover + logger.exception("Failed to save default response. Error: %s", str(exc)) + raise exc + + +def save_api_pipeline_response_to_fs(pipelines_path: str, remote_fs: Any): + """Saves API /pipelines/{pipeline} response to a directory.""" + for pipeline_id in data_access_manager.registered_pipelines.get_pipeline_ids(): + try: + write_api_response_to_fs( + f"{pipelines_path}/{pipeline_id}", + get_pipeline_response(pipeline_id), + remote_fs, + ) + except Exception as exc: # pragma: no cover + logger.exception( + "Failed to save pipeline data for pipeline ID %s. Error: %s", + pipeline_id, + str(exc), + ) + raise exc + + +def save_api_node_response_to_fs( + nodes_path: str, remote_fs: Any, is_all_previews_enabled: bool +): + """Saves API /nodes/{node} response to a directory.""" + # Set if preview is enabled/disabled for all data nodes + DataNodeMetadata.set_is_all_previews_enabled(is_all_previews_enabled) + + for node_id in data_access_manager.nodes.get_node_ids(): + try: + write_api_response_to_fs( + f"{nodes_path}/{node_id}", + get_node_metadata_response(node_id), + remote_fs, + ) + except Exception as exc: # pragma: no cover + logger.exception( + "Failed to save node data for node ID %s. Error: %s", node_id, str(exc) + ) + raise exc + + +def write_api_response_to_fs(file_path: str, response: Any, remote_fs: Any): + """Get encoded responses and writes it to a file""" + encoded_response = get_encoded_response(response) + + with remote_fs.open(file_path, "wb") as file: + file.write(encoded_response) diff --git a/package/kedro_viz/api/rest/responses/utils.py b/package/kedro_viz/api/rest/responses/utils.py new file mode 100644 index 0000000000..38bae09460 --- /dev/null +++ b/package/kedro_viz/api/rest/responses/utils.py @@ -0,0 +1,44 @@ +"""`kedro_viz.api.rest.responses.utils` contains utility +response classes and functions for the REST endpoints""" + +import logging +from typing import Any + +import orjson +from fastapi.encoders import jsonable_encoder +from fastapi.responses import ORJSONResponse + +logger = logging.getLogger(__name__) + + +class EnhancedORJSONResponse(ORJSONResponse): + """ + EnhancedORJSONResponse is a subclass of ORJSONResponse that provides + additional functionality for encoding content to a human-readable JSON format. + """ + + @staticmethod + def encode_to_human_readable(content: Any) -> bytes: + """A method to encode the given content to JSON, with the + proper formatting to write a human-readable file. + + Returns: + A bytes object containing the JSON to write. + + """ + return orjson.dumps( + content, + option=orjson.OPT_INDENT_2 + | orjson.OPT_NON_STR_KEYS + | orjson.OPT_SERIALIZE_NUMPY, + ) + + +def get_encoded_response(response: Any) -> bytes: + """Encodes and enhances the default response using human-readable format.""" + jsonable_response = jsonable_encoder(response) + encoded_response = EnhancedORJSONResponse.encode_to_human_readable( + jsonable_response + ) + + return encoded_response diff --git a/package/kedro_viz/api/rest/router.py b/package/kedro_viz/api/rest/router.py index a32e204281..2a743239fb 100644 --- a/package/kedro_viz/api/rest/router.py +++ b/package/kedro_viz/api/rest/router.py @@ -6,35 +6,31 @@ from fastapi.responses import JSONResponse from kedro_viz.api.rest.requests import DeployerConfiguration -from kedro_viz.integrations.deployment.deployer_factory import DeployerFactory - -from .responses import ( - APIErrorMessage, - GraphAPIResponse, +from kedro_viz.api.rest.responses.base import APINotFoundResponse +from kedro_viz.api.rest.responses.metadata import ( MetadataAPIResponse, - NodeMetadataAPIResponse, - get_default_response, get_metadata_response, +) +from kedro_viz.api.rest.responses.nodes import ( + NodeMetadataAPIResponse, get_node_metadata_response, - get_selected_pipeline_response, ) - -try: - from azure.core.exceptions import ServiceRequestError -except ImportError: # pragma: no cover - ServiceRequestError = None # type: ignore +from kedro_viz.api.rest.responses.pipelines import ( + GraphAPIResponse, + get_pipeline_response, +) logger = logging.getLogger(__name__) router = APIRouter( prefix="/api", - responses={404: {"model": APIErrorMessage}}, + responses={404: {"model": APINotFoundResponse}}, ) @router.get("/main", response_model=GraphAPIResponse) async def main(): - return get_default_response() + return get_pipeline_response() @router.get( @@ -51,11 +47,18 @@ async def get_single_node_metadata(node_id: str): response_model=GraphAPIResponse, ) async def get_single_pipeline_data(registered_pipeline_id: str): - return get_selected_pipeline_response(registered_pipeline_id) + return get_pipeline_response(registered_pipeline_id) @router.post("/deploy") async def deploy_kedro_viz(input_values: DeployerConfiguration): + from kedro_viz.integrations.deployment.deployer_factory import DeployerFactory + + try: + from azure.core.exceptions import ServiceRequestError + except ImportError: # pragma: no cover + ServiceRequestError = None # type: ignore + try: deployer = DeployerFactory.create_deployer( input_values.platform, input_values.endpoint, input_values.bucket_name diff --git a/package/kedro_viz/data_access/managers.py b/package/kedro_viz/data_access/managers.py index 4468804c77..f7e572a497 100644 --- a/package/kedro_viz/data_access/managers.py +++ b/package/kedro_viz/data_access/managers.py @@ -4,7 +4,6 @@ from collections import defaultdict from typing import Dict, List, Set, Union -import networkx as nx from kedro.io import DataCatalog try: @@ -549,6 +548,8 @@ def create_modular_pipelines_tree_for_registered_pipeline( # noqa: PLR0912 # so no need to check non modular pipeline nodes. # # We leverage networkx to help with graph traversal + import networkx as nx + digraph = nx.DiGraph() for edge in edges: digraph.add_edge(edge.source, edge.target) diff --git a/package/kedro_viz/integrations/deployment/base_deployer.py b/package/kedro_viz/integrations/deployment/base_deployer.py index 35b7fc1818..d0f0b2a7bf 100644 --- a/package/kedro_viz/integrations/deployment/base_deployer.py +++ b/package/kedro_viz/integrations/deployment/base_deployer.py @@ -12,7 +12,7 @@ from packaging.version import parse from kedro_viz import __version__ -from kedro_viz.api.rest.responses import save_api_responses_to_fs +from kedro_viz.api.rest.responses.save_responses import save_api_responses_to_fs from kedro_viz.integrations.kedro import telemetry as kedro_telemetry _HTML_DIR = Path(__file__).parent.parent.parent.absolute() / "html" diff --git a/package/kedro_viz/launchers/cli/deploy.py b/package/kedro_viz/launchers/cli/deploy.py index 75d0b8bb43..87e9157033 100644 --- a/package/kedro_viz/launchers/cli/deploy.py +++ b/package/kedro_viz/launchers/cli/deploy.py @@ -5,6 +5,7 @@ from kedro_viz.constants import SHAREABLEVIZ_SUPPORTED_PLATFORMS from kedro_viz.launchers.cli.main import viz +from kedro_viz.launchers.utils import display_cli_message @viz.command(context_settings={"help_option_names": ["-h", "--help"]}) @@ -39,10 +40,7 @@ ) def deploy(platform, endpoint, bucket_name, include_hooks, include_previews): """Deploy and host Kedro Viz on provided platform""" - from kedro_viz.launchers.cli.utils import ( - create_shareableviz_process, - display_cli_message, - ) + from kedro_viz.launchers.cli.utils import create_shareableviz_process if not platform or platform.lower() not in SHAREABLEVIZ_SUPPORTED_PLATFORMS: display_cli_message( diff --git a/package/kedro_viz/launchers/cli/run.py b/package/kedro_viz/launchers/cli/run.py index 4fab6c1869..e7dd08b408 100644 --- a/package/kedro_viz/launchers/cli/run.py +++ b/package/kedro_viz/launchers/cli/run.py @@ -111,13 +111,13 @@ def run( get_latest_version, is_running_outdated_version, ) - from kedro_viz.launchers.cli.utils import display_cli_message from kedro_viz.launchers.utils import ( _PYPROJECT, _check_viz_up, _find_kedro_project, _start_browser, _wait_for, + display_cli_message, ) from kedro_viz.server import run_server diff --git a/package/kedro_viz/launchers/cli/utils.py b/package/kedro_viz/launchers/cli/utils.py index b5a376022b..60e7403535 100644 --- a/package/kedro_viz/launchers/cli/utils.py +++ b/package/kedro_viz/launchers/cli/utils.py @@ -4,9 +4,8 @@ from time import sleep from typing import Union -import click - from kedro_viz.constants import VIZ_DEPLOY_TIME_LIMIT +from kedro_viz.launchers.utils import display_cli_message def create_shareableviz_process( @@ -103,16 +102,6 @@ def create_shareableviz_process( viz_deploy_process.terminate() -def display_cli_message(msg, msg_color=None): - """Displays message for Kedro Viz build and deploy commands""" - click.echo( - click.style( - msg, - fg=msg_color, - ) - ) - - def _load_and_deploy_viz( platform, is_all_previews_enabled, diff --git a/package/kedro_viz/launchers/utils.py b/package/kedro_viz/launchers/utils.py index 00fcde64eb..5c6bbae9e3 100644 --- a/package/kedro_viz/launchers/utils.py +++ b/package/kedro_viz/launchers/utils.py @@ -7,6 +7,7 @@ from time import sleep, time from typing import Any, Callable, Union +import click import requests logger = logging.getLogger(__name__) @@ -113,3 +114,13 @@ def _find_kedro_project(current_dir: Path) -> Any: if _is_project(project_dir): return project_dir return None + + +def display_cli_message(msg, msg_color=None): + """Displays message for Kedro Viz build and deploy commands""" + click.echo( + click.style( + msg, + fg=msg_color, + ) + ) diff --git a/package/kedro_viz/server.py b/package/kedro_viz/server.py index d9b8fbc2e6..251bb32b6b 100644 --- a/package/kedro_viz/server.py +++ b/package/kedro_viz/server.py @@ -8,13 +8,12 @@ from kedro.io import DataCatalog from kedro.pipeline import Pipeline -from kedro_viz.api.rest.responses import save_api_responses_to_fs from kedro_viz.constants import DEFAULT_HOST, DEFAULT_PORT from kedro_viz.data_access import DataAccessManager, data_access_manager from kedro_viz.database import make_db_session_factory from kedro_viz.integrations.kedro import data_loader as kedro_data_loader from kedro_viz.integrations.kedro.sqlite_store import SQLiteStore -from kedro_viz.launchers.utils import _check_viz_up, _wait_for +from kedro_viz.launchers.utils import _check_viz_up, _wait_for, display_cli_message DEV_PORT = 4142 @@ -124,6 +123,10 @@ def run_server( # [TODO: As we can do this with `kedro viz build`, # we need to shift this feature outside of kedro viz run] if save_file: + from kedro_viz.api.rest.responses.save_responses import ( + save_api_responses_to_fs, + ) + save_api_responses_to_fs(save_file, fsspec.filesystem("file"), True) app = apps.create_api_app_from_project(path, autoreload) @@ -170,13 +173,14 @@ def run_server( target=run_process, daemon=False, kwargs={**run_process_kwargs} ) - print("Starting Kedro Viz ...") + display_cli_message("Starting Kedro Viz ...", "green") viz_process.start() _wait_for(func=_check_viz_up, host=args.host, port=args.port) - print( + display_cli_message( "Kedro Viz started successfully. \n\n" - f"\u2728 Kedro Viz is running at \n http://{args.host}:{args.port}/" + f"\u2728 Kedro Viz is running at \n http://{args.host}:{args.port}/", + "green", ) diff --git a/package/tests/conftest.py b/package/tests/conftest.py index c6b802974a..5c1a300abb 100644 --- a/package/tests/conftest.py +++ b/package/tests/conftest.py @@ -485,7 +485,12 @@ def example_api( example_stats_dict, ) mocker.patch( - "kedro_viz.api.rest.responses.data_access_manager", new=data_access_manager + "kedro_viz.api.rest.responses.pipelines.data_access_manager", + new=data_access_manager, + ) + mocker.patch( + "kedro_viz.api.rest.responses.nodes.data_access_manager", + new=data_access_manager, ) yield api @@ -504,7 +509,12 @@ def example_api_no_default_pipeline( data_access_manager, example_catalog, example_pipelines, session_store, {} ) mocker.patch( - "kedro_viz.api.rest.responses.data_access_manager", new=data_access_manager + "kedro_viz.api.rest.responses.pipelines.data_access_manager", + new=data_access_manager, + ) + mocker.patch( + "kedro_viz.api.rest.responses.nodes.data_access_manager", + new=data_access_manager, ) yield api @@ -534,7 +544,12 @@ def example_api_for_edge_case_pipelines( {}, ) mocker.patch( - "kedro_viz.api.rest.responses.data_access_manager", new=data_access_manager + "kedro_viz.api.rest.responses.pipelines.data_access_manager", + new=data_access_manager, + ) + mocker.patch( + "kedro_viz.api.rest.responses.nodes.data_access_manager", + new=data_access_manager, ) yield api @@ -556,7 +571,12 @@ def example_transcoded_api( {}, ) mocker.patch( - "kedro_viz.api.rest.responses.data_access_manager", new=data_access_manager + "kedro_viz.api.rest.responses.pipelines.data_access_manager", + new=data_access_manager, + ) + mocker.patch( + "kedro_viz.api.rest.responses.nodes.data_access_manager", + new=data_access_manager, ) yield api diff --git a/package/tests/test_api/test_rest/test_responses/__init__.py b/package/tests/test_api/test_rest/test_responses/__init__.py new file mode 100755 index 0000000000..e69de29bb2 diff --git a/package/tests/test_api/test_rest/test_responses.py b/package/tests/test_api/test_rest/test_responses/assert_helpers.py similarity index 50% rename from package/tests/test_api/test_rest/test_responses.py rename to package/tests/test_api/test_rest/test_responses/assert_helpers.py index 8dbf549416..a55ecd9b81 100644 --- a/package/tests/test_api/test_rest/test_responses.py +++ b/package/tests/test_api/test_rest/test_responses/assert_helpers.py @@ -1,26 +1,5 @@ -import json import operator -from pathlib import Path from typing import Any, Dict, Iterable, List -from unittest import mock -from unittest.mock import Mock, call, patch - -import pytest -from fastapi.testclient import TestClient - -from kedro_viz.api import apps -from kedro_viz.api.rest.responses import ( - EnhancedORJSONResponse, - get_kedro_project_json_data, - get_metadata_response, - save_api_main_response_to_fs, - save_api_node_response_to_fs, - save_api_pipeline_response_to_fs, - save_api_responses_to_fs, - write_api_response_to_fs, -) -from kedro_viz.models.flowchart.nodes import TaskNode -from kedro_viz.models.metadata import Metadata def _is_dict_list(collection: Any) -> bool: @@ -29,19 +8,21 @@ def _is_dict_list(collection: Any) -> bool: return False -def assert_dict_list_equal( - response: List[Dict], expected: List[Dict], sort_keys: Iterable[str] -): - """Assert two list of dictionaries with undeterministic order - to be equal by sorting them first based on a sort key. - """ - if len(response) == 0: - assert len(expected) == 0 - return +def assert_modular_pipelines_tree_equal(response: Dict, expected: Dict): + """Assert if modular pipelines tree are equal.""" + # first assert that they have the same set of keys + assert sorted(response.keys()) == sorted(expected.keys()) - assert sorted(response, key=operator.itemgetter(*sort_keys)) == sorted( - expected, key=operator.itemgetter(*sort_keys) - ) + # then compare the dictionary at each key recursively + for key in response: + if isinstance(response[key], dict): + assert_modular_pipelines_tree_equal(response[key], expected[key]) + elif _is_dict_list(response[key]): + assert_dict_list_equal(response[key], expected[key], sort_keys=("id",)) + elif isinstance(response[key], list): + assert sorted(response[key]) == sorted(expected[key]) + else: + assert response[key] == expected[key] def assert_nodes_equal(response_nodes, expected_nodes): @@ -70,21 +51,19 @@ def assert_nodes_equal(response_nodes, expected_nodes): assert response_node == expected_node -def assert_modular_pipelines_tree_equal(response: Dict, expected: Dict): - """Assert if modular pipelines tree are equal.""" - # first assert that they have the same set of keys - assert sorted(response.keys()) == sorted(expected.keys()) +def assert_dict_list_equal( + response: List[Dict], expected: List[Dict], sort_keys: Iterable[str] +): + """Assert two list of dictionaries with undeterministic order + to be equal by sorting them first based on a sort key. + """ + if len(response) == 0: + assert len(expected) == 0 + return - # then compare the dictionary at each key recursively - for key in response: - if isinstance(response[key], dict): - assert_modular_pipelines_tree_equal(response[key], expected[key]) - elif _is_dict_list(response[key]): - assert_dict_list_equal(response[key], expected[key], sort_keys=("id",)) - elif isinstance(response[key], list): - assert sorted(response[key]) == sorted(expected[key]) - else: - assert response[key] == expected[key] + assert sorted(response, key=operator.itemgetter(*sort_keys)) == sorted( + expected, key=operator.itemgetter(*sort_keys) + ) def assert_example_data(response_data): @@ -563,500 +542,3 @@ def assert_example_transcoded_data(response_data): ] assert_nodes_equal(response_data.pop("nodes"), expected_nodes) - - -class TestMainEndpoint: - """Test a viz API created from a Kedro project.""" - - def test_endpoint_main(self, client): - response = client.get("/api/main") - assert_example_data(response.json()) - - def test_endpoint_main_no_default_pipeline(self, example_api_no_default_pipeline): - client = TestClient(example_api_no_default_pipeline) - response = client.get("/api/main") - assert len(response.json()["nodes"]) == 6 - assert len(response.json()["edges"]) == 9 - assert response.json()["pipelines"] == [ - {"id": "data_science", "name": "data_science"}, - {"id": "data_processing", "name": "data_processing"}, - ] - - def test_endpoint_main_for_edge_case_pipelines( - self, - example_api_for_edge_case_pipelines, - expected_modular_pipeline_tree_for_edge_cases, - ): - client = TestClient(example_api_for_edge_case_pipelines) - response = client.get("/api/main") - actual_modular_pipelines_tree = response.json()["modular_pipelines"] - assert_modular_pipelines_tree_equal( - actual_modular_pipelines_tree, expected_modular_pipeline_tree_for_edge_cases - ) - - -class TestTranscodedDataset: - """Test a viz API created from a Kedro project.""" - - def test_endpoint_main(self, example_transcoded_api): - client = TestClient(example_transcoded_api) - response = client.get("/api/main") - assert response.status_code == 200 - assert_example_transcoded_data(response.json()) - - def test_transcoded_data_node_metadata(self, example_transcoded_api): - client = TestClient(example_transcoded_api) - response = client.get("/api/nodes/0ecea0de") - assert response.json() == { - "filepath": "model_inputs.csv", - "original_type": "pandas.csv_dataset.CSVDataset", - "transcoded_types": [ - "pandas.parquet_dataset.ParquetDataset", - ], - "run_command": "kedro run --to-outputs=model_inputs@pandas2", - } - - -class TestNodeMetadataEndpoint: - def test_node_not_exist(self, client): - response = client.get("/api/nodes/foo") - assert response.status_code == 404 - - def test_task_node_metadata(self, client): - response = client.get("/api/nodes/782e4a43") - metadata = response.json() - assert ( - metadata["code"].replace(" ", "") - == "defprocess_data(raw_data,train_test_split):\npass\n" - ) - assert metadata["parameters"] == {"uk.data_processing.train_test_split": 0.1} - assert metadata["inputs"] == [ - "uk.data_processing.raw_data", - "params:uk.data_processing.train_test_split", - ] - assert metadata["outputs"] == ["model_inputs"] - assert ( - metadata["run_command"] - == "kedro run --to-nodes='uk.data_processing.process_data'" - ) - assert str(Path("package/tests/conftest.py")) in metadata["filepath"] - - def test_data_node_metadata(self, client): - response = client.get("/api/nodes/0ecea0de") - assert response.json() == { - "filepath": "model_inputs.csv", - "type": "pandas.csv_dataset.CSVDataset", - "preview_type": "TablePreview", - "run_command": "kedro run --to-outputs=model_inputs", - "stats": {"columns": 12, "rows": 29768}, - } - - def test_data_node_metadata_for_free_input(self, client): - response = client.get("/api/nodes/13399a82") - assert response.json() == { - "filepath": "raw_data.csv", - "preview_type": "TablePreview", - "type": "pandas.csv_dataset.CSVDataset", - } - - def test_parameters_node_metadata(self, client): - response = client.get("/api/nodes/f1f1425b") - assert response.json() == { - "parameters": {"train_test_split": 0.1, "num_epochs": 1000} - } - - def test_single_parameter_node_metadata(self, client): - response = client.get("/api/nodes/f0ebef01") - assert response.json() == { - "parameters": {"uk.data_processing.train_test_split": 0.1} - } - - def test_no_metadata(self, client): - with mock.patch.object(TaskNode, "has_metadata", return_value=False): - response = client.get("/api/nodes/782e4a43") - assert response.json() == {} - - -class TestSinglePipelineEndpoint: - def test_get_pipeline(self, client): - response = client.get("/api/pipelines/data_science") - assert response.status_code == 200 - response_data = response.json() - expected_edges = [ - {"source": "f2b25286", "target": "d5a8b994"}, - {"source": "f1f1425b", "target": "uk.data_science"}, - {"source": "f1f1425b", "target": "f2b25286"}, - {"source": "uk.data_science", "target": "d5a8b994"}, - {"source": "uk", "target": "d5a8b994"}, - {"source": "0ecea0de", "target": "uk"}, - {"source": "0ecea0de", "target": "uk.data_science"}, - {"source": "f1f1425b", "target": "uk"}, - {"source": "0ecea0de", "target": "f2b25286"}, - ] - assert_dict_list_equal( - response_data.pop("edges"), expected_edges, sort_keys=("source", "target") - ) - expected_nodes = [ - { - "id": "0ecea0de", - "name": "model_inputs", - "tags": ["train", "split"], - "pipelines": ["__default__", "data_science", "data_processing"], - "modular_pipelines": ["uk.data_science", "uk.data_processing"], - "type": "data", - "layer": "model_inputs", - "dataset_type": "pandas.csv_dataset.CSVDataset", - "stats": {"columns": 12, "rows": 29768}, - }, - { - "id": "f2b25286", - "name": "train_model", - "tags": ["train"], - "pipelines": ["__default__", "data_science"], - "modular_pipelines": ["uk.data_science"], - "type": "task", - "parameters": { - "train_test_split": 0.1, - "num_epochs": 1000, - }, - }, - { - "id": "f1f1425b", - "name": "parameters", - "tags": ["train"], - "pipelines": ["__default__", "data_science"], - "modular_pipelines": None, - "type": "parameters", - "layer": None, - "dataset_type": None, - "stats": None, - }, - { - "id": "d5a8b994", - "name": "uk.data_science.model", - "tags": ["train"], - "pipelines": ["__default__", "data_science"], - "modular_pipelines": ["uk", "uk.data_science"], - "type": "data", - "layer": None, - "dataset_type": "io.memory_dataset.MemoryDataset", - "stats": None, - }, - { - "id": "uk", - "name": "uk", - "tags": ["train"], - "pipelines": ["data_science"], - "type": "modularPipeline", - "modular_pipelines": None, - "layer": None, - "dataset_type": None, - "stats": None, - }, - { - "id": "uk.data_science", - "name": "uk.data_science", - "tags": ["train"], - "pipelines": ["data_science"], - "type": "modularPipeline", - "modular_pipelines": None, - "layer": None, - "dataset_type": None, - "stats": None, - }, - ] - assert_nodes_equal(response_data.pop("nodes"), expected_nodes) - - expected_modular_pipelines = { - "__root__": { - "children": [ - {"id": "f1f1425b", "type": "parameters"}, - {"id": "0ecea0de", "type": "data"}, - {"id": "uk", "type": "modularPipeline"}, - {"id": "d5a8b994", "type": "data"}, - ], - "id": "__root__", - "inputs": [], - "name": "__root__", - "outputs": [], - }, - "uk": { - "children": [ - {"id": "uk.data_science", "type": "modularPipeline"}, - ], - "id": "uk", - "inputs": ["0ecea0de", "f1f1425b"], - "name": "uk", - "outputs": ["d5a8b994"], - }, - "uk.data_science": { - "children": [ - {"id": "f2b25286", "type": "task"}, - ], - "id": "uk.data_science", - "inputs": ["0ecea0de", "f1f1425b"], - "name": "uk.data_science", - "outputs": ["d5a8b994"], - }, - } - - assert_modular_pipelines_tree_equal( - response_data.pop("modular_pipelines"), - expected_modular_pipelines, - ) - - # Extract and sort the layers field - response_data_layers_sorted = sorted(response_data["layers"]) - expected_layers_sorted = sorted(["model_inputs", "raw"]) - assert response_data_layers_sorted == expected_layers_sorted - - # Remove the layers field from response_data for further comparison - response_data.pop("layers") - - # Expected response without the layers field - expected_response_without_layers = { - "tags": [ - {"id": "split", "name": "split"}, - {"id": "train", "name": "train"}, - ], - "pipelines": [ - {"id": "__default__", "name": "__default__"}, - {"id": "data_science", "name": "data_science"}, - {"id": "data_processing", "name": "data_processing"}, - ], - "selected_pipeline": "data_science", - } - assert response_data == expected_response_without_layers - - def test_get_non_existing_pipeline(self, client): - response = client.get("/api/pipelines/foo") - assert response.status_code == 404 - - -class TestAppMetadata: - def test_get_metadata_response(self, mocker): - mock_get_compat = mocker.patch( - "kedro_viz.api.rest.responses.get_package_compatibilities", - return_value="mocked_compatibilities", - ) - mock_set_compat = mocker.patch( - "kedro_viz.api.rest.responses.Metadata.set_package_compatibilities" - ) - - response = get_metadata_response() - - # Assert get_package_compatibilities was called - mock_get_compat.assert_called_once() - - # Assert set_package_compatibilities was called with the mocked compatibilities - mock_set_compat.assert_called_once_with("mocked_compatibilities") - - # Assert the function returns the Metadata instance - assert isinstance(response, Metadata) - - -class TestAPIAppFromFile: - def test_api_app_from_json_file_main_api(self): - filepath = str(Path(__file__).parent.parent) - api_app = apps.create_api_app_from_file(filepath) - client = TestClient(api_app) - response = client.get("/api/main") - assert_example_data_from_file(response.json()) - - def test_api_app_from_json_file_index(self): - filepath = str(Path(__file__).parent.parent) - api_app = apps.create_api_app_from_file(filepath) - client = TestClient(api_app) - response = client.get("/") - assert response.status_code == 200 - - -class TestEnhancedORJSONResponse: - @pytest.mark.parametrize( - "content, expected", - [ - ( - {"key1": "value1", "key2": "value2"}, - b'{\n "key1": "value1",\n "key2": "value2"\n}', - ), - (["item1", "item2"], b'[\n "item1",\n "item2"\n]'), - ], - ) - def test_encode_to_human_readable(self, content, expected): - result = EnhancedORJSONResponse.encode_to_human_readable(content) - assert result == expected - - @pytest.mark.parametrize( - "file_path, response, encoded_response", - [ - ( - "test_output.json", - {"key1": "value1", "key2": "value2"}, - b'{"key1": "value1", "key2": "value2"}', - ), - ], - ) - def test_write_api_response_to_fs( - self, file_path, response, encoded_response, mocker - ): - mock_encode_to_human_readable = mocker.patch( - "kedro_viz.api.rest.responses.EnhancedORJSONResponse.encode_to_human_readable", - return_value=encoded_response, - ) - with patch("fsspec.filesystem") as mock_filesystem: - mockremote_fs = mock_filesystem.return_value - mockremote_fs.open.return_value.__enter__.return_value = Mock() - write_api_response_to_fs(file_path, response, mockremote_fs) - mockremote_fs.open.assert_called_once_with(file_path, "wb") - mock_encode_to_human_readable.assert_called_once() - - def test_get_kedro_project_json_data(self, mocker): - expected_json_data = {"key": "value"} - encoded_response = json.dumps(expected_json_data).encode("utf-8") - - mock_get_default_response = mocker.patch( - "kedro_viz.api.rest.responses.get_default_response", - return_value={"key": "value"}, - ) - mock_get_encoded_response = mocker.patch( - "kedro_viz.api.rest.responses.get_encoded_response", - return_value=encoded_response, - ) - - json_data = get_kedro_project_json_data() - - mock_get_default_response.assert_called_once() - mock_get_encoded_response.assert_called_once_with( - mock_get_default_response.return_value - ) - assert json_data == expected_json_data - - def test_save_api_main_response_to_fs(self, mocker): - expected_default_response = {"test": "json"} - main_path = "/main" - - mock_get_default_response = mocker.patch( - "kedro_viz.api.rest.responses.get_default_response", - return_value=expected_default_response, - ) - mock_write_api_response_to_fs = mocker.patch( - "kedro_viz.api.rest.responses.write_api_response_to_fs" - ) - - remote_fs = Mock() - - save_api_main_response_to_fs(main_path, remote_fs) - - mock_get_default_response.assert_called_once() - mock_write_api_response_to_fs.assert_called_once_with( - main_path, mock_get_default_response.return_value, remote_fs - ) - - def test_save_api_node_response_to_fs(self, mocker): - nodes_path = "/nodes" - nodeIds = ["01f456", "01f457"] - expected_metadata_response = {"test": "json"} - - mock_get_node_metadata_response = mocker.patch( - "kedro_viz.api.rest.responses.get_node_metadata_response", - return_value=expected_metadata_response, - ) - mock_write_api_response_to_fs = mocker.patch( - "kedro_viz.api.rest.responses.write_api_response_to_fs" - ) - mocker.patch( - "kedro_viz.api.rest.responses.data_access_manager.nodes.get_node_ids", - return_value=nodeIds, - ) - remote_fs = Mock() - - save_api_node_response_to_fs(nodes_path, remote_fs, False) - - assert mock_write_api_response_to_fs.call_count == len(nodeIds) - assert mock_get_node_metadata_response.call_count == len(nodeIds) - - expected_calls = [ - call( - f"{nodes_path}/{nodeId}", - mock_get_node_metadata_response.return_value, - remote_fs, - ) - for nodeId in nodeIds - ] - mock_write_api_response_to_fs.assert_has_calls(expected_calls, any_order=True) - - def test_save_api_pipeline_response_to_fs(self, mocker): - pipelines_path = "/pipelines" - pipelineIds = ["01f456", "01f457"] - expected_selected_pipeline_response = {"test": "json"} - - mock_get_selected_pipeline_response = mocker.patch( - "kedro_viz.api.rest.responses.get_selected_pipeline_response", - return_value=expected_selected_pipeline_response, - ) - mock_write_api_response_to_fs = mocker.patch( - "kedro_viz.api.rest.responses.write_api_response_to_fs" - ) - - mocker.patch( - "kedro_viz.api.rest.responses.data_access_manager." - "registered_pipelines.get_pipeline_ids", - return_value=pipelineIds, - ) - - remote_fs = Mock() - - save_api_pipeline_response_to_fs(pipelines_path, remote_fs) - - assert mock_write_api_response_to_fs.call_count == len(pipelineIds) - assert mock_get_selected_pipeline_response.call_count == len(pipelineIds) - - expected_calls = [ - call( - f"{pipelines_path}/{pipelineId}", - mock_get_selected_pipeline_response.return_value, - remote_fs, - ) - for pipelineId in pipelineIds - ] - mock_write_api_response_to_fs.assert_has_calls(expected_calls, any_order=True) - - @pytest.mark.parametrize( - "file_path, protocol, is_all_previews_enabled", - [ - ("s3://shareableviz", "s3", True), - ("abfs://shareableviz", "abfs", False), - ("shareableviz", "file", True), - ], - ) - def test_save_api_responses_to_fs( - self, file_path, protocol, is_all_previews_enabled, mocker - ): - mock_api_main_response_to_fs = mocker.patch( - "kedro_viz.api.rest.responses.save_api_main_response_to_fs" - ) - mock_api_node_response_to_fs = mocker.patch( - "kedro_viz.api.rest.responses.save_api_node_response_to_fs" - ) - mock_api_pipeline_response_to_fs = mocker.patch( - "kedro_viz.api.rest.responses.save_api_pipeline_response_to_fs" - ) - - mock_filesystem = mocker.patch("fsspec.filesystem") - mock_filesystem.return_value.protocol = protocol - - save_api_responses_to_fs( - file_path, mock_filesystem.return_value, is_all_previews_enabled - ) - - mock_api_main_response_to_fs.assert_called_once_with( - f"{file_path}/api/main", mock_filesystem.return_value - ) - mock_api_node_response_to_fs.assert_called_once_with( - f"{file_path}/api/nodes", - mock_filesystem.return_value, - is_all_previews_enabled, - ) - mock_api_pipeline_response_to_fs.assert_called_once_with( - f"{file_path}/api/pipelines", mock_filesystem.return_value - ) diff --git a/package/tests/test_api/test_rest/test_responses/test_base.py b/package/tests/test_api/test_rest/test_responses/test_base.py new file mode 100755 index 0000000000..d487fc542d --- /dev/null +++ b/package/tests/test_api/test_rest/test_responses/test_base.py @@ -0,0 +1,10 @@ +from kedro_viz.api.rest.responses.base import APINotFoundResponse + + +def test_api_not_found_response_valid_message(): + response = APINotFoundResponse(message="Resource not found") + assert response.message == "Resource not found" + + # Test that the model is serializable to a dictionary + serialized_response = response.model_dump() + assert serialized_response == {"message": "Resource not found"} diff --git a/package/tests/test_api/test_rest/test_responses/test_metadata.py b/package/tests/test_api/test_rest/test_responses/test_metadata.py new file mode 100755 index 0000000000..c6e8dd6d12 --- /dev/null +++ b/package/tests/test_api/test_rest/test_responses/test_metadata.py @@ -0,0 +1,24 @@ +from kedro_viz.api.rest.responses.metadata import get_metadata_response +from kedro_viz.models.metadata import Metadata + + +class TestAppMetadata: + def test_get_metadata_response(self, mocker): + mock_get_compat = mocker.patch( + "kedro_viz.api.rest.responses.metadata.get_package_compatibilities", + return_value="mocked_compatibilities", + ) + mock_set_compat = mocker.patch( + "kedro_viz.api.rest.responses.metadata.Metadata.set_package_compatibilities" + ) + + response = get_metadata_response() + + # Assert get_package_compatibilities was called + mock_get_compat.assert_called_once() + + # Assert set_package_compatibilities was called with the mocked compatibilities + mock_set_compat.assert_called_once_with("mocked_compatibilities") + + # Assert the function returns the Metadata instance + assert isinstance(response, Metadata) diff --git a/package/tests/test_api/test_rest/test_responses/test_nodes.py b/package/tests/test_api/test_rest/test_responses/test_nodes.py new file mode 100644 index 0000000000..6ee2008826 --- /dev/null +++ b/package/tests/test_api/test_rest/test_responses/test_nodes.py @@ -0,0 +1,91 @@ +from pathlib import Path +from unittest import mock + +from fastapi.testclient import TestClient + +from kedro_viz.models.flowchart.nodes import TaskNode +from tests.test_api.test_rest.test_responses.assert_helpers import ( + assert_example_transcoded_data, +) + + +class TestTranscodedDataset: + """Test a viz API created from a Kedro project.""" + + def test_endpoint_main(self, example_transcoded_api): + client = TestClient(example_transcoded_api) + response = client.get("/api/main") + assert response.status_code == 200 + assert_example_transcoded_data(response.json()) + + def test_transcoded_data_node_metadata(self, example_transcoded_api): + client = TestClient(example_transcoded_api) + response = client.get("/api/nodes/0ecea0de") + assert response.json() == { + "filepath": "model_inputs.csv", + "original_type": "pandas.csv_dataset.CSVDataset", + "transcoded_types": [ + "pandas.parquet_dataset.ParquetDataset", + ], + "run_command": "kedro run --to-outputs=model_inputs@pandas2", + } + + +class TestNodeMetadataEndpoint: + def test_node_not_exist(self, client): + response = client.get("/api/nodes/foo") + assert response.status_code == 404 + + def test_task_node_metadata(self, client): + response = client.get("/api/nodes/782e4a43") + metadata = response.json() + assert ( + metadata["code"].replace(" ", "") + == "defprocess_data(raw_data,train_test_split):\npass\n" + ) + assert metadata["parameters"] == {"uk.data_processing.train_test_split": 0.1} + assert metadata["inputs"] == [ + "uk.data_processing.raw_data", + "params:uk.data_processing.train_test_split", + ] + assert metadata["outputs"] == ["model_inputs"] + assert ( + metadata["run_command"] + == "kedro run --to-nodes='uk.data_processing.process_data'" + ) + assert str(Path("package/tests/conftest.py")) in metadata["filepath"] + + def test_data_node_metadata(self, client): + response = client.get("/api/nodes/0ecea0de") + assert response.json() == { + "filepath": "model_inputs.csv", + "type": "pandas.csv_dataset.CSVDataset", + "preview_type": "TablePreview", + "run_command": "kedro run --to-outputs=model_inputs", + "stats": {"columns": 12, "rows": 29768}, + } + + def test_data_node_metadata_for_free_input(self, client): + response = client.get("/api/nodes/13399a82") + assert response.json() == { + "filepath": "raw_data.csv", + "preview_type": "TablePreview", + "type": "pandas.csv_dataset.CSVDataset", + } + + def test_parameters_node_metadata(self, client): + response = client.get("/api/nodes/f1f1425b") + assert response.json() == { + "parameters": {"train_test_split": 0.1, "num_epochs": 1000} + } + + def test_single_parameter_node_metadata(self, client): + response = client.get("/api/nodes/f0ebef01") + assert response.json() == { + "parameters": {"uk.data_processing.train_test_split": 0.1} + } + + def test_no_metadata(self, client): + with mock.patch.object(TaskNode, "has_metadata", return_value=False): + response = client.get("/api/nodes/782e4a43") + assert response.json() == {} diff --git a/package/tests/test_api/test_rest/test_responses/test_pipelines.py b/package/tests/test_api/test_rest/test_responses/test_pipelines.py new file mode 100755 index 0000000000..4b933e33e2 --- /dev/null +++ b/package/tests/test_api/test_rest/test_responses/test_pipelines.py @@ -0,0 +1,241 @@ +import json +from pathlib import Path + +from fastapi.testclient import TestClient + +from kedro_viz.api import apps +from kedro_viz.api.rest.responses.pipelines import get_kedro_project_json_data +from tests.test_api.test_rest.test_responses.assert_helpers import ( + assert_dict_list_equal, + assert_example_data, + assert_example_data_from_file, + assert_modular_pipelines_tree_equal, + assert_nodes_equal, +) + + +class TestMainEndpoint: + """Test a viz API created from a Kedro project.""" + + def test_endpoint_main(self, client, mocker, data_access_manager): + mocker.patch( + "kedro_viz.api.rest.responses.nodes.data_access_manager", + new=data_access_manager, + ) + response = client.get("/api/main") + assert_example_data(response.json()) + + def test_endpoint_main_no_default_pipeline(self, example_api_no_default_pipeline): + client = TestClient(example_api_no_default_pipeline) + response = client.get("/api/main") + assert len(response.json()["nodes"]) == 6 + assert len(response.json()["edges"]) == 9 + assert response.json()["pipelines"] == [ + {"id": "data_science", "name": "data_science"}, + {"id": "data_processing", "name": "data_processing"}, + ] + + def test_endpoint_main_for_edge_case_pipelines( + self, + example_api_for_edge_case_pipelines, + expected_modular_pipeline_tree_for_edge_cases, + ): + client = TestClient(example_api_for_edge_case_pipelines) + response = client.get("/api/main") + actual_modular_pipelines_tree = response.json()["modular_pipelines"] + assert_modular_pipelines_tree_equal( + actual_modular_pipelines_tree, expected_modular_pipeline_tree_for_edge_cases + ) + + def test_get_kedro_project_json_data(self, mocker): + expected_json_data = {"key": "value"} + encoded_response = json.dumps(expected_json_data).encode("utf-8") + + mock_get_default_response = mocker.patch( + "kedro_viz.api.rest.responses.pipelines.get_pipeline_response", + return_value={"key": "value"}, + ) + mock_get_encoded_response = mocker.patch( + "kedro_viz.api.rest.responses.pipelines.get_encoded_response", + return_value=encoded_response, + ) + + json_data = get_kedro_project_json_data() + + mock_get_default_response.assert_called_once() + mock_get_encoded_response.assert_called_once_with( + mock_get_default_response.return_value + ) + assert json_data == expected_json_data + + +class TestSinglePipelineEndpoint: + def test_get_pipeline(self, client): + response = client.get("/api/pipelines/data_science") + assert response.status_code == 200 + response_data = response.json() + expected_edges = [ + {"source": "f2b25286", "target": "d5a8b994"}, + {"source": "f1f1425b", "target": "uk.data_science"}, + {"source": "f1f1425b", "target": "f2b25286"}, + {"source": "uk.data_science", "target": "d5a8b994"}, + {"source": "uk", "target": "d5a8b994"}, + {"source": "0ecea0de", "target": "uk"}, + {"source": "0ecea0de", "target": "uk.data_science"}, + {"source": "f1f1425b", "target": "uk"}, + {"source": "0ecea0de", "target": "f2b25286"}, + ] + assert_dict_list_equal( + response_data.pop("edges"), expected_edges, sort_keys=("source", "target") + ) + expected_nodes = [ + { + "id": "0ecea0de", + "name": "model_inputs", + "tags": ["train", "split"], + "pipelines": ["__default__", "data_science", "data_processing"], + "modular_pipelines": ["uk.data_science", "uk.data_processing"], + "type": "data", + "layer": "model_inputs", + "dataset_type": "pandas.csv_dataset.CSVDataset", + "stats": {"columns": 12, "rows": 29768}, + }, + { + "id": "f2b25286", + "name": "train_model", + "tags": ["train"], + "pipelines": ["__default__", "data_science"], + "modular_pipelines": ["uk.data_science"], + "type": "task", + "parameters": { + "train_test_split": 0.1, + "num_epochs": 1000, + }, + }, + { + "id": "f1f1425b", + "name": "parameters", + "tags": ["train"], + "pipelines": ["__default__", "data_science"], + "modular_pipelines": None, + "type": "parameters", + "layer": None, + "dataset_type": None, + "stats": None, + }, + { + "id": "d5a8b994", + "name": "uk.data_science.model", + "tags": ["train"], + "pipelines": ["__default__", "data_science"], + "modular_pipelines": ["uk", "uk.data_science"], + "type": "data", + "layer": None, + "dataset_type": "io.memory_dataset.MemoryDataset", + "stats": None, + }, + { + "id": "uk", + "name": "uk", + "tags": ["train"], + "pipelines": ["data_science"], + "type": "modularPipeline", + "modular_pipelines": None, + "layer": None, + "dataset_type": None, + "stats": None, + }, + { + "id": "uk.data_science", + "name": "uk.data_science", + "tags": ["train"], + "pipelines": ["data_science"], + "type": "modularPipeline", + "modular_pipelines": None, + "layer": None, + "dataset_type": None, + "stats": None, + }, + ] + assert_nodes_equal(response_data.pop("nodes"), expected_nodes) + + expected_modular_pipelines = { + "__root__": { + "children": [ + {"id": "f1f1425b", "type": "parameters"}, + {"id": "0ecea0de", "type": "data"}, + {"id": "uk", "type": "modularPipeline"}, + {"id": "d5a8b994", "type": "data"}, + ], + "id": "__root__", + "inputs": [], + "name": "__root__", + "outputs": [], + }, + "uk": { + "children": [ + {"id": "uk.data_science", "type": "modularPipeline"}, + ], + "id": "uk", + "inputs": ["0ecea0de", "f1f1425b"], + "name": "uk", + "outputs": ["d5a8b994"], + }, + "uk.data_science": { + "children": [ + {"id": "f2b25286", "type": "task"}, + ], + "id": "uk.data_science", + "inputs": ["0ecea0de", "f1f1425b"], + "name": "uk.data_science", + "outputs": ["d5a8b994"], + }, + } + + assert_modular_pipelines_tree_equal( + response_data.pop("modular_pipelines"), + expected_modular_pipelines, + ) + + # Extract and sort the layers field + response_data_layers_sorted = sorted(response_data["layers"]) + expected_layers_sorted = sorted(["model_inputs", "raw"]) + assert response_data_layers_sorted == expected_layers_sorted + + # Remove the layers field from response_data for further comparison + response_data.pop("layers") + + # Expected response without the layers field + expected_response_without_layers = { + "tags": [ + {"id": "split", "name": "split"}, + {"id": "train", "name": "train"}, + ], + "pipelines": [ + {"id": "__default__", "name": "__default__"}, + {"id": "data_science", "name": "data_science"}, + {"id": "data_processing", "name": "data_processing"}, + ], + "selected_pipeline": "data_science", + } + assert response_data == expected_response_without_layers + + def test_get_non_existing_pipeline(self, client): + response = client.get("/api/pipelines/foo") + assert response.status_code == 404 + + +class TestAPIAppFromFile: + def test_api_app_from_json_file_main_api(self): + filepath = str(Path(__file__).parent.parent.parent) + api_app = apps.create_api_app_from_file(filepath) + client = TestClient(api_app) + response = client.get("/api/main") + assert_example_data_from_file(response.json()) + + def test_api_app_from_json_file_index(self): + filepath = str(Path(__file__).parent.parent.parent) + api_app = apps.create_api_app_from_file(filepath) + client = TestClient(api_app) + response = client.get("/") + assert response.status_code == 200 diff --git a/package/tests/test_api/test_rest/test_responses/test_save_responses.py b/package/tests/test_api/test_rest/test_responses/test_save_responses.py new file mode 100644 index 0000000000..828fe26269 --- /dev/null +++ b/package/tests/test_api/test_rest/test_responses/test_save_responses.py @@ -0,0 +1,168 @@ +from unittest import mock +from unittest.mock import Mock, call, patch + +import pytest + +from kedro_viz.api.rest.responses.save_responses import ( + save_api_main_response_to_fs, + save_api_node_response_to_fs, + save_api_pipeline_response_to_fs, + save_api_responses_to_fs, + write_api_response_to_fs, +) + + +class TestSaveAPIResponse: + @pytest.mark.parametrize( + "file_path, protocol, is_all_previews_enabled", + [ + ("s3://shareableviz", "s3", True), + ("abfs://shareableviz", "abfs", False), + ("shareableviz", "file", True), + ], + ) + def test_save_api_responses_to_fs( + self, file_path, protocol, is_all_previews_enabled, mocker + ): + mock_api_main_response_to_fs = mocker.patch( + "kedro_viz.api.rest.responses.save_responses.save_api_main_response_to_fs" + ) + mock_api_node_response_to_fs = mocker.patch( + "kedro_viz.api.rest.responses.save_responses.save_api_node_response_to_fs" + ) + mock_api_pipeline_response_to_fs = mocker.patch( + "kedro_viz.api.rest.responses.save_responses.save_api_pipeline_response_to_fs" + ) + + mock_filesystem = mocker.patch("fsspec.filesystem") + mock_filesystem.return_value.protocol = protocol + + save_api_responses_to_fs( + file_path, mock_filesystem.return_value, is_all_previews_enabled + ) + + mock_api_main_response_to_fs.assert_called_once_with( + f"{file_path}/api/main", mock_filesystem.return_value + ) + mock_api_node_response_to_fs.assert_called_once_with( + f"{file_path}/api/nodes", + mock_filesystem.return_value, + is_all_previews_enabled, + ) + mock_api_pipeline_response_to_fs.assert_called_once_with( + f"{file_path}/api/pipelines", mock_filesystem.return_value + ) + + def test_save_api_main_response_to_fs(self, mocker): + expected_default_response = {"test": "json"} + main_path = "/main" + + mock_get_default_response = mocker.patch( + "kedro_viz.api.rest.responses.save_responses.get_pipeline_response", + return_value=expected_default_response, + ) + mock_write_api_response_to_fs = mocker.patch( + "kedro_viz.api.rest.responses.save_responses.write_api_response_to_fs" + ) + + remote_fs = Mock() + + save_api_main_response_to_fs(main_path, remote_fs) + + mock_get_default_response.assert_called_once() + mock_write_api_response_to_fs.assert_called_once_with( + main_path, mock_get_default_response.return_value, remote_fs + ) + + def test_save_api_pipeline_response_to_fs(self, mocker): + pipelines_path = "/pipelines" + pipelineIds = ["01f456", "01f457"] + expected_selected_pipeline_response = {"test": "json"} + + mock_get_selected_pipeline_response = mocker.patch( + "kedro_viz.api.rest.responses.save_responses.get_pipeline_response", + return_value=expected_selected_pipeline_response, + ) + mock_write_api_response_to_fs = mocker.patch( + "kedro_viz.api.rest.responses.save_responses.write_api_response_to_fs" + ) + + mocker.patch( + "kedro_viz.api.rest.responses.save_responses.data_access_manager." + "registered_pipelines.get_pipeline_ids", + return_value=pipelineIds, + ) + + remote_fs = Mock() + + save_api_pipeline_response_to_fs(pipelines_path, remote_fs) + + assert mock_write_api_response_to_fs.call_count == len(pipelineIds) + assert mock_get_selected_pipeline_response.call_count == len(pipelineIds) + + expected_calls = [ + call( + f"{pipelines_path}/{pipelineId}", + mock_get_selected_pipeline_response.return_value, + remote_fs, + ) + for pipelineId in pipelineIds + ] + mock_write_api_response_to_fs.assert_has_calls(expected_calls, any_order=True) + + def test_save_api_node_response_to_fs(self, mocker): + nodes_path = "/nodes" + nodeIds = ["01f456", "01f457"] + expected_metadata_response = {"test": "json"} + + mock_get_node_metadata_response = mocker.patch( + "kedro_viz.api.rest.responses.save_responses.get_node_metadata_response", + return_value=expected_metadata_response, + ) + mock_write_api_response_to_fs = mocker.patch( + "kedro_viz.api.rest.responses.save_responses.write_api_response_to_fs" + ) + mocker.patch( + "kedro_viz.api.rest.responses.save_responses.data_access_manager.nodes.get_node_ids", + return_value=nodeIds, + ) + remote_fs = mock.Mock() + + save_api_node_response_to_fs(nodes_path, remote_fs, False) + + assert mock_write_api_response_to_fs.call_count == len(nodeIds) + assert mock_get_node_metadata_response.call_count == len(nodeIds) + + expected_calls = [ + mock.call( + f"{nodes_path}/{nodeId}", + mock_get_node_metadata_response.return_value, + remote_fs, + ) + for nodeId in nodeIds + ] + mock_write_api_response_to_fs.assert_has_calls(expected_calls, any_order=True) + + @pytest.mark.parametrize( + "file_path, response, encoded_response", + [ + ( + "test_output.json", + {"key1": "value1", "key2": "value2"}, + b'{"key1": "value1", "key2": "value2"}', + ), + ], + ) + def test_write_api_response_to_fs( + self, file_path, response, encoded_response, mocker + ): + mock_encode_to_human_readable = mocker.patch( + "kedro_viz.api.rest.responses.utils.EnhancedORJSONResponse.encode_to_human_readable", + return_value=encoded_response, + ) + with patch("fsspec.filesystem") as mock_filesystem: + mockremote_fs = mock_filesystem.return_value + mockremote_fs.open.return_value.__enter__.return_value = Mock() + write_api_response_to_fs(file_path, response, mockremote_fs) + mockremote_fs.open.assert_called_once_with(file_path, "wb") + mock_encode_to_human_readable.assert_called_once() diff --git a/package/tests/test_api/test_rest/test_responses/test_utils.py b/package/tests/test_api/test_rest/test_responses/test_utils.py new file mode 100644 index 0000000000..cad8607e2b --- /dev/null +++ b/package/tests/test_api/test_rest/test_responses/test_utils.py @@ -0,0 +1,43 @@ +import pytest + +from kedro_viz.api.rest.responses.utils import ( + EnhancedORJSONResponse, + get_encoded_response, +) + + +class TestEnhancedORJSONResponse: + @pytest.mark.parametrize( + "content, expected", + [ + ( + {"key1": "value1", "key2": "value2"}, + b'{\n "key1": "value1",\n "key2": "value2"\n}', + ), + (["item1", "item2"], b'[\n "item1",\n "item2"\n]'), + ], + ) + def test_encode_to_human_readable(self, content, expected): + result = EnhancedORJSONResponse.encode_to_human_readable(content) + assert result == expected + + +def test_get_encoded_response(mocker): + mock_jsonable_encoder = mocker.patch( + "kedro_viz.api.rest.responses.utils.jsonable_encoder" + ) + mock_encode_to_human_readable = mocker.patch( + "kedro_viz.api.rest.responses.utils.EnhancedORJSONResponse.encode_to_human_readable" + ) + + mock_response = {"key": "value"} + mock_jsonable_encoder.return_value = mock_response + mock_encoded_response = b"encoded-response" + mock_encode_to_human_readable.return_value = mock_encoded_response + + result = get_encoded_response(mock_response) + + # Assertions + mock_jsonable_encoder.assert_called_once_with(mock_response) + mock_encode_to_human_readable.assert_called_once_with(mock_response) + assert result == mock_encoded_response diff --git a/package/tests/test_api/test_rest/test_router.py b/package/tests/test_api/test_rest/test_router.py index d84f1ce0f2..523043d96d 100644 --- a/package/tests/test_api/test_rest/test_router.py +++ b/package/tests/test_api/test_rest/test_router.py @@ -21,7 +21,7 @@ def test_deploy_kedro_viz( client, platform, endpoint, bucket_name, is_all_previews_enabled, mocker ): mocker.patch( - "kedro_viz.api.rest.router.DeployerFactory.create_deployer", + "kedro_viz.integrations.deployment.deployer_factory.DeployerFactory.create_deployer", return_value=MockDeployer(platform, endpoint, bucket_name), ) response = client.post( diff --git a/package/tests/test_server.py b/package/tests/test_server.py index 33fe6f2e1b..2169e9d4da 100644 --- a/package/tests/test_server.py +++ b/package/tests/test_server.py @@ -151,7 +151,7 @@ def test_load_file( def test_save_file(self, tmp_path, mocker): mock_filesystem = mocker.patch("fsspec.filesystem") save_api_responses_to_fs_mock = mocker.patch( - "kedro_viz.server.save_api_responses_to_fs" + "kedro_viz.api.rest.responses.save_responses.save_api_responses_to_fs" ) save_file = tmp_path / "save.json" run_server(save_file=save_file)