From b41dddbb877b123a15cdcbc1b806ef0f7aeddd59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yolan=20Honor=C3=A9-Roug=C3=A9?= Date: Fri, 13 Dec 2024 17:32:15 +0100 Subject: [PATCH] :ambulance: Ensure MlflowArtifactDataset logs in the same run when using mlflow>=2.18 and ThreadRunner --- CHANGELOG.md | 4 + kedro_mlflow/framework/hooks/mlflow_hook.py | 10 ++ .../io/artifacts/mlflow_artifact_dataset.py | 6 + .../add_run_id_to_artifact_datasets.py | 9 + tests/conftest.py | 3 +- .../framework/hooks/test_hook_log_artifact.py | 159 ++++++++++++++++++ 6 files changed, 189 insertions(+), 2 deletions(-) create mode 100644 kedro_mlflow/io/catalog/add_run_id_to_artifact_datasets.py create mode 100644 tests/framework/hooks/test_hook_log_artifact.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c8144230..a66b9a6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Fixed + +- :bug: :ambulance: Ensure `MlflowArtifactDataset` logs in the same run that parameters to when using ``mlflow>=2.18`` in combination with ``ThreadRunner`` [#613](https://github.com/Galileo-Galilei/kedro-mlflow/issues/613)) + ## [0.13.3] - 2024-10-29 ### Added diff --git a/kedro_mlflow/framework/hooks/mlflow_hook.py b/kedro_mlflow/framework/hooks/mlflow_hook.py index 7aed8ccb..9db324b8 100644 --- a/kedro_mlflow/framework/hooks/mlflow_hook.py +++ b/kedro_mlflow/framework/hooks/mlflow_hook.py @@ -27,6 +27,9 @@ _flatten_dict, _generate_kedro_command, ) +from kedro_mlflow.io.catalog.add_run_id_to_artifact_datasets import ( + add_run_id_to_artifact_datasets, +) from kedro_mlflow.io.catalog.switch_catalog_logging import switch_catalog_logging from kedro_mlflow.io.metrics import ( MlflowMetricDataset, @@ -270,6 +273,13 @@ def before_pipeline_run( pipeline_name=run_params["pipeline_name"], ), ) + + # This function ensures the run_id started at the beginning of the pipeline + # is associated to all the datasets. This is necessary because to make mlflow thread safe + # each call to the "active run" now creates a new run when started in a new thread. See + # https://github.com/Galileo-Galilei/kedro-mlflow/issues/613 and https://github.com/Galileo-Galilei/kedro-mlflow/pull/615 + add_run_id_to_artifact_datasets(catalog, mlflow.active_run().info.run_id) + else: self._logger.info( "kedro-mlflow logging is deactivated for this pipeline in the configuration. This includes DataSets and parameters." diff --git a/kedro_mlflow/io/artifacts/mlflow_artifact_dataset.py b/kedro_mlflow/io/artifacts/mlflow_artifact_dataset.py index 9e1d42f3..67e1fb97 100644 --- a/kedro_mlflow/io/artifacts/mlflow_artifact_dataset.py +++ b/kedro_mlflow/io/artifacts/mlflow_artifact_dataset.py @@ -171,3 +171,9 @@ def _describe(self) -> Dict[str, Any]: # pragma: no cover and consequently does not implements abtracts methods """ pass + + +def _is_instance_mlflow_artifact_dataset(dataset_instance): + parent_classname = dataset_instance.__class__.__bases__[0].__name__ + instance_classname = f"Mlflow{parent_classname}" + return type(dataset_instance).__name__ == instance_classname diff --git a/kedro_mlflow/io/catalog/add_run_id_to_artifact_datasets.py b/kedro_mlflow/io/catalog/add_run_id_to_artifact_datasets.py new file mode 100644 index 00000000..05f27351 --- /dev/null +++ b/kedro_mlflow/io/catalog/add_run_id_to_artifact_datasets.py @@ -0,0 +1,9 @@ +from kedro_mlflow.io.artifacts.mlflow_artifact_dataset import ( + _is_instance_mlflow_artifact_dataset, +) + + +def add_run_id_to_artifact_datasets(catalog, run_id: str): + for name, dataset in catalog._datasets.items(): + if _is_instance_mlflow_artifact_dataset(dataset): + catalog._datasets[name].run_id = run_id diff --git a/tests/conftest.py b/tests/conftest.py index 242a0e1a..10a9d3fa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,8 +29,7 @@ def mlflow_client(tracking_uri): @pytest.fixture(autouse=True) def cleanup_mlflow_after_runs(): - # A test function will be run at this point - yield + yield # A test function will be run at this point while mlflow.active_run(): mlflow.end_run() diff --git a/tests/framework/hooks/test_hook_log_artifact.py b/tests/framework/hooks/test_hook_log_artifact.py new file mode 100644 index 00000000..9cd32052 --- /dev/null +++ b/tests/framework/hooks/test_hook_log_artifact.py @@ -0,0 +1,159 @@ +import mlflow +import pandas as pd +import pytest +from kedro.framework.session import KedroSession +from kedro.framework.startup import bootstrap_project +from kedro.io import DataCatalog, MemoryDataset +from kedro.pipeline import Pipeline, node +from kedro.runner import ThreadRunner +from kedro_datasets.pickle import PickleDataset + +from kedro_mlflow.framework.hooks.mlflow_hook import MlflowHook +from kedro_mlflow.io.artifacts import MlflowArtifactDataset + + +@pytest.fixture +def dummy_pipeline(): + def preprocess_fun(data): + return data + + def train_fun(data): + return 2 + + dummy_pipeline = Pipeline( + [ + node( + func=preprocess_fun, + inputs="raw_data", + outputs="data", + ), + node( + func=train_fun, + inputs=["data"], + outputs="model", + ), + ] + ) + return dummy_pipeline + + +@pytest.fixture +def dummy_catalog(tmp_path): + dummy_catalog = DataCatalog( + { + "raw_data": MemoryDataset(pd.DataFrame(data=[1], columns=["a"])), + "data": MemoryDataset(), + "model": MlflowArtifactDataset( + dataset=dict( + type=PickleDataset, filepath=(tmp_path / "model.csv").as_posix() + ) + ), + } + ) + return dummy_catalog + + +@pytest.fixture +def dummy_run_params(tmp_path): + dummy_run_params = { + "project_path": tmp_path.as_posix(), + "env": "local", + "kedro_version": "0.16.0", + "tags": [], + "from_nodes": [], + "to_nodes": [], + "node_names": [], + "from_inputs": [], + "load_versions": [], + "pipeline_name": "my_cool_pipeline", + "extra_params": [], + } + return dummy_run_params + + +def test_mlflow_hook_automatically_update_artifact_run_id( + kedro_project, dummy_run_params, dummy_pipeline, dummy_catalog +): + # since mlflow>=2.18, the fluent API create a new run for each thread + # hence for thread runner we need to prefix the catalog with the run id + + bootstrap_project(kedro_project) + with KedroSession.create(project_path=kedro_project) as session: + context = session.load_context() # triggers conf setup + + mlflow_hook = MlflowHook() + mlflow_hook.after_context_created(context) # setup mlflow config + + mlflow_hook.after_catalog_created( + catalog=dummy_catalog, + # `after_catalog_created` is not using any of below arguments, + # so we are setting them to empty values. + conf_catalog={}, + conf_creds={}, + feed_dict={}, + save_version="", + load_versions="", + ) + + mlflow_hook.before_pipeline_run( + run_params=dummy_run_params, pipeline=dummy_pipeline, catalog=dummy_catalog + ) + + run_id = mlflow.active_run().info.run_id + # Check if metrics datasets have prefix with its names. + # for metric + assert dummy_catalog._datasets["model"].run_id == run_id + + +def test_mlflow_hook_log_artifacts_within_same_run_with_thread_runner( + kedro_project, dummy_run_params, dummy_pipeline, dummy_catalog +): + # this test is very specific to a new design introduced in mlflow 2.18 to make it htread safe + # see https://github.com/Galileo-Galilei/kedro-mlflow/issues/613 + bootstrap_project(kedro_project) + + with KedroSession.create(project_path=kedro_project) as session: + context = session.load_context() # setup mlflow + + mlflow_hook = MlflowHook() + runner = ThreadRunner() # this is what we want to test + + mlflow_hook.after_context_created(context) + mlflow_hook.after_catalog_created( + catalog=dummy_catalog, + # `after_catalog_created` is not using any of arguments bellow, + # so we are setting them to empty values. + conf_catalog={}, + conf_creds={}, + feed_dict={}, + save_version="", + load_versions="", + ) + mlflow_hook.before_pipeline_run( + run_params=dummy_run_params, + pipeline=dummy_pipeline, + catalog=dummy_catalog, + ) + + # we get the run id BEFORE running the pipeline because it was modified in different thread + run_id_before_run = mlflow.active_run().info.run_id + + runner.run(dummy_pipeline, dummy_catalog, session._hook_manager) + + run_id_after_run = mlflow.active_run().info.run_id + + # CHECK 1: check that we are not on the second id created by the thread.lock() + assert run_id_before_run == run_id_after_run + + mlflow_hook.after_pipeline_run( + run_params=dummy_run_params, + pipeline=dummy_pipeline, + catalog=dummy_catalog, + ) + + mlflow_client = context.mlflow.server._mlflow_client + + # check that the artifact is assocaied to the initial run: + + artifacts_list = mlflow_client.list_artifacts(run_id_before_run) + assert len(artifacts_list) == 1