Skip to content

Commit

Permalink
Tests: Check full-rank covariance matrix
Browse files Browse the repository at this point in the history
Add assert statements that verify full-rank VI recovers the true,
full-rank covariance matrix.
  • Loading branch information
gil2rok committed Sep 11, 2024
1 parent 26da046 commit 4b4534f
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions tests/vi/test_fullrank_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def setUp(self):
super().setUp()
self.key = jax.random.key(42)

@chex.variants(with_jit=True, without_jit=False)
def test_recover_posterior(self):
ground_truth = [
# loc, scale
Expand All @@ -39,11 +38,15 @@ def logdensity_fn(x):
rng_key = self.key
for i in range(num_steps):
subkey = jax.random.fold_in(rng_key, i)
state, _ = self.variant(frvi.step)(subkey, state)
state, _ = jax.jit(frvi.step)(subkey, state)

loc_1, loc_2 = state.mu["x_1"], state.mu["x_2"]
chol_factor = state.chol_params
scale_1, scale_2 = jnp.exp(chol_factor[0]), jnp.exp(chol_factor[1])
self.assertAlmostEqual(loc_1, ground_truth[0][0], delta=0.01)
self.assertAlmostEqual(scale_1, ground_truth[0][1], delta=0.01)
self.assertAlmostEqual(loc_2, ground_truth[1][0], delta=0.01)
self.assertAlmostEqual(scale_2, ground_truth[1][1], delta=0.01)


if __name__ == "__main__":
Expand Down

0 comments on commit 4b4534f

Please sign in to comment.