Skip to content

Commit

Permalink
update test_barker so it works with metric.scale
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Sep 24, 2024
1 parent fd35a51 commit 8d580ec
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions tests/mcmc/test_barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import numpy as np
from absl.testing import absltest, parameterized

import blackjax
from blackjax.mcmc import metrics
from blackjax.mcmc.barker import _barker_pdf, _barker_sample_nd
from blackjax.util import run_inference_algorithm

Expand All @@ -23,13 +23,10 @@ def test_nd(self, seed):
jnp.array([1.0, -2.0, 10.0, 0.0]),
0.5,
)
# will have no effect on original test with no preconditioning matrix
inv_mass_matrix = jnp.eye(4)

metric = metrics.default_metric(jnp.eye(4))
keys = jax.random.split(key, n_samples)
samples = jax.vmap(
lambda k: _barker_sample_nd(k, m, a, scale, inv_mass_matrix)
)(keys)
samples = jax.vmap(lambda k: _barker_sample_nd(k, m, a, scale, metric))(keys)
# Check that the emprical mean and the mean computed as sum(x * p(x) dx) are close
_test_samples_vs_pdf(samples, lambda x: _barker_pdf(x, m, a, scale))

Expand Down Expand Up @@ -68,7 +65,6 @@ def test_preconditioning_matrix(self, seed):
We follow the discussion in Appendix G of the Barker 2020 paper.
"""
from blackjax.mcmc.barker import _get_mass_matrix_sqrt

key = jax.random.key(seed)
init_key, inference_key = jax.random.split(key, 2)
Expand All @@ -81,7 +77,7 @@ def test_preconditioning_matrix(self, seed):

# some non-diagonal positive-defininte matrix for pre-conditioning
inv_mass_matrix = jnp.array([[1, 0.1], [0.1, 1]])
C_t, C_t_inv = _get_mass_matrix_sqrt(inv_mass_matrix)
metric = metrics.default_metric(inv_mass_matrix)

# define barker kernel two ways
# non-scaled, use pre-conditioning
Expand All @@ -95,34 +91,42 @@ def logdensity(x, data):
state1 = barker1.init(true_x)

# scaled, trivial pre-conditioning
def scaled_logdensity(x_scaled, data, C_t):
return logdensity(C_t.dot(x_scaled), data)
def scaled_logdensity(x_scaled, data, metric):
x = metric.scale(x_scaled, x_scaled, False, False)
return logdensity(x, data)

logposterior_fn2 = functools.partial(scaled_logdensity, data=data, C_t=C_t)
logposterior_fn2 = functools.partial(
scaled_logdensity, data=data, metric=metric
)
barker2 = blackjax.barker_proposal(logposterior_fn2, 1e-1, jnp.eye(2))

true_x_trans = C_t_inv.dot(true_x)
true_x_trans = metric.scale(true_x, true_x, True, True)
state2 = barker2.init(true_x_trans)

n_steps = 10
_, states1 = run_inference_algorithm(
rng_key=inference_key,
initial_state=state1,
inference_algorithm=barker1,
transform=lambda state, info: state.position,
num_steps=1000,
num_steps=n_steps,
)

_, states2 = run_inference_algorithm(
rng_key=inference_key,
initial_state=state2,
inference_algorithm=barker2,
transform=lambda state, info: state.position,
num_steps=1000,
num_steps=n_steps,
)

# states should be the exact same with same random key after transforming
states2_trans = C_t.dot(states2.T).T
np.testing.assert_allclose(states1, states2_trans, atol=1e-2)
states2_trans = []
for ii in range(n_steps):
s = states2[ii]
states2_trans.append(metric.scale(s, s, False, False))
states2_trans = jnp.array(states2_trans)
assert jnp.allclose(states1, states2_trans)


if __name__ == "__main__":
Expand Down

0 comments on commit 8d580ec

Please sign in to comment.