Skip to content

Commit

Permalink
Added tests for #1691.
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelClerx committed Dec 18, 2024
1 parent 0f0ae2e commit 3d16d1a
Showing 1 changed file with 53 additions and 1 deletion.
54 changes: 53 additions & 1 deletion pints/tests/test_mcmc_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,26 @@ def test_log_pdf_storage_in_memory_single(self):
likelihoods = [self.log_likelihood(x) for x in chain]
self.assertTrue(np.all(evals[i] == likelihoods))

# Test with a sensitivity-using method
mcmc = pints.MCMCController(
self.log_posterior, n_chains, xs, method=pints.MALAMCMC)
mcmc.set_max_iterations(n_iterations)
mcmc.set_log_to_screen(False)
mcmc.set_log_pdf_storage(True)
chains = mcmc.run()
evals = mcmc.log_pdfs()
self.assertEqual(len(evals.shape), 3)
self.assertEqual(evals.shape[0], n_chains)
self.assertEqual(evals.shape[1], n_iterations)
self.assertEqual(evals.shape[2], 3)
for i, chain in enumerate(chains):
posteriors = [self.log_posterior(x) for x in chain]
self.assertTrue(np.all(evals[i, :, 0] == posteriors))
likelihoods = [self.log_likelihood(x) for x in chain]
self.assertTrue(np.all(evals[i, :, 1] == likelihoods))
priors = [self.log_prior(x) for x in chain]
self.assertTrue(np.all(evals[i, :, 2] == priors))

# Test disabling again
mcmc = pints.MCMCController(self.log_posterior, n_chains, xs)
mcmc.set_max_iterations(n_iterations)
Expand All @@ -706,7 +726,9 @@ def test_log_pdf_storage_in_memory_multi(self):
x0 = np.array(self.real_parameters) * 1.05
x1 = np.array(self.real_parameters) * 1.15
x2 = np.array(self.real_parameters) * 0.95
xs = [x0, x1, x2]
x3 = np.array(self.real_parameters) * 0.95
x4 = np.array(self.real_parameters) * 0.95
xs = [x0, x1, x2, x3, x4]
n_chains = len(xs)
n_iterations = 100
meth = pints.DifferentialEvolutionMCMC
Expand Down Expand Up @@ -737,6 +759,16 @@ def test_log_pdf_storage_in_memory_multi(self):
priors = [self.log_prior(x) for x in chain]
self.assertTrue(np.all(evals[i, :, 2] == priors))

# Test with a sensitivity-using method
# We don't have any of these!
mcmc = pints.MCMCController(
self.log_posterior, n_chains, xs,
method=FakeMultiChainSamplerWithSensitivities)
mcmc.set_max_iterations(5)
mcmc.set_log_to_screen(False)
mcmc.set_log_pdf_storage(True)
chains = mcmc.run()

# Test with a loglikelihood
mcmc = pints.MCMCController(
self.log_likelihood, n_chains, xs, method=meth)
Expand Down Expand Up @@ -1651,6 +1683,26 @@ def tell(self, fx):
return None if x is None else (x, fx, np.array([True] * self._n))


class FakeMultiChainSamplerWithSensitivities(pints.MultiChainMCMC):
"""
Fake sampler that pretends to be a multi-chain method that uses
sensitivities. At the moment (2024-12-18) we don't have these in PINTS, but
we need to check that code potentially handling this type of sampler exists
or raises not-implemented errors.
"""
def ask(self):
self._xs = self._x0 + 1e-3 * np.random.uniform(
size=(self._n_chains, self._n_parameters))
return self._xs

def current_log_pdfs(self):
return self._fxs

def tell(self, fxs):
self._fxs = fxs
return self._xs, self._fxs, [True] * self._n_chains


class TestMCMCControllerSingleChainStorage(unittest.TestCase):
"""
Tests storage of samples and evaluations to disk, running with a
Expand Down

0 comments on commit 3d16d1a

Please sign in to comment.