We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The example usage code from barker_mcmc
barker = blackjax.barker(logdensity_fn, step_size) state = barker.init(position) new_state, info = barker.step(rng_key, state)
raises an error
barker = blackjax.barker(log_density_fn, step_size) TypeError: 'module' object is not callable
Possible related to #723. I got it to work for a single chain
# initial_states = jnp.repeat(jnp.array([15.,15.,0.,1.,0.])[jnp.newaxis,:],axis=0, repeats = num_chains) initial_position = jnp.array([15., 15., 0., 1., 0.]) # shape (5,) barker_state = blackjax.mcmc.barker.init( position=initial_position, logdensity_fn=log_density_fn) barker_kernel = blackjax.barker.build_kernel() barker_kernel(rng_key=rng_key, state=barker_state, logdensity_fn=log_density_fn, step_size=0.1) @jax.jit def one_step(state, subkey): new_state, info = barker_kernel(subkey, state, log_density_fn, step_size) return new_state, state.position # blackjax.mcmc.barker.build_kernel(jax.random.PRNGKey(42),barker_state) keys = jax.random.split(rng_key, num_samples) result = jax.lax.scan(one_step, barker_state, keys)
but I can't wrap my head around the windows size adaptation, it keeps giving an error. Judging by
warmup = blackjax.window_adaptation(blackjax.nuts, logdensity)
from the quickstart page
I need to pass the barker kernel in some way, but the current one raises error. Alternatively, this window size adaptation attempt
barker_kernel = blackjax.mcmc.barker.as_top_level_api(logdensity_fn = log_density_fn, step_size = step_size, inverse_mass_matrix = jnp.eye(5)) barker_kernel.step(rng, initial_state) blackjax.window_adaptation(algorithm = barker_kernel, logdensity_fn= log_density_fn, is_mass_matrix_diagonal= False)
results in a different error
mcmc_kernel = algorithm.build_kernel(integrator) AttributeError: 'SamplingAlgorithm' object has no attribute 'build_kernel'
Can you please advise on this? Thanks.
.
blackjax: 1.2.5 jax: 0.5.0 on CPU
No response
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Describe the issue as clearly as possible:
The example usage code from barker_mcmc
raises an error
Possible related to #723. I got it to work for a single chain
but I can't wrap my head around the windows size adaptation, it keeps giving an error. Judging by
from the quickstart page
I need to pass the barker kernel in some way, but the current one raises error. Alternatively, this window size adaptation attempt
results in a different error
Can you please advise on this? Thanks.
Steps/code to reproduce the bug:
Expected result:
.
Error message:
barker = blackjax.barker(log_density_fn, step_size) TypeError: 'module' object is not callable
Blackjax/JAX/jaxlib/Python version information:
Context for the issue:
No response
The text was updated successfully, but these errors were encountered: