From a7e18313c038d8bd8edf0925a6acf3394fe0bb39 Mon Sep 17 00:00:00 2001 From: Ismael Mendoza Date: Tue, 24 Sep 2024 14:58:44 -0500 Subject: [PATCH] fix tests add trans to scale --- tests/mcmc/test_metrics.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/mcmc/test_metrics.py b/tests/mcmc/test_metrics.py index 0791f3cb1..098649a9a 100644 --- a/tests/mcmc/test_metrics.py +++ b/tests/mcmc/test_metrics.py @@ -131,8 +131,8 @@ def test_gaussian_euclidean_dim_1(self): assert momentum_val == expected_momentum_val assert kinetic_energy_val == expected_kinetic_energy_val - inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) - scaled_momentum = scale(arbitrary_position, momentum_val, False) + inv_scaled_momentum = scale(arbitrary_position, momentum_val, True, False) + scaled_momentum = scale(arbitrary_position, momentum_val, False, False) expected_scaled_momentum = momentum_val / jnp.sqrt(inverse_mass_matrix) expected_inv_scaled_momentum = momentum_val * jnp.sqrt(inverse_mass_matrix) @@ -164,8 +164,8 @@ def test_gaussian_euclidean_dim_2(self): np.testing.assert_allclose(expected_momentum_val, momentum_val) np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) - inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) - scaled_momentum = scale(arbitrary_position, momentum_val, False) + inv_scaled_momentum = scale(arbitrary_position, momentum_val, True, False) + scaled_momentum = scale(arbitrary_position, momentum_val, False, False) expected_inv_scaled_momentum = jnp.linalg.inv(L_inv).T @ momentum_val expected_scaled_momentum = L_inv @ momentum_val @@ -226,8 +226,8 @@ def test_gaussian_riemannian_dim_1(self): np.testing.assert_allclose(expected_momentum_val, momentum_val) np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) - inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) - scaled_momentum = scale(arbitrary_position, momentum_val, False) + inv_scaled_momentum = scale(arbitrary_position, momentum_val, True, False) + scaled_momentum = scale(arbitrary_position, momentum_val, False, False) expected_scaled_momentum = momentum_val / jnp.sqrt(inverse_mass_matrix) expected_inv_scaled_momentum = momentum_val * jnp.sqrt(inverse_mass_matrix) @@ -265,8 +265,8 @@ def test_gaussian_riemannian_dim_2(self): np.testing.assert_allclose(expected_momentum_val, momentum_val) np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) - inv_scaled_momentum = scale(arbitrary_position, momentum_val, True) - scaled_momentum = scale(arbitrary_position, momentum_val, False) + inv_scaled_momentum = scale(arbitrary_position, momentum_val, True, False) + scaled_momentum = scale(arbitrary_position, momentum_val, False, False) expected_inv_scaled_momentum = jnp.linalg.inv(L_inv).T @ momentum_val expected_scaled_momentum = L_inv @ momentum_val