-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Subsampling for MCLMC tuning #738
Comments
Hi, I encountered this issue but found it more convenient to solve it by performing thinning, which btw Thinning is implemented in def run_inference_algorithm(..., thinning=1):
...
def one_sub_step(state, rng_key):
state, info = inference_algorithm.step(rng_key, state)
return state, info
def one_step(state, xs):
_, rng_key = xs
keys = jr.split(rng_key, thinning)
state, info = lax.scan(one_sub_step, state, keys)
return state, transform(state, info) and same in mclmc adaptation. This way we hit 2 with 1 stone. |
Thanks for this comment! Thinning makes sense, but I'd like to implement it as a kernel transformation, rather than complicating But yes, really either way would be fine. |
Separately, we've recently been interesting in using MCLMC for high dimensional problems, so would be interested in hearing about your experience. |
Agree. I looked at it a bit also, maybe this could simply look like: def make_thinned_kernel(kernel, thinning=1, info_transform=lambda x: x) -> Callable:
"""
Return a thinned version of kernel, that iterates `thinning` times before returning the state.
"""
def thinned_kernel(rng_key: PRNGKey, state: NamedTuple, *args, **kwargs) -> tuple[NamedTuple, NamedTuple]:
step = lambda state, rng_key: kernel(rng_key, state, *args, **kwargs)
keys = jr.split(rng_key, thinning)
state, info = lax.scan(step, state, keys)
return state, info_transform(info)
return thinned_kernel Applied to MCLMC adaptation, this would look like: kernel = lambda inverse_mass_matrix : make_thinned_kernel(
blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdf,
integrator=isokinetic_mclachlan,
inverse_mass_matrix=inverse_mass_matrix,
),
thinning=10,
info_transform=lambda info: tree.map(lambda x: (x**2).mean()**.5, info)
)
state, params, n_steps = blackjax.mclmc_find_L_and_step_size(mclmc_kernel=kernel, ...)
|
What would you like to know? We are currently benchmarking high dimensional sampling methods for large scale cosmological inference, and we find significant improvement in using MCLMC compared to NUTS et al. It was not so simple to tune at first as the energy fluctuation is quite sensible to preconditioning so we also work on that. |
Thanks for writing this out! I basically agree with this design.
See My only aversion to putting it in
Yes, this is not particularly principled, and could be improved. If we were in a language with rich static types, I think it would be much easier to enforce consistent practices, but Python is what it is, and I wasn't disciplined enough when adding MCLMC. I can't recall if there were specific reasons that e.g. Relatedly, there are many improvements that could be made to the code quality of the MCLMC implementations that I have postponed, so if you see things like this, it is indeed often the case that it's improvable.
Basically I was curious about the comparison to NUTS. From what we understand, a lot of the improvement in high dimensions comes from removing the MH step, so it would be interesting to know about results using the adjusted variant (especially with the omelyan integrator). We're currently interested in benchmarking adjusted and unadjusted MCLMC across high dimensional problems (> 10^4) in the sciences (LCQD, computational chemistry, BNNs, cosmology, etc), so I might reach out by email at some point to follow up. |
Current behavior
The third stage of tuning computes effective sample size using correlation length averaged across dimensions. For high dimensional problems, it might be better to subsample only n dimensions, for efficiency reasons.
Desired behavior
Have a default maximum number of dimensions n, and run a chain in the third stage of tuning that only returns a random choice of n dimensions.
The text was updated successfully, but these errors were encountered: