From d9f0c5140c681bd9d5214e3c8d552186016aac96 Mon Sep 17 00:00:00 2001 From: Will Dean <57733339+wd60622@users.noreply.github.com> Date: Thu, 23 Jan 2025 17:01:19 +0100 Subject: [PATCH] MLflow callback logging (#1425) * add logic * add test against logged metrics * add to the docstring * add to coverage --- pymc_marketing/mlflow.py | 125 +++++++++++++++++++++++++++++++++++++++ tests/test_mlflow.py | 35 +++++++++++ 2 files changed, 160 insertions(+) diff --git a/pymc_marketing/mlflow.py b/pymc_marketing/mlflow.py index 6a01418f..e605e610 100644 --- a/pymc_marketing/mlflow.py +++ b/pymc_marketing/mlflow.py @@ -180,6 +180,131 @@ warnings.warn(warning_msg, FutureWarning, stacklevel=1) +def _exclude_tuning(func): + def callback(trace, draw): + if draw.tuning: + return + + return func(trace, draw) + + return callback + + +def _take_every(n: int): + def decorator(func): + def callback(trace, draw): + if draw.draw_idx % n != 0: + return + + return func(trace, draw) + + return callback + + return decorator + + +def create_log_callback( + stats: list[str] | None = None, + parameters: list[str] | None = None, + exclude_tuning: bool = True, + take_every: int = 100, +): + """Create callback function to log sample stats and parameter values to MLflow during sampling. + + This callback only works for the "pymc" sampler. + + Parameters + ---------- + stats : list of str, optional + List of sample statistics to log from the Draw + parameters : list of str, optional + List of parameters to log from the Draw + exclude_tuning : bool, optional + Whether to exclude tuning steps from logging. Defaults to True. + + Returns + ------- + callback : Callable + The callback function to log sample stats and parameter values to MLflow during sampling + + Examples + -------- + Create example model: + + .. code-block:: python + + import pymc as pm + + with pm.Model() as model: + mu = pm.Normal("mu") + sigma = pm.HalfNormal("sigma") + obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=[1, 2, 3]) + + Log off divergences and logp every 100th draw: + + .. code-block:: python + + import mlflow + + from pymc_marketing.mlflow import create_log_callback + + callback = create_log_callback( + stats=["diverging", "model_logp"], + take_every=100, + ) + + mlflow.set_experiment("Live Tracking Stats") + + with mlflow.start_run(): + idata = pm.sample(model=model, callback=callback) + + Log the parameters `mu` and `sigma_log__` every 100th draw: + + .. code-block:: python + + import mlflow + + from pymc_marketing.mlflow import create_log_callback + + callback = create_log_callback( + parameters=["mu", "sigma_log__"], + take_every=100, + ) + + mlflow.set_experiment("Live Tracking Parameters") + + with mlflow.start_run(): + idata = pm.sample(model=model, callback=callback) + + """ + if not stats and not parameters: + raise ValueError("At least one of `stats` or `parameters` must be provided.") + + def callback(_, draw): + prefix = f"chain_{draw.chain}" + for stat in stats or []: + mlflow.log_metric( + key=f"{prefix}/{stat}", + value=draw.stats[0][stat], + step=draw.draw_idx, + ) + + for parameter in parameters or []: + mlflow.log_metric( + key=f"{prefix}/{parameter}", + value=draw.point[parameter], + step=draw.draw_idx, + ) + + if exclude_tuning: + callback = _exclude_tuning(callback) + + if take_every: + callback = _take_every(n=take_every)(callback) + + return callback + + def _log_and_remove_artifact(path: str | Path) -> None: """Log an artifact to MLflow and then remove the local file. diff --git a/tests/test_mlflow.py b/tests/test_mlflow.py index 1cd63ff6..065f26d6 100644 --- a/tests/test_mlflow.py +++ b/tests/test_mlflow.py @@ -28,6 +28,7 @@ from pymc_marketing.clv import BetaGeoModel from pymc_marketing.mlflow import ( autolog, + create_log_callback, log_likelihood_type, log_mmm_evaluation_metrics, log_model_graph, @@ -674,3 +675,37 @@ def test_log_mmm_evaluation_metrics() -> None: assert set(run_data.metrics.keys()) == expected_metrics assert all(isinstance(value, float) for value in run_data.metrics.values()) + + +def test_callback_raises() -> None: + match = "At least one of" + with pytest.raises(ValueError, match=match): + create_log_callback() + + +def test_logging_callback(model_with_likelihood) -> None: + mlflow.set_experiment("pymc-marketing-test-suite-logging-callback") + + callback = create_log_callback( + stats=["energy"], + parameters=["mu"], + take_every=10, + ) + with mlflow.start_run() as run: + pm.sample( + model=model_with_likelihood, + draws=100, + tune=1, + chains=2, + callback=callback, + ) + + assert mlflow.active_run() is None + + run_id = run.info.run_id + client = MlflowClient() + + for chain in [0, 1]: + for value in ["energy", "mu"]: + history = client.get_metric_history(run_id, f"chain_{chain}/{value}") + assert len(history) == 10