Skip to content

Commit

Permalink
✨ Add test for modern and legacy datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
deepyaman committed Sep 9, 2024
1 parent 4c9af58 commit a33e2a3
Showing 1 changed file with 81 additions and 0 deletions.
81 changes: 81 additions & 0 deletions tests/io/artifacts/test_mlflow_artifact_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import mlflow
import pandas as pd
import pytest
from kedro.io import AbstractDataset
from kedro_datasets.pandas import CSVDataset
from kedro_datasets.partitions import PartitionedDataset
from kedro_datasets.pickle import PickleDataset
Expand Down Expand Up @@ -289,3 +290,83 @@ def test_partitioned_dataset_save_and_reload(
reloaded_data = {k: loader() for k, loader in mlflow_dataset.load().items()}
for k, df in data.items():
pd.testing.assert_frame_equal(df, reloaded_data[k])


def test_modern_dataset(tmp_path, mlflow_client, df1):
class MyOwnDatasetWithoutUnderscoreMethods(AbstractDataset):
def __init__(self, filepath):
self._filepath = Path(filepath)

def load(self) -> pd.DataFrame:
return pd.read_csv(self._filepath)

def save(self, df: pd.DataFrame) -> None:
df.to_csv(str(self._filepath), index=False)

def _exists(self) -> bool:
return Path(self._filepath.as_posix()).exists()

def _describe(self):
return dict(param1=self._filepath)

filepath = tmp_path / "data.csv"

mlflow_dataset = MlflowArtifactDataset(
artifact_path="artifact_dir",
dataset=dict(
type=MyOwnDatasetWithoutUnderscoreMethods, filepath=filepath.as_posix()
),
)

with mlflow.start_run():
mlflow_dataset.save(df1)
run_id = mlflow.active_run().info.run_id

# the artifact must be properly uploaded to "mlruns" and reloadable
run_artifacts = [
fileinfo.path
for fileinfo in mlflow_client.list_artifacts(run_id=run_id, path="artifact_dir")
]
remote_path = (Path("artifact_dir") / filepath.name).as_posix()
assert remote_path in run_artifacts
assert df1.equals(mlflow_dataset.load())


def test_legacy_dataset(tmp_path, mlflow_client, df1):
class MyOwnDatasetWithUnderscoreMethods(AbstractDataset):
def __init__(self, filepath):
self._filepath = Path(filepath)

def _load(self) -> pd.DataFrame:
return pd.read_csv(self._filepath)

def _save(self, df: pd.DataFrame) -> None:
df.to_csv(str(self._filepath), index=False)

def _exists(self) -> bool:
return Path(self._filepath.as_posix()).exists()

def _describe(self):
return dict(param1=self._filepath)

filepath = tmp_path / "data.csv"

mlflow_dataset = MlflowArtifactDataset(
artifact_path="artifact_dir",
dataset=dict(
type=MyOwnDatasetWithUnderscoreMethods, filepath=filepath.as_posix()
),
)

with mlflow.start_run():
mlflow_dataset.save(df1)
run_id = mlflow.active_run().info.run_id

# the artifact must be properly uploaded to "mlruns" and reloadable
run_artifacts = [
fileinfo.path
for fileinfo in mlflow_client.list_artifacts(run_id=run_id, path="artifact_dir")
]
remote_path = (Path("artifact_dir") / filepath.name).as_posix()
assert remote_path in run_artifacts
assert df1.equals(mlflow_dataset.load())

0 comments on commit a33e2a3

Please sign in to comment.