Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subsampling for MCLMC tuning #738

Open
reubenharry opened this issue Sep 18, 2024 · 6 comments
Open

Subsampling for MCLMC tuning #738

reubenharry opened this issue Sep 18, 2024 · 6 comments

Comments

@reubenharry
Copy link
Contributor

Current behavior

The third stage of tuning computes effective sample size using correlation length averaged across dimensions. For high dimensional problems, it might be better to subsample only n dimensions, for efficiency reasons.

Desired behavior

Have a default maximum number of dimensions n, and run a chain in the third stage of tuning that only returns a random choice of n dimensions.

@hsimonfroy
Copy link
Contributor

hsimonfroy commented Feb 15, 2025

Hi,

I encountered this issue but found it more convenient to solve it by performing thinning, which btw run_inference_algorithm API does not implement yet. In my high-dimensional (>10^7) use case, it is not relevant to scan and return every MCMC steps, while having 500 steps / ESS. In such case, the ESS estimation would not change much by discarding 99 steps out of 100.

Thinning is implemented in numpyro MCMC API. In blackjax, this can simply be done with

def run_inference_algorithm(..., thinning=1):

    ...

    def one_sub_step(state, rng_key):
        state, info = inference_algorithm.step(rng_key, state)
        return state, info
    
    def one_step(state, xs):
        _, rng_key = xs
        keys = jr.split(rng_key, thinning)
        state, info = lax.scan(one_sub_step, state, keys)
        return state, transform(state, info)

and same in mclmc adaptation. This way we hit 2 with 1 stone.

@reubenharry
Copy link
Contributor Author

reubenharry commented Feb 16, 2025

Thanks for this comment! Thinning makes sense, but I'd like to implement it as a kernel transformation, rather than complicating run_inference_algorithm. That is, make_thinned_kernel would take a kernel, and return a new kernel that does the appropriate thing. This feels more functional, and I have it on branch somewhere, as a draft.

But yes, really either way would be fine.

@reubenharry
Copy link
Contributor Author

Separately, we've recently been interesting in using MCLMC for high dimensional problems, so would be interested in hearing about your experience.

@hsimonfroy
Copy link
Contributor

make_thinned_kernel would take a kernel, and return a new kernel that does the appropriate thing.

Agree. I looked at it a bit also, maybe this could simply look like:

def make_thinned_kernel(kernel, thinning=1, info_transform=lambda x: x) -> Callable:
    """
    Return a thinned version of kernel, that iterates `thinning` times before returning the state.
    """

    def thinned_kernel(rng_key: PRNGKey, state: NamedTuple, *args, **kwargs) -> tuple[NamedTuple, NamedTuple]:

        step = lambda state, rng_key: kernel(rng_key, state, *args, **kwargs)
        keys = jr.split(rng_key, thinning)
        state, info = lax.scan(step, state, keys)
        return state, info_transform(info)

    return thinned_kernel

Applied to MCLMC adaptation, this would look like:

    kernel = lambda inverse_mass_matrix : make_thinned_kernel(
        blackjax.mcmc.mclmc.build_kernel(
                            logdensity_fn=logdf,
                            integrator=isokinetic_mclachlan,
                            inverse_mass_matrix=inverse_mass_matrix,
                            ), 
    thinning=10, 
    info_transform=lambda info: tree.map(lambda x: (x**2).mean()**.5, info)
    )

    state, params, n_steps = blackjax.mclmc_find_L_and_step_size(mclmc_kernel=kernel, ...)
  • The main point is that while thinning is what it is for state (returning last), its desired impact on info can vary. For instance, in the case of controlling energy fluctuation, we have to aggregate it among the thinning steps (taking RMS of energy_change). This is why we could allow an info_transform parameter.
  • Concerning the top level API, I dont' know how this kernel transform should be implemented. A thinning parameters in as_top_level_api?
  • Btw concerning lower level API, I can't help but to be a bit confused by the heterogeneity among samplers on what is considered a build_kernel parameter or a kernel parameter. For instance, NUTS and HMC consider inverse_mass_matrix as a kernel parameter (makes sense to me), but this is not the case MCLMC, adjusted-MCLMC, nor adjusted-MCLMC-dynamic, such that we actually have to use not the kernel but the callable:
    lambda inverse_mass_matrix : build_kernel(inverse_mass_matrix, ...)
    On the contrary, MCLMC and adjusted-MCLMC consider logdensity_fn as a "kernel building" parameter, but this is not the case for NUTS, HMC, nor adjusted-MCLMC-dynamic. I guess this is hidden by the top level API, but whenever some tuning is required, this complicates a bit juggling with samplers.

@hsimonfroy
Copy link
Contributor

Separately, we've recently been interesting in using MCLMC for high dimensional problems, so would be interested in hearing about your experience.

What would you like to know? We are currently benchmarking high dimensional sampling methods for large scale cosmological inference, and we find significant improvement in using MCLMC compared to NUTS et al. It was not so simple to tune at first as the energy fluctuation is quite sensible to preconditioning so we also work on that.

@reubenharry
Copy link
Contributor Author

Thanks for writing this out!

I basically agree with this design.

Concerning the top level API, I dont' know how this kernel transform should be implemented. A thinning parameters in as_top_level_api?

See store_only_expectation_values in util.py, which is a kernel transform which takes a kernel and produces a new kernel that returns a running average. I'm envisioning thinning_kernel living here (or even as an extension of store_only_expectation_values if you need a running average).

My only aversion to putting it in as_top_level_api is that we'd want to do this for every sampler separately.

Btw concerning lower level API, I can't help but to be a bit confused by the heterogeneity among samplers on what is considered a build_kernel parameter or a kernel parameter. For instance, NUTS and HMC consider inverse_mass_matrix as a kernel parameter (makes sense to me), but this is not the case MCLMC, adjusted-MCLMC, nor adjusted-MCLMC-dynamic, such that we actually have to use not the kernel but the callable

Yes, this is not particularly principled, and could be improved. If we were in a language with rich static types, I think it would be much easier to enforce consistent practices, but Python is what it is, and I wasn't disciplined enough when adding MCLMC. I can't recall if there were specific reasons that e.g. log_density is partially applied (i.e. put into build_kernel) for particular samplers, but all things being equal, I don't see why we couldn't standardize this more and avoid the lambda abstraction you mentioned.

Relatedly, there are many improvements that could be made to the code quality of the MCLMC implementations that I have postponed, so if you see things like this, it is indeed often the case that it's improvable.

What would you like to know? We are currently benchmarking high dimensional sampling methods for large scale cosmological inference, and we find significant improvement in using MCLMC compared to NUTS et al. It was not so simple to tune at first as the energy fluctuation is quite sensible to preconditioning so we also work on that.

Basically I was curious about the comparison to NUTS. From what we understand, a lot of the improvement in high dimensions comes from removing the MH step, so it would be interesting to know about results using the adjusted variant (especially with the omelyan integrator). We're currently interested in benchmarking adjusted and unadjusted MCLMC across high dimensional problems (> 10^4) in the sciences (LCQD, computational chemistry, BNNs, cosmology, etc), so I might reach out by email at some point to follow up.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants