Skip to content

Commit

Permalink
- check that it works with CSVDataSet object
Browse files Browse the repository at this point in the history
- check that it works if a run id is specified
  • Loading branch information
Galileo-Galilei committed Apr 13, 2020
1 parent 1c2a9cf commit 82ea397
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
Empty file added tests/__init__.py
Empty file.
Empty file added tests/io/__init__.py
Empty file.
58 changes: 58 additions & 0 deletions tests/io/test_mlflow_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest
import pandas as pd
import mlflow
from kedro_mlflow.io import MlflowDataSet
from kedro.extras.datasets.pandas import CSVDataSet
import glob

@pytest.fixture
def tracking_uri(tmp_path):
return (tmp_path / "mlruns")


@pytest.fixture
def dummy_df1():
return pd.DataFrame({"col1": [1, 2, 3],
"col2": [4, 5, 6]})


@pytest.fixture
def dummy_df2():
return pd.DataFrame({"col3": [7, 8, 9],
"col4": ["a", "b", "c"]})


def test_mlflow_csv_data_set_save(tmp_path, tracking_uri, dummy_df1):
mlflow.set_tracking_uri(tracking_uri.as_uri())
mlflow_csv_dataset = MlflowDataSet(data_set=dict(type=CSVDataSet,
filepath=(tmp_path / "df1.csv").as_posix()))
with mlflow.start_run():
mlflow_csv_dataset.save(dummy_df1)
run_id = mlflow.active_run().info.run_id

# the artifact must be properly uploaded to "mlruns" and reloadable
assert (tracking_uri / "0" / run_id / "artifacts" / "df1.csv").is_file()
assert dummy_df1.equals(mlflow_csv_dataset.load())

def test_mlflow_data_set_save_with_run_id(tmp_path, tracking_uri, dummy_df1):
mlflow.set_tracking_uri(tracking_uri.as_uri())

# create a first run and get its id
with mlflow.start_run():
mlflow.log_param("fake", 2)
run_id = mlflow.active_run().info.run_id

# then same scneario but precise the run_id where data is saved
mlflow_csv_dataset = MlflowDataSet(data_set=dict(type=CSVDataSet,
filepath=(tmp_path / "df1.csv").as_posix()),
run_id=run_id)
mlflow_csv_dataset.save(dummy_df1)
tracked_runs = [f.stem for f in (tracking_uri / "0").glob("*") if f.is_dir()]
# same tests as previously, bu no new experiments must have been created
assert len(tracked_runs)==1
assert (tracking_uri / "0" / run_id / "artifacts" / "df1.csv").is_file()
assert dummy_df1.equals(mlflow_csv_dataset.load())

# use pytest.fixture.parametrize to test if it works with
# - type="pandas.CSVDataSet" (a string) -> because it is necessary for loading from catalog
# - experiment name = None or a string

0 comments on commit 82ea397

Please sign in to comment.