Skip to content

Commit

Permalink
✨ Enable passing kwargs to create_experiment from configuration, incl…
Browse files Browse the repository at this point in the history
…uding artifact_location (#557)
  • Loading branch information
Galileo-Galilei committed Feb 18, 2025
1 parent 0f15f2e commit 4311e34
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 34 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

### Added

- :sparkles: Add the ``tracking.disable_tracking.disable_autologging`` configuration option in ``mlflow.yml ``to disable autologging by default. This simplify the workflow for Databricks users who have autologging activated by default, which conflicts with the plugin([[#610](https://github.com/Galileo-Galilei/kedro-mlflow/issues/610)]).
- :sparkles: Add the ``tracking.disable_tracking.disable_autologging`` configuration option in ``mlflow.yml `` to disable autologging by default. This simplify the workflow for Databricks users who have autologging activated by default, which conflicts with ``kedro-mlflow`` ([[#610](https://github.com/Galileo-Galilei/kedro-mlflow/issues/610)]).
- :sparkles: Add ``tracking.experiment.create_experiment_kwargs.artifact_location`` and ``tracking.experiment.create_experiment_kwargs.tags`` configuration options in ``mlflow.yml `` to enable advanced configuration of mlflow experiment created at runtime by ``kedro-mlflow`` ([[#557](https://github.com/Galileo-Galilei/kedro-mlflow/issues/557)]).

## [0.14.3] - 2025-02-17

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,21 @@ tracking:
# in a new mlflow run

disable_tracking:
disable_autologging: True # If True, we force autologging to be disabled. This is useful on databricks with autologging by default which conflicts with the plugin. If False, we keep the default behaviour which is disable by default anayway.
pipelines: []

experiment:
name: {{ python_package }}
create_experiment_kwargs: # will be used only if the experiment does not exist yet and is created.
artifact_location: null # enable to specify an artifact location for the experiment different than the global one for the mlflow server
tags: null # a dict of tags for the experiment
restore_if_deleted: True # if the experiment`name` was previously deleted experiment, should we restore it?

run:
id: null # if `id` is None, a new run will be created
name: null # if `name` is None, pipeline name will be used for the run name. You can use "${km.random_name:}" to generate a random name (mlflow's default)
nested: True # if `nested` is False, you won't be able to launch sub-runs inside your nodes

params:
dict_params:
flatten: False # if True, parameter which are dictionary will be splitted in multiple parameters when logged in mlflow, one for each key.
Expand Down Expand Up @@ -175,12 +180,13 @@ Mlflow enable the user to create "experiments" to organize his work. The differe
```yaml
tracking:
experiment:
name: <your-experiment-name> # by default, the name of your python package in your kedro project
name: {{ python_package }}
create_experiment_kwargs: # will be used only if the experiment does not exist yet and is created.
artifact_location: null # enable to specify an artifact location for the experiment different than the global one for the mlflow server
tags: null # a dict of tags for the experiment
restore_if_deleted: True # if the experiment`name` was previously deleted experiment, should we restore it?
```

Note that by default, mlflow crashes if you try to start a run while you have not created the experiment first. `kedro-mlflow` has a `create` key (`True` by default) which forces the creation of the experiment if it does not exist. Set it to `False` to match mlflow default value.

### Configure the run

When you launch a new `kedro` run, `kedro-mlflow` instantiates an underlying `mlflow` run through the hooks. By default, we assume the user want to launch each kedro run in separated mlflow run to keep a one to one relationship between kedro runs and mlflow runs. However, one may need to *continue* an existing mlflow run (for instance, because you resume the kedro run from a later starting point of your pipeline).
Expand All @@ -191,8 +197,8 @@ The `mlflow.yml` accepts the following keys:
tracking:
run:
id: null # if `id` is None, a new run will be created
name: null # if `name` is None, pipeline name will be used for the run name
nested: True # # if `nested` is False, you won't be able to launch sub-runs inside your nodes
name: null # if `name` is None, pipeline name will be used for the run name. You can use "${km.random_name:}" to generate a random name (mlflow's default)
nested: True # if `nested` is False, you won't be able to launch sub-runs inside your nodes
```
```{tip}
Expand Down
58 changes: 42 additions & 16 deletions kedro_mlflow/config/kedro_mlflow_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from logging import getLogger
from pathlib import Path, PurePath
from typing import Dict, List, Optional
from typing import Optional
from urllib.parse import urlparse

import mlflow
Expand All @@ -23,7 +23,7 @@ class RequestHeaderProviderOptions(BaseModel):
# mutable default is ok for pydantic : https://stackoverflow.com/questions/63793662/how-to-give-a-pydantic-list-field-a-default-value
type: Optional[str] = None
pass_context: bool = False
init_kwargs: Dict[str, str] = {}
init_kwargs: dict[str, str] = {}

class Config:
extra = "forbid"
Expand All @@ -47,15 +47,24 @@ class Config:

class DisableTrackingOptions(BaseModel):
# mutable default is ok for pydantic : https://stackoverflow.com/questions/63793662/how-to-give-a-pydantic-list-field-a-default-value
pipelines: List[str] = []
pipelines: list[str] = []
disable_autologging: bool = True

class Config:
extra = "forbid"


class CreateExperimentOptions(BaseModel):
artifact_location: Optional[str] = None
tags: Optional[dict] = None

class Config:
extra = "forbid"


class ExperimentOptions(BaseModel):
name: str = "Default"
create_experiment_kwargs: CreateExperimentOptions = CreateExperimentOptions()
restore_if_deleted: StrictBool = True
_experiment: Experiment = PrivateAttr()
# do not create _experiment immediately to avoid creating
Expand All @@ -75,7 +84,7 @@ class Config:
extra = "forbid"


class DictParamsOptions(BaseModel):
class dictParamsOptions(BaseModel):
flatten: StrictBool = False
recursive: StrictBool = True
sep: str = "."
Expand All @@ -85,7 +94,7 @@ class Config:


class MlflowParamsOptions(BaseModel):
dict_params: DictParamsOptions = DictParamsOptions()
dict_params: dictParamsOptions = dictParamsOptions()
long_params_strategy: Literal["fail", "truncate", "tag"] = "fail"

class Config:
Expand Down Expand Up @@ -157,6 +166,16 @@ def setup(self, context):
mlflow.set_tracking_uri(self.server.mlflow_tracking_uri)
mlflow.set_registry_uri(self.server.mlflow_registry_uri)

# before we set the experiment, ensure it is a valid uri
if (
self.tracking.experiment.create_experiment_kwargs.artifact_location
is not None
):
self.tracking.experiment.create_experiment_kwargs.artifact_location = _validate_uri(
project_path=context.project_path,
uri=self.tracking.experiment.create_experiment_kwargs.artifact_location,
)

self._set_experiment()

if self.tracking.disable_tracking.disable_autologging is True:
Expand Down Expand Up @@ -206,23 +225,30 @@ def _set_experiment(self):
Returns:
mlflow.entities.Experiment -- [description]
"""
# we retrieve the experiment manually to check if it exsits
# we retrieve the experiment manually to check if it exists
mlflow_experiment = self.server._mlflow_client.get_experiment_by_name(
name=self.tracking.experiment.name
)
# Deal with two side case when retrieving the experiment
if mlflow_experiment is not None:
if (
self.tracking.experiment.restore_if_deleted
and mlflow_experiment.lifecycle_stage == "deleted"
):
# the experiment was created, then deleted : we have to restore it manually before setting it as the active one
self.server._mlflow_client.restore_experiment(
mlflow_experiment.experiment_id
)
if mlflow_experiment is None:
# we create the experiment if it does not exists
self.server._mlflow_client.create_experiment(
name=self.tracking.experiment.name,
artifact_location=self.tracking.experiment.create_experiment_kwargs.artifact_location,
tags=self.tracking.experiment.create_experiment_kwargs.tags,
)
elif (
self.tracking.experiment.restore_if_deleted
and mlflow_experiment.lifecycle_stage == "deleted"
):
# the experiment was created, then deleted : we have to restore it manually before setting it as the active one
self.server._mlflow_client.restore_experiment(
mlflow_experiment.experiment_id
)

# this creates the experiment if it does not exists
# and creates a global variable with the experiment
# but it has been done by prvious 'else' to have advanced control at creation
# It creates a global variable with the experiment
# but returns nothing
mlflow.set_experiment(experiment_name=self.tracking.experiment.name)

Expand Down
3 changes: 3 additions & 0 deletions kedro_mlflow/template/project/mlflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ tracking:

experiment:
name: {{ python_package }}
create_experiment_kwargs: # will be used only if the experiment does not exist yet and is created.
artifact_location: null # enable to specify an artifact location for the experiment different than the global one for the mlflow server
tags: null # a dict of tags for the experiment
restore_if_deleted: True # if the experiment`name` was previously deleted experiment, should we restore it?

run:
Expand Down
54 changes: 45 additions & 9 deletions tests/config/test_get_mlflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ def test_mlflow_config_default(kedro_project):
pipelines=["my_disabled_pipeline"],
disable_autologging=True,
),
experiment=dict(name="fake_package", restore_if_deleted=True),
experiment=dict(
name="fake_package",
create_experiment_kwargs=dict(artifact_location=None, tags=None),
restore_if_deleted=True,
),
run=dict(id="123456789", name="my_run", nested=True),
params=dict(
dict_params=dict(
Expand Down Expand Up @@ -81,7 +85,11 @@ def test_mlflow_config_in_uninitialized_project(kedro_project, package_name):
pipelines=[],
disable_autologging=True,
),
experiment=dict(name="fake_project", restore_if_deleted=True),
experiment=dict(
name="fake_project",
create_experiment_kwargs=dict(artifact_location=None, tags=None),
restore_if_deleted=True,
),
run=dict(id=None, name=None, nested=True),
params=dict(
dict_params=dict(flatten=False, recursive=True, sep="."),
Expand Down Expand Up @@ -111,7 +119,11 @@ def test_mlflow_config_with_no_experiment_name(kedro_project):
pipelines=[],
disable_autologging=True,
),
experiment=dict(name="fake_project", restore_if_deleted=True),
experiment=dict(
name="fake_project",
create_experiment_kwargs=dict(artifact_location=None, tags=None),
restore_if_deleted=True,
),
run=dict(id=None, name=None, nested=True),
params=dict(
dict_params=dict(flatten=False, recursive=True, sep="."),
Expand Down Expand Up @@ -225,7 +237,11 @@ def test_mlflow_config_correctly_set(kedro_project, project_settings):
),
tracking=dict(
disable_tracking=dict(pipelines=[], disable_autologging=True),
experiment=dict(name="fake_project", restore_if_deleted=True),
experiment=dict(
name="fake_project",
create_experiment_kwargs=dict(artifact_location=None, tags=None),
restore_if_deleted=True,
),
run=dict(id=None, name=None, nested=True),
params=dict(
dict_params=dict(flatten=False, recursive=True, sep="."),
Expand All @@ -251,7 +267,11 @@ def test_mlflow_config_interpolated_with_globals_resolver(monkeypatch, fake_proj
pipelines=["my_disabled_pipeline"],
disable_autologging=True,
),
experiment=dict(name="fake_package", restore_if_deleted=True),
experiment=dict(
name="fake_package",
create_experiment_kwargs=dict(artifact_location=None, tags=None),
restore_if_deleted=True,
),
run=dict(id="123456789", name="my_run", nested=True),
params=dict(
dict_params=dict(
Expand Down Expand Up @@ -318,7 +338,11 @@ def request_headers(self):
),
tracking=dict(
disable_tracking=dict(pipelines=[]),
experiment=dict(name="Default", restore_if_deleted=True),
experiment=dict(
name="Default",
create_experiment_kwargs=dict(artifact_location=None, tags=None),
restore_if_deleted=True,
),
run=dict(id=None, name=None, nested=True),
params=dict(
dict_params=dict(
Expand Down Expand Up @@ -377,7 +401,11 @@ def request_headers(self):
),
tracking=dict(
disable_tracking=dict(pipelines=[]),
experiment=dict(name="Default", restore_if_deleted=True),
experiment=dict(
name="Default",
create_experiment_kwargs=dict(artifact_location=None, tags=None),
restore_if_deleted=True,
),
run=dict(id=None, name=None, nested=True),
params=dict(
dict_params=dict(
Expand Down Expand Up @@ -440,7 +468,11 @@ def request_headers(self):
),
tracking=dict(
disable_tracking=dict(pipelines=[]),
experiment=dict(name="Default", restore_if_deleted=True),
experiment=dict(
name="Default",
create_experiment_kwargs=dict(artifact_location=None, tags=None),
restore_if_deleted=True,
),
run=dict(id=None, name=None, nested=True),
params=dict(
dict_params=dict(
Expand Down Expand Up @@ -495,7 +527,11 @@ def request_headers(self):
),
tracking=dict(
disable_tracking=dict(pipelines=[]),
experiment=dict(name="Default", restore_if_deleted=True),
experiment=dict(
name="Default",
create_experiment_kwargs=dict(artifact_location=None, tags=None),
restore_if_deleted=True,
),
run=dict(id=None, name=None, nested=True),
params=dict(
dict_params=dict(
Expand Down
39 changes: 37 additions & 2 deletions tests/config/test_kedro_mlflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ def test_kedro_mlflow_config_init():
),
tracking=dict(
disable_tracking=dict(pipelines=[], disable_autologging=True),
experiment=dict(name="Default", restore_if_deleted=True),
experiment=dict(
name="Default",
create_experiment_kwargs=dict(artifact_location=None, tags=None),
restore_if_deleted=True,
),
run=dict(id=None, name=None, nested=True),
params=dict(
dict_params=dict(
Expand All @@ -35,7 +39,7 @@ def test_kedro_mlflow_config_init():
)


def test_kedro_mlflow_config_new_experiment_does_not_exists(
def test_kedro_mlflow_config_new_experiment_is_created_if_does_not_exists(
kedro_project_with_mlflow_conf,
):
config = KedroMlflowConfig(
Expand All @@ -53,6 +57,37 @@ def test_kedro_mlflow_config_new_experiment_does_not_exists(
]


def test_kedro_mlflow_config_new_experiment_is_created_with_kwargs(
kedro_project_with_mlflow_conf,
):
config = KedroMlflowConfig(
server=dict(mlflow_tracking_uri="mlruns"),
tracking=dict(
experiment=dict(
name="exp_custom_artifact_location",
create_experiment_kwargs=dict(
artifact_location="mlruns_custom", tags={"my_tag": "my_value"}
),
)
),
)

bootstrap_project(kedro_project_with_mlflow_conf)
with KedroSession.create(project_path=kedro_project_with_mlflow_conf) as session:
context = session.load_context() # setup config
config.setup(context)

experiment = config.server._mlflow_client.get_experiment_by_name(
"exp_custom_artifact_location"
)
assert experiment is not None
assert (
experiment.artifact_location
== (kedro_project_with_mlflow_conf / "mlruns_custom").as_uri()
)
assert experiment.tags == {"my_tag": "my_value"}


def test_kedro_mlflow_config_with_use_env_tracking_uri(
kedro_project_with_mlflow_conf,
):
Expand Down
6 changes: 5 additions & 1 deletion tests/template/project/test_mlflow_yml.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ def test_mlflow_yml_rendering(template_mlflowyml):
project_path="fake/path",
tracking=dict(
disable_tracking=dict(pipelines=[], disable_autologging=True),
experiment=dict(name="fake_project", restore_if_deleted=True),
experiment=dict(
name="fake_project",
create_experiment_kwargs=dict(artifact_location=None, tags=None),
restore_if_deleted=True,
),
params=dict(
dict_params=dict(flatten=False, recursive=True, sep="."),
long_params_strategy="fail",
Expand Down

0 comments on commit 4311e34

Please sign in to comment.