Skip to content

Commit

Permalink
MLflow callback logging (#1425)
Browse files Browse the repository at this point in the history
* add logic

* add test against logged metrics

* add to the docstring

* add to coverage
  • Loading branch information
wd60622 authored Jan 23, 2025
1 parent 2ef7b51 commit d9f0c51
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 0 deletions.
125 changes: 125 additions & 0 deletions pymc_marketing/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,131 @@
warnings.warn(warning_msg, FutureWarning, stacklevel=1)


def _exclude_tuning(func):
def callback(trace, draw):
if draw.tuning:
return

return func(trace, draw)

return callback


def _take_every(n: int):
def decorator(func):
def callback(trace, draw):
if draw.draw_idx % n != 0:
return

return func(trace, draw)

return callback

return decorator


def create_log_callback(
stats: list[str] | None = None,
parameters: list[str] | None = None,
exclude_tuning: bool = True,
take_every: int = 100,
):
"""Create callback function to log sample stats and parameter values to MLflow during sampling.
This callback only works for the "pymc" sampler.
Parameters
----------
stats : list of str, optional
List of sample statistics to log from the Draw
parameters : list of str, optional
List of parameters to log from the Draw
exclude_tuning : bool, optional
Whether to exclude tuning steps from logging. Defaults to True.
Returns
-------
callback : Callable
The callback function to log sample stats and parameter values to MLflow during sampling
Examples
--------
Create example model:
.. code-block:: python
import pymc as pm
with pm.Model() as model:
mu = pm.Normal("mu")
sigma = pm.HalfNormal("sigma")
obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=[1, 2, 3])
Log off divergences and logp every 100th draw:
.. code-block:: python
import mlflow
from pymc_marketing.mlflow import create_log_callback
callback = create_log_callback(
stats=["diverging", "model_logp"],
take_every=100,
)
mlflow.set_experiment("Live Tracking Stats")
with mlflow.start_run():
idata = pm.sample(model=model, callback=callback)
Log the parameters `mu` and `sigma_log__` every 100th draw:
.. code-block:: python
import mlflow
from pymc_marketing.mlflow import create_log_callback
callback = create_log_callback(
parameters=["mu", "sigma_log__"],
take_every=100,
)
mlflow.set_experiment("Live Tracking Parameters")
with mlflow.start_run():
idata = pm.sample(model=model, callback=callback)
"""
if not stats and not parameters:
raise ValueError("At least one of `stats` or `parameters` must be provided.")

def callback(_, draw):
prefix = f"chain_{draw.chain}"
for stat in stats or []:
mlflow.log_metric(
key=f"{prefix}/{stat}",
value=draw.stats[0][stat],
step=draw.draw_idx,
)

for parameter in parameters or []:
mlflow.log_metric(
key=f"{prefix}/{parameter}",
value=draw.point[parameter],
step=draw.draw_idx,
)

if exclude_tuning:
callback = _exclude_tuning(callback)

if take_every:
callback = _take_every(n=take_every)(callback)

return callback


def _log_and_remove_artifact(path: str | Path) -> None:
"""Log an artifact to MLflow and then remove the local file.
Expand Down
35 changes: 35 additions & 0 deletions tests/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pymc_marketing.clv import BetaGeoModel
from pymc_marketing.mlflow import (
autolog,
create_log_callback,
log_likelihood_type,
log_mmm_evaluation_metrics,
log_model_graph,
Expand Down Expand Up @@ -674,3 +675,37 @@ def test_log_mmm_evaluation_metrics() -> None:
assert set(run_data.metrics.keys()) == expected_metrics

assert all(isinstance(value, float) for value in run_data.metrics.values())


def test_callback_raises() -> None:
match = "At least one of"
with pytest.raises(ValueError, match=match):
create_log_callback()


def test_logging_callback(model_with_likelihood) -> None:
mlflow.set_experiment("pymc-marketing-test-suite-logging-callback")

callback = create_log_callback(
stats=["energy"],
parameters=["mu"],
take_every=10,
)
with mlflow.start_run() as run:
pm.sample(
model=model_with_likelihood,
draws=100,
tune=1,
chains=2,
callback=callback,
)

assert mlflow.active_run() is None

run_id = run.info.run_id
client = MlflowClient()

for chain in [0, 1]:
for value in ["energy", "mu"]:
history = client.get_metric_history(run_id, f"chain_{chain}/{value}")
assert len(history) == 10

0 comments on commit d9f0c51

Please sign in to comment.