diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 27b5c2e9c..e7a69849b 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -80,7 +80,9 @@ def build_kernel(logdensity_fn, sqrt_diag_cov, integrator): """ - step = with_isokinetic_maruyama(integrator(logdensity_fn, sqrt_diag_cov)) + step = with_isokinetic_maruyama( + integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + ) def kernel( rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float