-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-path pathfinder instead of just one-path pathfinder implementation #763
Comments
@aphc14 is currently working on Multi-pathfinder in JAX (he has a PyMC version ready pymc-devs/pymc-extras#387) |
That is amazing @junpenglao @aphc14 . But it seems like multi pathfinder is not available in pymc yet. Is the code pushed? When do you think the multi pathfinder version will be ready in blackjax? Also, it seems pareto smoothing importance sampling doesn’t always help refining the posterior in the pymc example. |
@junpenglao @seongwoohan, the multi-path pathfinder in the pymc-extras PR #387 is functional but needs refining before it can be merged to main. I expect these refinements to be completed by the end of this week. With BlackJAX, I need to address 1-2 minor bugs in the current single Pathfinder, which gives poor posterior estimates. Once that's resolved, the multi-path algorithm should be more straightforward, and I expect to have a PR for BlackJAX submitted in about 1-2 weeks. I'll update this thread when its ready for review. @seongwoohan, the pymc examples here is very outdated and based on a buggy single Pathfinder. The soon-to-be merged multi-path would set PSIS as the default importance sampling (from my limited tests, I found that PSIS was more accurate and stable than PSIR), but there'll be an option not to do any importance sampling. |
@aphc14 that's awesome. Let me know in here as well if pymc multi-path is available. I would like to review and test on several complex models that I have. Also happy to go over BlackJAX multi-path whenever you are ready. Question for everyone (@junpenglao @aphc14): What do you think is a generally good inference method for handling high-dimensional, multi-modal posteriors? I've been experimenting with SVGD, and while multi-pathfinder combined with PSIS seems relatively effective, I'm curious if there are other approaches you'd recommend. |
I generally go with SMC with HMC as kernel, works fairly well in my experience. |
Is there a way to use SMC with HMC as kernel in blackjax? |
you can take a look at this example: https://blackjax-devs.github.io/sampling-book/algorithms/TemperedSMC.html#tempered-smc-with-hmc-kernel |
@seongwoohan pymc multi-path hasn't been merged yet and still pending review, but it should be functional if you want to give it a try pymc-devs/pymc-extras@862627e. Feel free to test it out on your models. Please let me know if you get any issues, errors, or see anything unusual.
@junpenglao @seongwoohan Apologies, I'm behind with this timeframe. Will aim to have a draft PR roughly 1-2 weeks after the Lunar New Year. |
@aphc14 Thank you, let me check the pymc version! Let me know if blackjax version is ready, happy to take a look. Also I assume blackjax multi-path pathfinder with numpyro would be faster than multi-path pathfinder with pymc so blackjax version may be more useful for scalable reasons? |
@seongwoohan I'm not too sure what the speed differences will be until the Blackjax multi-path pathfinder is done. currently, in pymc-extras, the main body of pathfinder is compiled using pytensor, which applies rewrites to the compute graph and performs code execution using C-code where possible through Blackjax multi-path pathfinder would use mainly JAX functions in the backend and may have similar performance to using py_mode = compile.mode.Mode(linker="py", optimizer="None")
current_mode = compile.mode.Mode(linker="cvm_nogc", optimizer="fast_run")
with model:
idata = pmx.fit(
method="pathfinder",
inference_backend="pymc",
compile_kwargs=dict(mode=py_mode), # slow
# compile_kwargs=dict(mode=current_mode ), # fast
# compile_kwargs=dict(mode="JAX"), # not yet available.
) |
@aphc14 Thank you, I just tried both, and One concern I have is I wonder if pymc & blackjax multi-path pathfinder has a way to set up its initializations like Julia multi-path pathfinder does here. Here is how the below example gives customized
With a random start like what the current implementation is doing, pathfinder will mostly find a local optimum even though we run multiple paths. Without a good initializations, Pareto k value would be very high, so having a smart initialization is crucial if we want to use multi-path pathfinder efficiently. Without it, it won't have much difference from single-path pathfinder. Is there a way for users to set up their own initialization points in the pymc & blackjax multi-path pathfinder like Julia? For now, I think the only way to have different initialization is to control the jitter by showing some randomness if I understand correctly. I wonder if you are already working on custom initializations. |
Currently, you can control the initial values by controlling the base values and jitter scale via the from pymc.initial_point import make_initial_point_fn
from pymc.model.core import Point
from pymc.blocking import DictToArrayBijection
def eight_schools_model():
J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
with pm.Model() as model:
mu = pm.Normal("mu", mu=0.0, sigma=10.0, initval=5.0)
tau = pm.HalfCauchy("tau", 5.0, initval=20.0)
theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
obs = pm.Normal("obs", mu=mu + tau * theta, sigma=sigma, shape=J, observed=y)
return model
model = eight_schools_model()
ipfn = make_initial_point_fn(model=model)
ip = Point(ipfn(None), model=model)
x_base = DictToArrayBijection.map(ip).data
print(x_base) It's not recommended to set the initial values at the MAP or posterior mode with small jitter scales. Doing so could reduce the algorithm's ability to properly explore the posterior space. There isn't a way to pass in an iterator to control initial values exactly or have a separate jitter scale for each parameter, unfortunately. Because a working Pathfinder implementation in PyMC is still relatively new, we're still discovering ways to make improvements. Having more control over the initialisation compared to what can be done currently may be something worth testing out. You might want to submit a feature request (or git issue to pymc-extras repo) detailing the use case with an example on why it'll be helpful to have this and the differences between the results. Regarding the Pareto k values, I've experienced Pathfinder results with high (poor) Pareto k values, however, resembled a posterior similar what you'd get with NUTS, and a better fit (through visual comparisons) than ADVI. I've included the Pareto k values in the stdout result for reference/completeness. Pareto k values alone might not provide you with enough to go on to determine whether the model fit is reasonable to work with. |
@aphc14 Thanks, also I am happy to submit a feature request for custom initializations. In my experience when playing with multi-path pathfinder in Julia, giving custom initializations was pretty effective in high dimensional posterior. With random starts, it sometimes fail to find the highest mode and fall into local optimums. The Julia code example I attached is such situation when your goal is to find the highest mode and couple others with multi-path pathfinder when your dimension is like hundreds. In these cases, as HMC would take a lot of running time due to the mixing, multi-path pathfinder is preferred to find well approximated highest and couple modes in a moderate running time. Pathfinder paper shows the importance of giving smart initializations in Figure 15. Then does the blackjax multi-path pathfinder follows the pymc's multi-path pathfinder, not having custom initializations at the moment? If blackjax version is already developed I would like to help or add custom initializations. |
@seongwoohan Agreed--would be great to have initialisation control for the individual parameters. I am thinking of updating this feature request pymc-devs/pymc#7555 to add in I plan to submit a draft PR to blackjax in the coming days (TBC) re multi-path pathfinder. Feel free to help out with multipath pathfinder once after the draft PR is submitted. |
@aphc14 Great, after having the final version at some point, I am also happy to provide comparisons (e.g., jupyter notebook) between random vs preset initialization multi-path pathfinders in high dimension data for future users. Let me know when the blackjax version is ready here! |
@seongwoohan sounds good! |
Presentation of the new sampler
Pathfinder is already implemented in blackjax. However, if I understand correctly, this implementation is a single-path (algorithm 1 in the paper) pathfinder run. I propose algorithm2, multiple-path pathfinder to approximate complex posterior faster.
If we have a multiple-path pathfinder, it can approximate a more complex posterior better. If we follow algorithm2, we can run one-path pathfinder algorithm multiple times with different initializations in parallel using jax vmap and then apply Arviz's Pareto smoothed importance sampling.
Julia currently has multiple-path pathfinder implementation (https://github.com/mlcolab/Pathfinder.jl/blob/52fef889ec6b579e35d251c0fc5b74628214b98f/src/multipath.jl#L68-L117) so a good reference.
How does it compare to other algorithms in blackjax?
In a high dimensional, multimodal posterior, I can't think of a better algorithm than multiple-pathfinder.
Where does it fit in blackjax
It's an extension of current pathfinder.
Are you willing to open a PR?
Yes or is it already under development?
The text was updated successfully, but these errors were encountered: