-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Current MEADS implementation is incomplete #781
Comments
@albcab thoughts? |
Hi @alexlyttle, you are right, the current adaptation routine is not exactly the one used in the paper. It was a long time ago, but I think at the end I wanted it to be an adaptation algorithm. MEADS is an adaptive algorithm (as in real adaptive MCMC, which you should never stop adapting and you are sure it will converge to your distribution). Instead, adaptation algorithms are meant to stop adapting and then fix a kernel that you use to generate your samples, a kernel that doesn't change, so you don't need to worry about changing hyperparameters ruining convergence results. When I wrote the original code, adaptation algorithms would run for a specified number of steps and would return a kernel (GHMC in this case) to fix and use (here is the original code and discussion). Real MEADS would've returned Anyway, API has changed, adaptation algorithms have stepping functions and adaptive algorithms can more easily be implemented. So, by all means, go ahead and share/pull your complete adaptive MEADS! |
Thanks @albcab, super informative! My follow-up question would have been about how the paper presents MEADS as an adaptive algorithm rather than something where you'd need to freeze the kernel. This makes sense though, I can see your reasoning in the original PR. I am happy to share my version and contribute if it's welcome. Practically, should this be added as its own MCMC algorithm in BlackJAX? My understanding of the paper is that MEADS is to GHMC what NUTS is to HMC (in the sense that it is adaptive). |
Current behavior
The MEADS adaptation routine appears to be incomplete. Currently, cross-chain statistics are computed each iteration and used to update the kernel parameters for the entire chain. This is missing some aspects of Algorithm 3 from Hoffman & Sountsov (2022). Perhaps this was on purpose, in which case I would be interested to know why.
Desired behavior
Hoffman & Sountsov (2022) describe an algorithm where the chains are split into$K$ folds. Cross-chains statistics are computed within each fold and used to update the neighbouring fold each iteration (skipping the fold equal to the current iteration modulo $K$ ). It also describes a shuffling of all chains every $K$ steps. It appears the original author implemented the algorithm in this notebook. I have recently experimented with modifying the BlackJAX MEADS to reflect this for a project testing new MCMC adaptation algorithms.
I propose updating the existing implementation to include the$K$ -folding and shuffling described in the paper. This would introduce a few more parameters to
blackjax.meads_adaptation
which could take default values from the paper.The
step_size_multiplier
anddamping_slowdown
are hyper-parameters used in calculating the MEADS statistics.The text was updated successfully, but these errors were encountered: