diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 6903ddc2c..7c636181f 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -22,7 +22,7 @@ from blackjax.base import SamplingAlgorithm from blackjax.mcmc.integrators import IntegratorState, isokinetic_mclachlan from blackjax.types import ArrayLike, PRNGKey -from blackjax.util import generate_unit_vector +from blackjax.util import generate_unit_vector, pytree_size __all__ = ["MCLMCInfo", "init", "build_kernel", "mclmc"] @@ -45,6 +45,10 @@ class MCLMCInfo(NamedTuple): def init(position: ArrayLike, logdensity_fn, rng_key): + if pytree_size(position) < 2: + raise ValueError( + "The target distribution must have more than 1 dimension for MCLMC." + ) l, g = jax.value_and_grad(logdensity_fn)(position) return IntegratorState(