Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Current MEADS implementation is incomplete #781

Open
alexlyttle opened this issue Feb 26, 2025 · 3 comments
Open

Current MEADS implementation is incomplete #781

alexlyttle opened this issue Feb 26, 2025 · 3 comments
Assignees

Comments

@alexlyttle
Copy link

Current behavior

The MEADS adaptation routine appears to be incomplete. Currently, cross-chain statistics are computed each iteration and used to update the kernel parameters for the entire chain. This is missing some aspects of Algorithm 3 from Hoffman & Sountsov (2022). Perhaps this was on purpose, in which case I would be interested to know why.

meads = blackjax.meads_adaptation(logdensity_fn, num_chains)

Desired behavior

Hoffman & Sountsov (2022) describe an algorithm where the chains are split into $K$ folds. Cross-chains statistics are computed within each fold and used to update the neighbouring fold each iteration (skipping the fold equal to the current iteration modulo $K$). It also describes a shuffling of all chains every $K$ steps. It appears the original author implemented the algorithm in this notebook. I have recently experimented with modifying the BlackJAX MEADS to reflect this for a project testing new MCMC adaptation algorithms.

I propose updating the existing implementation to include the $K$-folding and shuffling described in the paper. This would introduce a few more parameters to blackjax.meads_adaptation which could take default values from the paper.

meads = blackjax.meads_adaptation(logdensity_fn, num_chains, num_folds=4, shuffle=True, step_size_multiplier=0.5, damping_slowdown=1.0)

The step_size_multiplier and damping_slowdown are hyper-parameters used in calculating the MEADS statistics.

@junpenglao
Copy link
Member

@albcab thoughts?

@albcab
Copy link
Member

albcab commented Feb 26, 2025

Hi @alexlyttle,

you are right, the current adaptation routine is not exactly the one used in the paper. It was a long time ago, but I think at the end I wanted it to be an adaptation algorithm. MEADS is an adaptive algorithm (as in real adaptive MCMC, which you should never stop adapting and you are sure it will converge to your distribution). Instead, adaptation algorithms are meant to stop adapting and then fix a kernel that you use to generate your samples, a kernel that doesn't change, so you don't need to worry about changing hyperparameters ruining convergence results.

When I wrote the original code, adaptation algorithms would run for a specified number of steps and would return a kernel (GHMC in this case) to fix and use (here is the original code and discussion). Real MEADS would've returned $K$ kernels, which was an API problem. I also opened a pull request to make it more adaptive-ish here, but it was not merged, can't remember why.

Anyway, API has changed, adaptation algorithms have stepping functions and adaptive algorithms can more easily be implemented. So, by all means, go ahead and share/pull your complete adaptive MEADS!

@alexlyttle
Copy link
Author

alexlyttle commented Feb 27, 2025

Thanks @albcab, super informative! My follow-up question would have been about how the paper presents MEADS as an adaptive algorithm rather than something where you'd need to freeze the kernel. This makes sense though, I can see your reasoning in the original PR.

I am happy to share my version and contribute if it's welcome.

Practically, should this be added as its own MCMC algorithm in BlackJAX? My understanding of the paper is that MEADS is to GHMC what NUTS is to HMC (in the sense that it is adaptive).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants