Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-path pathfinder instead of just one-path pathfinder implementation #763

Open
seongwoohan opened this issue Jan 12, 2025 · 16 comments
Open

Comments

@seongwoohan
Copy link

seongwoohan commented Jan 12, 2025

Presentation of the new sampler

Pathfinder is already implemented in blackjax. However, if I understand correctly, this implementation is a single-path (algorithm 1 in the paper) pathfinder run. I propose algorithm2, multiple-path pathfinder to approximate complex posterior faster.

If we have a multiple-path pathfinder, it can approximate a more complex posterior better. If we follow algorithm2, we can run one-path pathfinder algorithm multiple times with different initializations in parallel using jax vmap and then apply Arviz's Pareto smoothed importance sampling.

Julia currently has multiple-path pathfinder implementation (https://github.com/mlcolab/Pathfinder.jl/blob/52fef889ec6b579e35d251c0fc5b74628214b98f/src/multipath.jl#L68-L117) so a good reference.

How does it compare to other algorithms in blackjax?

In a high dimensional, multimodal posterior, I can't think of a better algorithm than multiple-pathfinder.

Where does it fit in blackjax

It's an extension of current pathfinder.

Are you willing to open a PR?

Yes or is it already under development?

@junpenglao
Copy link
Member

@aphc14 is currently working on Multi-pathfinder in JAX (he has a PyMC version ready pymc-devs/pymc-extras#387)

@seongwoohan
Copy link
Author

That is amazing @junpenglao @aphc14 . But it seems like multi pathfinder is not available in pymc yet. Is the code pushed?

When do you think the multi pathfinder version will be ready in blackjax? Also, it seems pareto smoothing importance sampling doesn’t always help refining the posterior in the pymc example.

@aphc14
Copy link

aphc14 commented Jan 13, 2025

@junpenglao @seongwoohan, the multi-path pathfinder in the pymc-extras PR #387 is functional but needs refining before it can be merged to main. I expect these refinements to be completed by the end of this week.

With BlackJAX, I need to address 1-2 minor bugs in the current single Pathfinder, which gives poor posterior estimates. Once that's resolved, the multi-path algorithm should be more straightforward, and I expect to have a PR for BlackJAX submitted in about 1-2 weeks. I'll update this thread when its ready for review.

@seongwoohan, the pymc examples here is very outdated and based on a buggy single Pathfinder. The soon-to-be merged multi-path would set PSIS as the default importance sampling (from my limited tests, I found that PSIS was more accurate and stable than PSIR), but there'll be an option not to do any importance sampling.

@seongwoohan
Copy link
Author

@aphc14 that's awesome. Let me know in here as well if pymc multi-path is available. I would like to review and test on several complex models that I have. Also happy to go over BlackJAX multi-path whenever you are ready.

Question for everyone (@junpenglao @aphc14): What do you think is a generally good inference method for handling high-dimensional, multi-modal posteriors? I've been experimenting with SVGD, and while multi-pathfinder combined with PSIS seems relatively effective, I'm curious if there are other approaches you'd recommend.

@junpenglao
Copy link
Member

I generally go with SMC with HMC as kernel, works fairly well in my experience.

@seongwoohan
Copy link
Author

Is there a way to use SMC with HMC as kernel in blackjax?

@junpenglao
Copy link
Member

@aphc14
Copy link

aphc14 commented Jan 24, 2025

@aphc14 that's awesome. Let me know in here as well if pymc multi-path is available. I would like to review and test on several complex models that I have. Also happy to go over BlackJAX multi-path whenever you are ready.

@seongwoohan pymc multi-path hasn't been merged yet and still pending review, but it should be functional if you want to give it a try pymc-devs/pymc-extras@862627e. Feel free to test it out on your models. Please let me know if you get any issues, errors, or see anything unusual.

... I expect to have a PR for BlackJAX submitted in about 1-2 weeks. I'll update this thread when its ready for review.

@junpenglao @seongwoohan Apologies, I'm behind with this timeframe. Will aim to have a draft PR roughly 1-2 weeks after the Lunar New Year.

@seongwoohan
Copy link
Author

@aphc14 Thank you, let me check the pymc version! Let me know if blackjax version is ready, happy to take a look. Also I assume blackjax multi-path pathfinder with numpyro would be faster than multi-path pathfinder with pymc so blackjax version may be more useful for scalable reasons?

@aphc14
Copy link

aphc14 commented Feb 13, 2025

@seongwoohan I'm not too sure what the speed differences will be until the Blackjax multi-path pathfinder is done. currently, in pymc-extras, the main body of pathfinder is compiled using pytensor, which applies rewrites to the compute graph and performs code execution using C-code where possible through cvm_nogc as the linker. You can test the speed up between pure Python mode versus the current mode from the code snippet.

Blackjax multi-path pathfinder would use mainly JAX functions in the backend and may have similar performance to using compile_kwargs=dict(mode="JAX") in pytensor. However, selecting JAX compiler in pytensor is not yet available, see pymc-devs/pymc-extras#425

py_mode = compile.mode.Mode(linker="py", optimizer="None")
current_mode = compile.mode.Mode(linker="cvm_nogc", optimizer="fast_run")

with model:
    idata = pmx.fit(
        method="pathfinder",
        inference_backend="pymc",
        compile_kwargs=dict(mode=py_mode), # slow
        # compile_kwargs=dict(mode=current_mode ), # fast
        # compile_kwargs=dict(mode="JAX"), # not yet available. 
    )

@seongwoohan
Copy link
Author

seongwoohan commented Feb 15, 2025

@aphc14 Thank you, I just tried both, and current_modeis faster.

One concern I have is I wonder if pymc & blackjax multi-path pathfinder has a way to set up its initializations like Julia multi-path pathfinder does here.

Here is how the below example gives customized init to multi-path pathfinder.

function fit_cyclic_model(g::interventionGraph, log_normalize::Bool, model_pars::NamedTuple, sampling_pars::NamedTuple)
    Turing.setadbackend(:forwarddiff)

    @info "$(now()) preparing matrices for the model"
    T, t, edges = get_cyclic_matrices(g, log_normalize, true, true)
    T = sparse(T)

    model = joint_cyclic_model(g.nv, model_pars, T, t)

    init_β = 0.3 .* (T \ t)
    noise_dist = Normal(0, .10)
    init = vcat(
        # zeros(g.nv * (g.nv - 1)), #beta
        init_β,
        # -5.0
        fill(-3.0, g.nv) # σ
    )
    @info "$(now()) running pathfinder now"
    nruns = 5
    # 0.72 seconds with gradientdescent()
    # 0.10 with LBFGS
    # result_multi = multipathfinder(
    result_multi = wrap_multipathfinder(
        model, 
        2_000;
        nruns = nruns,
        # ad_backend = AD.ReverseDiffBackend(),
        optimizer = Optim.LBFGS(; 
            m = 6,
            linesearch = HagerZhang(), 
            alphaguess = InitialHagerZhang(α0=0.8),
        ),
        init = [Float64.(init .+ rand(noise_dist, length(init))) for i in 1:nruns],
        # init_scale = 0.01,
        importance = true,
        ndraws_elbo = 100,
        # show_trace = false,
        # ntries = 100
        # iterations = 200
    )
    @info "$(now()) done running pathfinder"
    model_chain = result_multi.draws_transformed 
    chains_params = Turing.MCMCChains.get_sections(model_chain, :parameters)
    quantities = generated_quantities(model, chains_params)
    # return model_chain, quantities, path_init, sampler, vec_β_init
    return model_chain, quantities, model, result_multi
end

With a random start like what the current implementation is doing, pathfinder will mostly find a local optimum even though we run multiple paths. Without a good initializations, Pareto k value would be very high, so having a smart initialization is crucial if we want to use multi-path pathfinder efficiently. Without it, it won't have much difference from single-path pathfinder. Is there a way for users to set up their own initialization points in the pymc & blackjax multi-path pathfinder like Julia? For now, I think the only way to have different initialization is to control the jitter by showing some randomness if I understand correctly.

I wonder if you are already working on custom initializations.

@aphc14
Copy link

aphc14 commented Feb 17, 2025

Currently, you can control the initial values by controlling the base values and jitter scale via the jitter argument in the function call, since initial values = base + jitter draw. To change the base value, you can set the initval for a random parameter inside pm.Model. for example:

from pymc.initial_point import make_initial_point_fn
from pymc.model.core import Point
from pymc.blocking import DictToArrayBijection


def eight_schools_model():
    J = 8
    y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
    sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

    with pm.Model() as model:
        mu = pm.Normal("mu", mu=0.0, sigma=10.0, initval=5.0)
        tau = pm.HalfCauchy("tau", 5.0, initval=20.0)

        theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
        obs = pm.Normal("obs", mu=mu + tau * theta, sigma=sigma, shape=J, observed=y)

    return model


model = eight_schools_model()

ipfn = make_initial_point_fn(model=model)
ip = Point(ipfn(None), model=model)
x_base = DictToArrayBijection.map(ip).data

print(x_base)

It's not recommended to set the initial values at the MAP or posterior mode with small jitter scales. Doing so could reduce the algorithm's ability to properly explore the posterior space.

There isn't a way to pass in an iterator to control initial values exactly or have a separate jitter scale for each parameter, unfortunately. Because a working Pathfinder implementation in PyMC is still relatively new, we're still discovering ways to make improvements. Having more control over the initialisation compared to what can be done currently may be something worth testing out. You might want to submit a feature request (or git issue to pymc-extras repo) detailing the use case with an example on why it'll be helpful to have this and the differences between the results.

Regarding the Pareto k values, I've experienced Pathfinder results with high (poor) Pareto k values, however, resembled a posterior similar what you'd get with NUTS, and a better fit (through visual comparisons) than ADVI. I've included the Pareto k values in the stdout result for reference/completeness. Pareto k values alone might not provide you with enough to go on to determine whether the model fit is reasonable to work with.

@seongwoohan
Copy link
Author

seongwoohan commented Feb 17, 2025

@aphc14 Thanks, also I am happy to submit a feature request for custom initializations. In my experience when playing with multi-path pathfinder in Julia, giving custom initializations was pretty effective in high dimensional posterior. With random starts, it sometimes fail to find the highest mode and fall into local optimums.

The Julia code example I attached is such situation when your goal is to find the highest mode and couple others with multi-path pathfinder when your dimension is like hundreds. In these cases, as HMC would take a lot of running time due to the mixing, multi-path pathfinder is preferred to find well approximated highest and couple modes in a moderate running time. Pathfinder paper shows the importance of giving smart initializations in Figure 15.

Then does the blackjax multi-path pathfinder follows the pymc's multi-path pathfinder, not having custom initializations at the moment? If blackjax version is already developed I would like to help or add custom initializations.

@aphc14
Copy link

aphc14 commented Feb 18, 2025

@seongwoohan Agreed--would be great to have initialisation control for the individual parameters. I am thinking of updating this feature request pymc-devs/pymc#7555 to add in start and update the jitter_scale argument so it can be specified at the parameter level. In terms of being able to input an iterator for the intialisations, I think it'll be easier to handle that within the fit_pathfinder in pymc-extras. Happy for you to submit a git issue to the pymc-extras repo for this.

I plan to submit a draft PR to blackjax in the coming days (TBC) re multi-path pathfinder. Feel free to help out with multipath pathfinder once after the draft PR is submitted.

@seongwoohan
Copy link
Author

@aphc14 Great, after having the final version at some point, I am also happy to provide comparisons (e.g., jupyter notebook) between random vs preset initialization multi-path pathfinders in high dimension data for future users. Let me know when the blackjax version is ready here!

@aphc14
Copy link

aphc14 commented Feb 19, 2025

@seongwoohan sounds good!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants