Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blackjax MCMC error #779

Open
danleonte opened this issue Feb 21, 2025 · 0 comments
Open

Blackjax MCMC error #779

danleonte opened this issue Feb 21, 2025 · 0 comments

Comments

@danleonte
Copy link

danleonte commented Feb 21, 2025

Describe the issue as clearly as possible:

The example usage code from barker_mcmc

barker = blackjax.barker(logdensity_fn, step_size)
state = barker.init(position)
new_state, info = barker.step(rng_key, state)

raises an error

barker = blackjax.barker(log_density_fn, step_size)
TypeError: 'module' object is not callable

Possible related to #723. I got it to work for a single chain

# initial_states = jnp.repeat(jnp.array([15.,15.,0.,1.,0.])[jnp.newaxis,:],axis=0, repeats = num_chains)
initial_position = jnp.array([15., 15., 0., 1., 0.])  # shape (5,)
barker_state = blackjax.mcmc.barker.init(
    position=initial_position, logdensity_fn=log_density_fn)
barker_kernel = blackjax.barker.build_kernel()
barker_kernel(rng_key=rng_key, state=barker_state,
              logdensity_fn=log_density_fn, step_size=0.1)


@jax.jit
def one_step(state, subkey):
    new_state, info = barker_kernel(subkey, state, log_density_fn, step_size)
    return new_state, state.position


# blackjax.mcmc.barker.build_kernel(jax.random.PRNGKey(42),barker_state)
keys = jax.random.split(rng_key, num_samples)
result = jax.lax.scan(one_step, barker_state, keys)

but I can't wrap my head around the windows size adaptation, it keeps giving an error. Judging by

warmup = blackjax.window_adaptation(blackjax.nuts, logdensity)

from the quickstart page

I need to pass the barker kernel in some way, but the current one raises error. Alternatively, this window size adaptation attempt

barker_kernel = blackjax.mcmc.barker.as_top_level_api(logdensity_fn = log_density_fn, 
                                                      step_size = step_size,
                                                      inverse_mass_matrix  = jnp.eye(5))
barker_kernel.step(rng, initial_state)

blackjax.window_adaptation(algorithm = barker_kernel, logdensity_fn= log_density_fn, is_mass_matrix_diagonal= False)

results in a different error

   mcmc_kernel = algorithm.build_kernel(integrator)
AttributeError: 'SamplingAlgorithm' object has no attribute 'build_kernel'

Can you please advise on this? Thanks.

Steps/code to reproduce the bug:

barker = blackjax.barker(logdensity_fn, step_size)
state = barker.init(position)
new_state, info = barker.step(rng_key, state)

Expected result:

.

Error message:

barker = blackjax.barker(log_density_fn, step_size)
TypeError: 'module' object is not callable

Blackjax/JAX/jaxlib/Python version information:

blackjax: 1.2.5
jax: 0.5.0 
on CPU

Context for the issue:

No response

@danleonte danleonte changed the title <Please write a descriptive title> Blackjax MCMC error Feb 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant