-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1104 from pints-team/930-gradient-descent
Simple gradient descent optimiser
- Loading branch information
Showing
10 changed files
with
879 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
************************************** | ||
Gradient descent (fixed learning rate) | ||
************************************** | ||
|
||
.. currentmodule:: pints | ||
|
||
.. autoclass:: GradientDescent | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# | ||
# Fixed learning-rate gradient descent. | ||
# | ||
# This file is part of PINTS (https://github.com/pints-team/pints/) which is | ||
# released under the BSD 3-clause license. See accompanying LICENSE.md for | ||
# copyright notice and full license details. | ||
# | ||
from __future__ import absolute_import, division | ||
from __future__ import print_function, unicode_literals | ||
|
||
import pints | ||
|
||
|
||
class GradientDescent(pints.Optimiser): | ||
""" | ||
Gradient-descent method with a fixed learning rate. | ||
""" | ||
|
||
def __init__(self, x0, sigma0=0.1, boundaries=None): | ||
super(GradientDescent, self).__init__(x0, sigma0, boundaries) | ||
|
||
# Set optimiser state | ||
self._running = False | ||
self._ready_for_tell = False | ||
|
||
# Best solution found | ||
self._xbest = self._x0 | ||
self._fbest = float('inf') | ||
|
||
# Learning rate | ||
self._eta = 0.01 | ||
|
||
# Current point, score, and gradient | ||
self._current = self._x0 | ||
self._current_f = None | ||
self._current_df = None | ||
|
||
# Proposed next point (read-only, so can be passed to user) | ||
self._proposed = self._x0 | ||
self._proposed.setflags(write=False) | ||
|
||
def ask(self): | ||
""" See :meth:`Optimiser.ask()`. """ | ||
|
||
# Running, and ready for tell now | ||
self._ready_for_tell = True | ||
self._running = True | ||
|
||
# Return proposed points (just the one) | ||
return [self._proposed] | ||
|
||
def fbest(self): | ||
""" See :meth:`Optimiser.fbest()`. """ | ||
return self._fbest | ||
|
||
def learning_rate(self): | ||
""" Returns this optimiser's learning rate. """ | ||
return self._eta | ||
|
||
def name(self): | ||
""" See :meth:`Optimiser.name()`. """ | ||
return 'Gradient descent' | ||
|
||
def needs_sensitivities(self): | ||
""" See :meth:`Optimiser.needs_sensitivities()`. """ | ||
return True | ||
|
||
def n_hyper_parameters(self): | ||
""" See :meth:`pints.TunableMethod.n_hyper_parameters()`. """ | ||
return 1 | ||
|
||
def running(self): | ||
""" See :meth:`Optimiser.running()`. """ | ||
return self._running | ||
|
||
def set_hyper_parameters(self, x): | ||
""" | ||
See :meth:`pints.TunableMethod.set_hyper_parameters()`. | ||
The hyper-parameter vector is ``[learning_rate]``. | ||
""" | ||
self.set_learning_rate(x[0]) | ||
|
||
def set_learning_rate(self, eta): | ||
""" | ||
Sets the learning rate for this optimiser. | ||
Parameters | ||
---------- | ||
eta : float | ||
The learning rate, as a float greater than zero. | ||
""" | ||
eta = float(eta) | ||
if eta <= 0: | ||
raise ValueError('Learning rate must greater than zero.') | ||
self._eta = eta | ||
|
||
def tell(self, reply): | ||
""" See :meth:`Optimiser.tell()`. """ | ||
|
||
# Check ask-tell pattern | ||
if not self._ready_for_tell: | ||
raise Exception('ask() not called before tell()') | ||
self._ready_for_tell = False | ||
|
||
# Unpack reply | ||
fx, dfx = reply[0] | ||
|
||
# Move to proposed point | ||
self._current = self._proposed | ||
self._current_f = fx | ||
self._current_df = dfx | ||
|
||
# Propose next point | ||
self._proposed = self._current - self._eta * dfx | ||
self._proposed.setflags(write=False) | ||
|
||
# Update xbest and fbest | ||
if self._fbest > fx: | ||
self._fbest = fx | ||
self._xbest = self._current | ||
|
||
def xbest(self): | ||
""" See :meth:`Optimiser.xbest()`. """ | ||
return self._xbest | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
#!/usr/bin/env python3 | ||
# | ||
# Tests the basic methods of the Gradient Descent optimiser. | ||
# | ||
# This file is part of PINTS (https://github.com/pints-team/pints/) which is | ||
# released under the BSD 3-clause license. See accompanying LICENSE.md for | ||
# copyright notice and full license details. | ||
# | ||
import unittest | ||
import numpy as np | ||
|
||
import pints | ||
import pints.toy | ||
|
||
from shared import CircularBoundaries | ||
|
||
|
||
debug = False | ||
method = pints.GradientDescent | ||
|
||
# Consistent unit testing in Python 2 and 3 | ||
try: | ||
unittest.TestCase.assertRaisesRegex | ||
except AttributeError: | ||
unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp | ||
|
||
|
||
class TestGradientDescent(unittest.TestCase): | ||
""" | ||
Tests the basic methods of the gradient descent optimiser. | ||
""" | ||
def setUp(self): | ||
""" Called before every test """ | ||
np.random.seed(1) | ||
|
||
def problem(self): | ||
""" Returns a test problem, starting point, sigma, and boundaries. """ | ||
r = pints.toy.ParabolicError() | ||
x = [0.1, 0.1] | ||
s = 0.1 | ||
b = pints.RectangularBoundaries([-1, -1], [1, 1]) | ||
return r, x, s, b | ||
|
||
def test_unbounded(self): | ||
# Runs an optimisation without boundaries. | ||
r, x, s, b = self.problem() | ||
opt = pints.OptimisationController(r, x, method=method) | ||
opt.set_threshold(1e-3) | ||
opt.set_log_to_screen(debug) | ||
found_parameters, found_solution = opt.run() | ||
self.assertTrue(found_solution < 1e-3) | ||
|
||
def test_bounded(self): | ||
# Runs an optimisation with boundaries. | ||
r, x, s, b = self.problem() | ||
|
||
# Rectangular boundaries | ||
b = pints.RectangularBoundaries([-1, -1], [1, 1]) | ||
opt = pints.OptimisationController(r, x, boundaries=b, method=method) | ||
opt.set_log_to_screen(debug) | ||
found_parameters, found_solution = opt.run() | ||
self.assertTrue(found_solution < 1e-3) | ||
|
||
# Circular boundaries | ||
# Start near edge, to increase chance of out-of-bounds occurring. | ||
b = CircularBoundaries([0, 0], 1) | ||
x = [0.99, 0] | ||
opt = pints.OptimisationController(r, x, boundaries=b, method=method) | ||
opt.set_log_to_screen(debug) | ||
found_parameters, found_solution = opt.run() | ||
self.assertTrue(found_solution < 1e-3) | ||
|
||
def test_bounded_and_sigma(self): | ||
# Runs an optimisation without boundaries and sigma. | ||
r, x, s, b = self.problem() | ||
opt = pints.OptimisationController(r, x, s, b, method) | ||
opt.set_threshold(1e-3) | ||
opt.set_log_to_screen(debug) | ||
found_parameters, found_solution = opt.run() | ||
self.assertTrue(found_solution < 1e-3) | ||
|
||
def test_ask_tell(self): | ||
# Tests ask-and-tell related error handling. | ||
r, x, s, b = self.problem() | ||
opt = method(x) | ||
|
||
# Stop called when not running | ||
self.assertFalse(opt.running()) | ||
self.assertFalse(opt.stop()) | ||
|
||
# Best position and score called before run | ||
self.assertEqual(list(opt.xbest()), list(x)) | ||
self.assertEqual(opt.fbest(), float('inf')) | ||
|
||
# Tell before ask | ||
self.assertRaisesRegex( | ||
Exception, r'ask\(\) not called before tell\(\)', opt.tell, 5) | ||
|
||
# Ask | ||
opt.ask() | ||
|
||
# Now we should be running | ||
self.assertTrue(opt.running()) | ||
|
||
def test_hyper_parameter_interface(self): | ||
# Tests the hyper parameter interface for this optimiser. | ||
r, x, s, b = self.problem() | ||
opt = pints.OptimisationController(r, x, method=method) | ||
m = opt.optimiser() | ||
self.assertEqual(m.n_hyper_parameters(), 1) | ||
eta = m.learning_rate() * 2 | ||
m.set_hyper_parameters([eta]) | ||
self.assertEqual(m.learning_rate(), eta) | ||
self.assertRaisesRegex( | ||
ValueError, 'greater than zero', m.set_hyper_parameters, [0]) | ||
|
||
def test_name(self): | ||
# Test the name() method. | ||
opt = method(np.array([0, 1.01])) | ||
self.assertIn('radient descent', opt.name()) | ||
|
||
|
||
if __name__ == '__main__': | ||
print('Add -v for more debug output') | ||
import sys | ||
if '-v' in sys.argv: | ||
debug = True | ||
import logging | ||
logging.basicConfig(level=logging.DEBUG) | ||
unittest.main() |
Oops, something went wrong.