Skip to content

Commit

Permalink
🚑 Ensure MlflowArtifactDataset logs in the same run when using mlflow…
Browse files Browse the repository at this point in the history
…>=2.18 and ThreadRunner
  • Loading branch information
Galileo-Galilei committed Dec 14, 2024
1 parent ca7e729 commit b41dddb
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 2 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## [Unreleased]

### Fixed

- :bug: :ambulance: Ensure `MlflowArtifactDataset` logs in the same run that parameters to when using ``mlflow>=2.18`` in combination with ``ThreadRunner`` [#613](https://github.com/Galileo-Galilei/kedro-mlflow/issues/613))

## [0.13.3] - 2024-10-29

### Added
Expand Down
10 changes: 10 additions & 0 deletions kedro_mlflow/framework/hooks/mlflow_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
_flatten_dict,
_generate_kedro_command,
)
from kedro_mlflow.io.catalog.add_run_id_to_artifact_datasets import (
add_run_id_to_artifact_datasets,
)
from kedro_mlflow.io.catalog.switch_catalog_logging import switch_catalog_logging
from kedro_mlflow.io.metrics import (
MlflowMetricDataset,
Expand Down Expand Up @@ -270,6 +273,13 @@ def before_pipeline_run(
pipeline_name=run_params["pipeline_name"],
),
)

# This function ensures the run_id started at the beginning of the pipeline
# is associated to all the datasets. This is necessary because to make mlflow thread safe
# each call to the "active run" now creates a new run when started in a new thread. See
# https://github.com/Galileo-Galilei/kedro-mlflow/issues/613 and https://github.com/Galileo-Galilei/kedro-mlflow/pull/615
add_run_id_to_artifact_datasets(catalog, mlflow.active_run().info.run_id)

else:
self._logger.info(
"kedro-mlflow logging is deactivated for this pipeline in the configuration. This includes DataSets and parameters."
Expand Down
6 changes: 6 additions & 0 deletions kedro_mlflow/io/artifacts/mlflow_artifact_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,9 @@ def _describe(self) -> Dict[str, Any]: # pragma: no cover
and consequently does not implements abtracts methods
"""
pass


def _is_instance_mlflow_artifact_dataset(dataset_instance):
parent_classname = dataset_instance.__class__.__bases__[0].__name__
instance_classname = f"Mlflow{parent_classname}"
return type(dataset_instance).__name__ == instance_classname
9 changes: 9 additions & 0 deletions kedro_mlflow/io/catalog/add_run_id_to_artifact_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from kedro_mlflow.io.artifacts.mlflow_artifact_dataset import (
_is_instance_mlflow_artifact_dataset,
)


def add_run_id_to_artifact_datasets(catalog, run_id: str):
for name, dataset in catalog._datasets.items():
if _is_instance_mlflow_artifact_dataset(dataset):
catalog._datasets[name].run_id = run_id
3 changes: 1 addition & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def mlflow_client(tracking_uri):

@pytest.fixture(autouse=True)
def cleanup_mlflow_after_runs():
# A test function will be run at this point
yield
yield # A test function will be run at this point
while mlflow.active_run():
mlflow.end_run()

Expand Down
159 changes: 159 additions & 0 deletions tests/framework/hooks/test_hook_log_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import mlflow
import pandas as pd
import pytest
from kedro.framework.session import KedroSession
from kedro.framework.startup import bootstrap_project
from kedro.io import DataCatalog, MemoryDataset
from kedro.pipeline import Pipeline, node
from kedro.runner import ThreadRunner
from kedro_datasets.pickle import PickleDataset

from kedro_mlflow.framework.hooks.mlflow_hook import MlflowHook
from kedro_mlflow.io.artifacts import MlflowArtifactDataset


@pytest.fixture
def dummy_pipeline():
def preprocess_fun(data):
return data

def train_fun(data):
return 2

dummy_pipeline = Pipeline(
[
node(
func=preprocess_fun,
inputs="raw_data",
outputs="data",
),
node(
func=train_fun,
inputs=["data"],
outputs="model",
),
]
)
return dummy_pipeline


@pytest.fixture
def dummy_catalog(tmp_path):
dummy_catalog = DataCatalog(
{
"raw_data": MemoryDataset(pd.DataFrame(data=[1], columns=["a"])),
"data": MemoryDataset(),
"model": MlflowArtifactDataset(
dataset=dict(
type=PickleDataset, filepath=(tmp_path / "model.csv").as_posix()
)
),
}
)
return dummy_catalog


@pytest.fixture
def dummy_run_params(tmp_path):
dummy_run_params = {
"project_path": tmp_path.as_posix(),
"env": "local",
"kedro_version": "0.16.0",
"tags": [],
"from_nodes": [],
"to_nodes": [],
"node_names": [],
"from_inputs": [],
"load_versions": [],
"pipeline_name": "my_cool_pipeline",
"extra_params": [],
}
return dummy_run_params


def test_mlflow_hook_automatically_update_artifact_run_id(
kedro_project, dummy_run_params, dummy_pipeline, dummy_catalog
):
# since mlflow>=2.18, the fluent API create a new run for each thread
# hence for thread runner we need to prefix the catalog with the run id

bootstrap_project(kedro_project)
with KedroSession.create(project_path=kedro_project) as session:
context = session.load_context() # triggers conf setup

mlflow_hook = MlflowHook()
mlflow_hook.after_context_created(context) # setup mlflow config

mlflow_hook.after_catalog_created(
catalog=dummy_catalog,
# `after_catalog_created` is not using any of below arguments,
# so we are setting them to empty values.
conf_catalog={},
conf_creds={},
feed_dict={},
save_version="",
load_versions="",
)

mlflow_hook.before_pipeline_run(
run_params=dummy_run_params, pipeline=dummy_pipeline, catalog=dummy_catalog
)

run_id = mlflow.active_run().info.run_id
# Check if metrics datasets have prefix with its names.
# for metric
assert dummy_catalog._datasets["model"].run_id == run_id


def test_mlflow_hook_log_artifacts_within_same_run_with_thread_runner(
kedro_project, dummy_run_params, dummy_pipeline, dummy_catalog
):
# this test is very specific to a new design introduced in mlflow 2.18 to make it htread safe
# see https://github.com/Galileo-Galilei/kedro-mlflow/issues/613
bootstrap_project(kedro_project)

with KedroSession.create(project_path=kedro_project) as session:
context = session.load_context() # setup mlflow

mlflow_hook = MlflowHook()
runner = ThreadRunner() # this is what we want to test

mlflow_hook.after_context_created(context)
mlflow_hook.after_catalog_created(
catalog=dummy_catalog,
# `after_catalog_created` is not using any of arguments bellow,
# so we are setting them to empty values.
conf_catalog={},
conf_creds={},
feed_dict={},
save_version="",
load_versions="",
)
mlflow_hook.before_pipeline_run(
run_params=dummy_run_params,
pipeline=dummy_pipeline,
catalog=dummy_catalog,
)

# we get the run id BEFORE running the pipeline because it was modified in different thread
run_id_before_run = mlflow.active_run().info.run_id

runner.run(dummy_pipeline, dummy_catalog, session._hook_manager)

run_id_after_run = mlflow.active_run().info.run_id

# CHECK 1: check that we are not on the second id created by the thread.lock()
assert run_id_before_run == run_id_after_run

mlflow_hook.after_pipeline_run(
run_params=dummy_run_params,
pipeline=dummy_pipeline,
catalog=dummy_catalog,
)

mlflow_client = context.mlflow.server._mlflow_client

# check that the artifact is assocaied to the initial run:

artifacts_list = mlflow_client.list_artifacts(run_id_before_run)
assert len(artifacts_list) == 1

0 comments on commit b41dddb

Please sign in to comment.