From 6c892e4d2362f2433bca97a745467fad8a2a3ea9 Mon Sep 17 00:00:00 2001 From: Michael Clerx Date: Wed, 11 Aug 2021 03:14:46 +0100 Subject: [PATCH] Added set_hyperparameters method for PopulationMCMC. Closes #1327. --- pints/_mcmc/_population.py | 26 ++++++++++++++++++++++++++ pints/tests/test_mcmc_population.py | 22 ++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/pints/_mcmc/_population.py b/pints/_mcmc/_population.py index b0183a2b4e..c203315374 100644 --- a/pints/_mcmc/_population.py +++ b/pints/_mcmc/_population.py @@ -337,3 +337,29 @@ def temperature_schedule(self): distribution is ``p(theta|data) ^ (1 - T)``. """ return self._schedule + + def n_hyper_parameters(self): + """ See :meth:`TunableMethod.n_hyper_parameters()`. """ + return 1 + + def set_hyper_parameters(self, x): + """ + The hyper-parameter vector is ``[n_temperatures]``, where + ``n_temperatures`` is an integer that will be passed to + :meth:`set_temperature_schedule`. + + Note that, since the hyper-parameter vector should be 1d (without + nesting), setting an explicit temperature schedule is not supported via + the hyper-parameter interface. + + See :meth:`TunableMethod.set_hyper_parameters()`. + """ + try: + n_temperatures = int(x[0]) + except TypeError: + raise ValueError( + 'First hyper-parameter must be (convertible to) an integer' + ' (setting an explicit schedule is not supported through the' + ' hyper-parameter interface).') + + self.set_temperature_schedule(n_temperatures) diff --git a/pints/tests/test_mcmc_population.py b/pints/tests/test_mcmc_population.py index a1e9593bdd..6709f751cb 100755 --- a/pints/tests/test_mcmc_population.py +++ b/pints/tests/test_mcmc_population.py @@ -126,6 +126,28 @@ def test_logging(self): self.assertIn(' j ', text) self.assertIn(' Ex. ', text) + def test_hyperparameters(self): + + # Create mcmc + mcmc = pints.PopulationMCMC(self.real_parameters) + + # Test setting with an int + mcmc.set_temperature_schedule(7) + x = mcmc.temperature_schedule() + self.assertEqual(len(x), 7) + mcmc.set_temperature_schedule(8) + self.assertEqual(len(mcmc.temperature_schedule()), 8) + mcmc.set_hyper_parameters([7]) + y = mcmc.temperature_schedule() + self.assertEqual(len(y), 7) + self.assertTrue(np.all(x == y)) + + # Test setting with a list + mcmc.set_temperature_schedule(x) + self.assertRaisesRegex( + ValueError, 'First hyper-parameter', + mcmc.set_hyper_parameters, [x]) + if __name__ == '__main__': unittest.main()