From 3d16d1a69b8bd341619fc2c127114512ae3230b7 Mon Sep 17 00:00:00 2001 From: Michael Clerx Date: Wed, 18 Dec 2024 20:23:28 +0000 Subject: [PATCH] Added tests for #1691. --- pints/tests/test_mcmc_controller.py | 54 ++++++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/pints/tests/test_mcmc_controller.py b/pints/tests/test_mcmc_controller.py index 98220805b..8d2a462e7 100755 --- a/pints/tests/test_mcmc_controller.py +++ b/pints/tests/test_mcmc_controller.py @@ -690,6 +690,26 @@ def test_log_pdf_storage_in_memory_single(self): likelihoods = [self.log_likelihood(x) for x in chain] self.assertTrue(np.all(evals[i] == likelihoods)) + # Test with a sensitivity-using method + mcmc = pints.MCMCController( + self.log_posterior, n_chains, xs, method=pints.MALAMCMC) + mcmc.set_max_iterations(n_iterations) + mcmc.set_log_to_screen(False) + mcmc.set_log_pdf_storage(True) + chains = mcmc.run() + evals = mcmc.log_pdfs() + self.assertEqual(len(evals.shape), 3) + self.assertEqual(evals.shape[0], n_chains) + self.assertEqual(evals.shape[1], n_iterations) + self.assertEqual(evals.shape[2], 3) + for i, chain in enumerate(chains): + posteriors = [self.log_posterior(x) for x in chain] + self.assertTrue(np.all(evals[i, :, 0] == posteriors)) + likelihoods = [self.log_likelihood(x) for x in chain] + self.assertTrue(np.all(evals[i, :, 1] == likelihoods)) + priors = [self.log_prior(x) for x in chain] + self.assertTrue(np.all(evals[i, :, 2] == priors)) + # Test disabling again mcmc = pints.MCMCController(self.log_posterior, n_chains, xs) mcmc.set_max_iterations(n_iterations) @@ -706,7 +726,9 @@ def test_log_pdf_storage_in_memory_multi(self): x0 = np.array(self.real_parameters) * 1.05 x1 = np.array(self.real_parameters) * 1.15 x2 = np.array(self.real_parameters) * 0.95 - xs = [x0, x1, x2] + x3 = np.array(self.real_parameters) * 0.95 + x4 = np.array(self.real_parameters) * 0.95 + xs = [x0, x1, x2, x3, x4] n_chains = len(xs) n_iterations = 100 meth = pints.DifferentialEvolutionMCMC @@ -737,6 +759,16 @@ def test_log_pdf_storage_in_memory_multi(self): priors = [self.log_prior(x) for x in chain] self.assertTrue(np.all(evals[i, :, 2] == priors)) + # Test with a sensitivity-using method + # We don't have any of these! + mcmc = pints.MCMCController( + self.log_posterior, n_chains, xs, + method=FakeMultiChainSamplerWithSensitivities) + mcmc.set_max_iterations(5) + mcmc.set_log_to_screen(False) + mcmc.set_log_pdf_storage(True) + chains = mcmc.run() + # Test with a loglikelihood mcmc = pints.MCMCController( self.log_likelihood, n_chains, xs, method=meth) @@ -1651,6 +1683,26 @@ def tell(self, fx): return None if x is None else (x, fx, np.array([True] * self._n)) +class FakeMultiChainSamplerWithSensitivities(pints.MultiChainMCMC): + """ + Fake sampler that pretends to be a multi-chain method that uses + sensitivities. At the moment (2024-12-18) we don't have these in PINTS, but + we need to check that code potentially handling this type of sampler exists + or raises not-implemented errors. + """ + def ask(self): + self._xs = self._x0 + 1e-3 * np.random.uniform( + size=(self._n_chains, self._n_parameters)) + return self._xs + + def current_log_pdfs(self): + return self._fxs + + def tell(self, fxs): + self._fxs = fxs + return self._xs, self._fxs, [True] * self._n_chains + + class TestMCMCControllerSingleChainStorage(unittest.TestCase): """ Tests storage of samples and evaluations to disk, running with a