Skip to content

Commit

Permalink
Refactor datasets (#1698)
Browse files Browse the repository at this point in the history
Description
Related to: #1700

Kedro-Viz and Kedro-Datasets are tightly coupled , leading to issues where modifications in Kedro-Datasets can disrupt the functionality of Kedro-Viz. This ticket aims to refactor the codebase to enhance modularity and stability. The goal is to decouple the two components, allowing for smoother integration in the future without excessive interdependencies.

Development notes
We've moved the logic for generating previews for several datasets from kedro-viz to kedro-datasets, reducing the coupling between kedro-viz and kedro-datasets.

Previews are rendered differently in the front-end, so we've introduced aliasing using NewType. Currently, we support four types of previews in the front-end: json, image [png], plotly, and dataframe [as tables]. Users can enable previews for custom datasets as long as they fall into one of these categories.

Previously, we utilized the load function and overrode it in kedro-viz to load dataset previews. However, this logic has been migrated to kedro-datasets within the preview() function. In the kedro-viz API, we no longer include plot, image, or tracking_data as fields of DataNodeMetadata. Instead, we now only send preview and preview_type. Preview can be a preview of any dataset that has the preview function, and preview_type informs the front-end on how to render the preview.
  • Loading branch information
rashidakanchwala authored Feb 21, 2024
1 parent feba71c commit 5ca1550
Show file tree
Hide file tree
Showing 21 changed files with 190 additions and 433 deletions.
2 changes: 1 addition & 1 deletion demo-project/src/demo_project/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jupyter~=1.0
jupyter_client>=5.1, <7.0
jupyterlab~=3.0
kedro~=0.18.0
kedro-datasets[pandas.CSVDataset,pandas.ExcelDataset, pandas.ParquetDataset, plotly.PlotlyDataset]<=2.0.0
git+https://github.com/kedro-org/kedro-plugins.git@main#egg=kedro-datasets[pandas.ParquetDataset,pandas.CSVDataset,pandas.ExcelDataset,plotly.JSONDataset]&subdirectory=kedro-datasets # temporary pin until the next release of kedro-datasets
nbstripout~=0.4
pytest-cov~=2.5
pytest-mock>=1.7.1, <2.0
Expand Down
2 changes: 1 addition & 1 deletion demo-project/src/docker_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
kedro>=0.18.0
kedro-datasets[pandas.CSVDataset,pandas.ExcelDataset, pandas.ParquetDataset, plotly.PlotlyDataset, matplotlib.MatplotlibWriter]<=2.0.0
git+https://github.com/kedro-org/kedro-plugins.git@main#egg=kedro-datasets[pandas.ParquetDataset,pandas.CSVDataset,pandas.ExcelDataset,plotly.JSONDataset]&subdirectory=kedro-datasets # temporary pin until the next release of kedro-datasets
scikit-learn~=1.0
pillow~=9.0
seaborn~=0.11.2
6 changes: 2 additions & 4 deletions package/kedro_viz/api/rest/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,9 @@ class Config:
class DataNodeMetadataAPIResponse(BaseAPIResponse):
filepath: Optional[str]
type: str
plot: Optional[Dict]
image: Optional[str]
tracking_data: Optional[Dict]
run_command: Optional[str]
preview: Optional[Dict]
preview: Optional[Union[Dict, str]]
preview_type: Optional[str]
stats: Optional[Dict]

class Config:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
TRACKING_DATASET_GROUPS,
TrackingDatasetGroup,
TrackingDatasetModel,
get_dataset_type,
)
from kedro_viz.models.utils import get_dataset_type

if TYPE_CHECKING:
try:
Expand Down
66 changes: 0 additions & 66 deletions package/kedro_viz/integrations/kedro/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
load data from projects created in a range of Kedro versions.
"""
# pylint: disable=import-outside-toplevel, protected-access
# pylint: disable=missing-function-docstring

import base64
import json
import logging
from pathlib import Path
Expand All @@ -14,24 +12,7 @@
from kedro import __version__
from kedro.framework.session import KedroSession
from kedro.framework.session.store import BaseSessionStore

try:
from kedro_datasets import ( # isort:skip
json as json_dataset,
matplotlib,
plotly,
tracking,
)
except ImportError: # kedro_datasets is not installed.
from kedro.extras.datasets import ( # Safe since ImportErrors are suppressed within kedro.
json as json_dataset,
matplotlib,
plotly,
tracking,
)

from kedro.io import DataCatalog
from kedro.io.core import get_filepath_str
from kedro.pipeline import Pipeline

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -129,50 +110,3 @@ def load_data(
stats_dict = _get_dataset_stats(project_path)

return catalog, pipelines_dict, session_store, stats_dict


# Try to access the attribute to trigger the import of dependencies, only modify the _load
# if dependencies are installed.
# These datasets do not have _load methods defined (tracking and matplotlib) or do not
# load to json (plotly), hence the need to define _load here.
try:
getattr(matplotlib, "MatplotlibWriter") # Trigger the lazy import

def matplotlib_writer_load(dataset: matplotlib.MatplotlibWriter) -> str:
load_path = get_filepath_str(dataset._get_load_path(), dataset._protocol)
with dataset._fs.open(load_path, mode="rb") as img_file:
base64_bytes = base64.b64encode(img_file.read())
return base64_bytes.decode("utf-8")

matplotlib.MatplotlibWriter._load = matplotlib_writer_load
except (ImportError, AttributeError):
pass

try:
getattr(plotly, "JSONDataset") # Trigger import
plotly.JSONDataset._load = json_dataset.JSONDataset._load
except (ImportError, AttributeError):
getattr(plotly, "JSONDataSet") # Trigger import
plotly.JSONDataSet._load = json_dataset.JSONDataSet._load


try:
getattr(plotly, "PlotlyDataset") # Trigger import
plotly.PlotlyDataset._load = json_dataset.JSONDataset._load
except (ImportError, AttributeError):
getattr(plotly, "PlotlyDataSet") # Trigger import
plotly.PlotlyDataSet._load = json_dataset.JSONDataSet._load

try:
getattr(tracking, "JSONDataset") # Trigger import
tracking.JSONDataset._load = json_dataset.JSONDataset._load
except (ImportError, AttributeError):
getattr(tracking, "JSONDataSet") # Trigger import
tracking.JSONDataSet._load = json_dataset.JSONDataSet._load

try:
getattr(tracking, "MetricsDataset") # Trigger import
tracking.MetricsDataset._load = json_dataset.JSONDataset._load
except (ImportError, AttributeError):
getattr(tracking, "MetricsDataSet") # Trigger import
tracking.MetricsDataSet._load = json_dataset.JSONDataSet._load
8 changes: 5 additions & 3 deletions package/kedro_viz/models/experiment_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class TrackingDatasetGroup(str, Enum):
JSON = "json"


# Map dataset types (as produced by get_dataset_type) to their group
# Map dataset types to their group
TRACKING_DATASET_GROUPS = {
"plotly.plotly_dataset.PlotlyDataset": TrackingDatasetGroup.PLOT,
"plotly.json_dataset.JSONDataset": TrackingDatasetGroup.PLOT,
Expand Down Expand Up @@ -110,9 +110,11 @@ def load_tracking_data(self, run_id: str):

try:
if TRACKING_DATASET_GROUPS[self.dataset_type] is TrackingDatasetGroup.PLOT:
self.runs[run_id] = {self.dataset._filepath.name: self.dataset.load()}
self.runs[run_id] = {
self.dataset._filepath.name: self.dataset.preview() # type: ignore
}
else:
self.runs[run_id] = self.dataset.load()
self.runs[run_id] = self.dataset.preview() # type: ignore
except Exception as exc: # pylint: disable=broad-except # pragma: no cover
logger.warning(
"'%s' with version '%s' could not be loaded. Full exception: %s: %s",
Expand Down
114 changes: 37 additions & 77 deletions package/kedro_viz/models/flowchart.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""`kedro_viz.models.flowchart` defines data models to represent Kedro entities in a viz graph."""

# pylint: disable=protected-access, missing-function-docstring
import abc
import hashlib
Expand Down Expand Up @@ -586,52 +587,16 @@ def set_viz_metadata(cls, _, values):

return None

# TODO: improve this scheme.
def is_plot_node(self):
"""Check if the current node is a plot node.
Currently it only recognises one underlying dataset as a plot node.
In the future, we might want to make this generic.
"""
return self.dataset_type in (
"plotly.plotly_dataset.PlotlyDataset",
"plotly.json_dataset.JSONDataset",
"plotly.plotly_dataset.PlotlyDataSet",
"plotly.json_dataset.JSONDataSet",
)

def is_image_node(self):
"""Check if the current node is a matplotlib image node."""
return self.dataset_type == "matplotlib.matplotlib_writer.MatplotlibWriter"

def is_metric_node(self):
"""Check if the current node is a metrics node."""
return self.dataset_type in (
"tracking.metrics_dataset.MetricsDataset",
"tracking.metrics_dataset.MetricsDataSet",
)

def is_json_node(self):
"""Check if the current node is a JSONDataset node."""
return self.dataset_type in (
"tracking.json_dataset.JSONDataset",
"tracking.json_dataset.JSONDataSet",
)

def is_tracking_node(self):
"""Checks if the current node is a tracking data node"""
return self.is_json_node() or self.is_metric_node()

def is_preview_node(self):
"""Checks if the current node has a preview"""
if not (self.viz_metadata and self.viz_metadata.get("preview_args", None)):
return False

return True

def get_preview_args(self):
"""Gets the preview arguments for a dataset"""
return self.viz_metadata.get("preview_args", None)

def is_preview_disabled(self):
"""Checks if the dataset has a preview disabled"""
return (
self.viz_metadata is not None and self.viz_metadata.get("preview") is False
)


class TranscodedDataNode(GraphNode):
"""Represent a graph node of type data
Expand Down Expand Up @@ -718,24 +683,15 @@ class DataNodeMetadata(GraphNodeMetadata):
# The path to the actual data file for the underlying dataset
filepath: Optional[str]

plot: Optional[Dict] = Field(
None, description="The optional plot data if the underlying dataset has a plot"
)

# The image data if the underlying dataset has a image
# currently only applicable for matplotlib.MatplotlibWriter
image: Optional[str] = Field(
None, description="The image data if the underlying dataset has a image"
)
tracking_data: Optional[Dict] = Field(
None,
description="The tracking data if the underlying dataset has a tracking dataset",
)
run_command: Optional[str] = Field(
None, description="Command to run the pipeline to this node"
)
preview: Optional[Dict] = Field(
None, description="Preview data for the underlying datanode"
preview: Optional[Union[Dict, str]] = Field(
None, description="Preview data for the underlying data node"
)

preview_type: Optional[str] = Field(
None, description="Type of preview for the dataset"
)
stats: Optional[Dict] = Field(None, description="The statistics for the data node.")

Expand Down Expand Up @@ -769,35 +725,39 @@ def set_run_command(cls, _):
return f"kedro run --to-outputs={cls.data_node.name}"
return None

@validator("plot", always=True)
def set_plot(cls, _):
if cls.data_node.is_plot_node():
return cls.data_node.kedro_obj.load()
return None
@validator("preview", always=True)
def set_preview(cls, _):
if cls.data_node.is_preview_disabled() or not hasattr(cls.dataset, "preview"):
return None

@validator("image", always=True)
def set_image(cls, _):
if cls.data_node.is_image_node():
return cls.data_node.kedro_obj.load()
return None
try:
preview_args = (
cls.data_node.get_preview_args() if cls.data_node.viz_metadata else None
)
if preview_args is None:
return cls.dataset.preview()
return cls.dataset.preview(**preview_args)

@validator("tracking_data", always=True)
def set_tracking_data(cls, _):
if cls.data_node.is_tracking_node():
return cls.data_node.kedro_obj.load()
return None
except Exception as exc: # pylint: disable=broad-except
logger.warning(
"'%s' could not be previewed. Full exception: %s: %s",
cls.data_node.name,
type(exc).__name__,
exc,
)
return None

@validator("preview", always=True)
def set_preview(cls, _):
if not (cls.data_node.is_preview_node() and hasattr(cls.dataset, "_preview")):
@validator("preview_type", always=True)
def set_preview_type(cls, _):
if cls.data_node.is_preview_disabled() or not hasattr(cls.dataset, "preview"):
return None

try:
return cls.dataset._preview(**cls.data_node.get_preview_args())
return inspect.signature(cls.dataset.preview).return_annotation.__name__

except Exception as exc: # pylint: disable=broad-except # pragma: no cover
logger.warning(
"'%s' could not be previewed. Full exception: %s: %s",
"'%s' did not have preview type. Full exception: %s: %s",
cls.data_node.name,
type(exc).__name__,
exc,
Expand Down
2 changes: 1 addition & 1 deletion package/test_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
-r requirements.txt

kedro >=0.18.0
kedro-datasets[pandas.ParquetDataset, pandas.CSVDataset, pandas.ExcelDataset, plotly.JSONDataset]<=2.0.0
git+https://github.com/kedro-org/kedro-plugins.git@main#egg=kedro-datasets[pandas.ParquetDataset,pandas.CSVDataset,pandas.ExcelDataset,plotly.JSONDataset]&subdirectory=kedro-datasets # temporary pin until the next release of kedro-datasets
kedro-telemetry>=0.1.1 # for testing telemetry integration
bandit~=1.7
behave~=1.2
Expand Down
42 changes: 35 additions & 7 deletions package/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,17 @@ def json(self):
@pytest.fixture
def example_data_frame():
data = {
"id": ["35029", "30292"],
"company_rating": ["100%", "67%"],
"company_location": ["Niue", "Anguilla"],
"total_fleet_count": ["4.0", "6.0"],
"iata_approved": ["f", "f"],
"id": ["35029", "30292", "12345", "67890", "54321", "98765", "11111"],
"company_rating": ["100%", "67%", "80%", "95%", "72%", "88%", "75%"],
"company_location": [
"Niue",
"Anguilla",
"Barbados",
"Fiji",
"Grenada",
"Jamaica",
"Trinidad and Tobago",
],
}
yield pd.DataFrame(data)

Expand All @@ -292,10 +298,32 @@ def example_csv_dataset(tmp_path, example_data_frame):


@pytest.fixture
def example_data_node():
def example_csv_filepath(tmp_path, example_data_frame):
csv_file_path = tmp_path / "temporary_test_data.csv"
example_data_frame.to_csv(csv_file_path, index=False)
yield csv_file_path


@pytest.fixture
def example_data_node(example_csv_filepath):
dataset_name = "uk.data_science.model_training.dataset"
metadata = {"kedro-viz": {"preview_args": {"nrows": 3}}}
kedro_dataset = CSVDataset(filepath="test.csv", metadata=metadata)
kedro_dataset = CSVDataset(filepath=example_csv_filepath, metadata=metadata)
data_node = GraphNode.create_data_node(
dataset_name=dataset_name,
layer="raw",
tags=set(),
dataset=kedro_dataset,
stats={"rows": 10, "columns": 5, "file_size": 1024},
)

yield data_node


@pytest.fixture
def example_data_node_without_viz_metadata(example_csv_filepath):
dataset_name = "uk.data_science.model_training.dataset"
kedro_dataset = CSVDataset(filepath=example_csv_filepath)
data_node = GraphNode.create_data_node(
dataset_name=dataset_name,
layer="raw",
Expand Down
6 changes: 5 additions & 1 deletion package/tests/test_api/test_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ class TestNodeMetadataEndpoint:
(
"13399a82",
200,
{"filepath": "raw_data.csv", "type": "pandas.csv_dataset.CSVDataset"},
{
"filepath": "raw_data.csv",
"preview_type": "TablePreview",
"type": "pandas.csv_dataset.CSVDataset",
},
),
],
)
Expand Down
Loading

0 comments on commit 5ca1550

Please sign in to comment.