Skip to content

Commit

Permalink
fix bug so now everything is tree_mapped in barker
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Oct 2, 2024
1 parent bd6ba3d commit 02e9cb9
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions blackjax/mcmc/barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from blackjax.mcmc.metrics import Metric
from blackjax.mcmc.proposal import static_binomial_sampling
from blackjax.types import ArrayLikeTree, ArrayTree, Numeric, PRNGKey
from blackjax.util import generate_gaussian_noise

__all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "as_top_level_api"]

Expand Down Expand Up @@ -224,6 +225,15 @@ def step_fn(rng_key: PRNGKey, state):
return SamplingAlgorithm(init_fn, step_fn)


def _generate_bernoulli(
rng_key: PRNGKey, position: ArrayLikeTree, p: ArrayLikeTree
) -> ArrayTree:
pos, unravel_fn = ravel_pytree(position)
p_flat, _ = ravel_pytree(p)
sample = jax.random.bernoulli(rng_key, p=p_flat, shape=pos.shape)
return unravel_fn(sample)


def _barker_sample(key, mean, a, scale, metric):
r"""
Sample from a multivariate Barker's proposal distribution for PyTrees.
Expand All @@ -244,18 +254,20 @@ def _barker_sample(key, mean, a, scale, metric):
"""

key1, key2 = jax.random.split(key)
flat_mean, _ = ravel_pytree(mean)

z = generate_gaussian_noise(key1, mean, sigma=scale)
c = metric.scale(mean, a, False, True)
z = scale * jax.random.normal(key1, shape=flat_mean.shape)

# Sample b=1 with probability p and 0 with probability 1 - p where
# p = 1 / (1 + exp(-a * (z - mean)))
log_p = jax.tree_util.tree_map(lambda x, y: -_log1pexp(-x * y), c, z)
b = jax.random.bernoulli(key2, p=jnp.exp(log_p), shape=flat_mean.shape)
p = jax.tree_util.tree_map(lambda x: jnp.exp(x), log_p)
b = _generate_bernoulli(key2, mean, p=p)

bz = jax.tree_util.tree_map(lambda x, y: x * y - (1 - x) * y, b, z)

return jax.tree_util.tree_map(
lambda a, b: a + b, mean, metric.scale(mean, b * z - (1 - b) * z, False, False)
lambda a, b: a + b, mean, metric.scale(mean, bz, False, False)
)


Expand Down

0 comments on commit 02e9cb9

Please sign in to comment.