Skip to content

Commit

Permalink
Added set_hyperparameters method for PopulationMCMC. Closes #1327.
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelClerx committed Aug 11, 2021
1 parent 73268a5 commit 6c892e4
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
26 changes: 26 additions & 0 deletions pints/_mcmc/_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,29 @@ def temperature_schedule(self):
distribution is ``p(theta|data) ^ (1 - T)``.
"""
return self._schedule

def n_hyper_parameters(self):
""" See :meth:`TunableMethod.n_hyper_parameters()`. """
return 1

def set_hyper_parameters(self, x):
"""
The hyper-parameter vector is ``[n_temperatures]``, where
``n_temperatures`` is an integer that will be passed to
:meth:`set_temperature_schedule`.
Note that, since the hyper-parameter vector should be 1d (without
nesting), setting an explicit temperature schedule is not supported via
the hyper-parameter interface.
See :meth:`TunableMethod.set_hyper_parameters()`.
"""
try:
n_temperatures = int(x[0])
except TypeError:
raise ValueError(
'First hyper-parameter must be (convertible to) an integer'
' (setting an explicit schedule is not supported through the'
' hyper-parameter interface).')

self.set_temperature_schedule(n_temperatures)
22 changes: 22 additions & 0 deletions pints/tests/test_mcmc_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,28 @@ def test_logging(self):
self.assertIn(' j ', text)
self.assertIn(' Ex. ', text)

def test_hyperparameters(self):

# Create mcmc
mcmc = pints.PopulationMCMC(self.real_parameters)

# Test setting with an int
mcmc.set_temperature_schedule(7)
x = mcmc.temperature_schedule()
self.assertEqual(len(x), 7)
mcmc.set_temperature_schedule(8)
self.assertEqual(len(mcmc.temperature_schedule()), 8)
mcmc.set_hyper_parameters([7])
y = mcmc.temperature_schedule()
self.assertEqual(len(y), 7)
self.assertTrue(np.all(x == y))

# Test setting with a list
mcmc.set_temperature_schedule(x)
self.assertRaisesRegex(
ValueError, 'First hyper-parameter',
mcmc.set_hyper_parameters, [x])


if __name__ == '__main__':
unittest.main()

0 comments on commit 6c892e4

Please sign in to comment.