From 16ee03e1ee4dc8d607987446e2b06bc0552404ff Mon Sep 17 00:00:00 2001 From: Moritz Hoffmann Date: Wed, 7 Sep 2022 15:44:49 +0200 Subject: [PATCH] var_cutoff can be disabled in covariance koopman models. (#255) --- deeptime/decomposition/_koopman.py | 9 +++++---- tests/decomposition/test_tica.py | 17 +++++++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/deeptime/decomposition/_koopman.py b/deeptime/decomposition/_koopman.py index 16deb5f01..8ea5e4caf 100644 --- a/deeptime/decomposition/_koopman.py +++ b/deeptime/decomposition/_koopman.py @@ -344,15 +344,16 @@ def var_cutoff(self) -> Optional[float]: precedence over the :meth:`dim` parameter. :getter: Yields the current variance cutoff. - :setter: Sets a new variance cutoff + :setter: Sets a new variance cutoff or disables variance cutoff by setting the value to `None`. :type: float or None """ return self._var_cutoff @var_cutoff.setter - def var_cutoff(self, value): - assert 0 < value <= 1., "Invalid dimension parameter, if it is given in terms of a variance cutoff, " \ - "it can only be in the interval (0, 1]." + def var_cutoff(self, value: Optional[float]): + assert value is None or 0 < value <= 1., \ + "Invalid dimension parameter, if it is given in terms of a variance cutoff, " \ + "it can only be in the interval (0, 1]." self._var_cutoff = value self._update_output_dimension() diff --git a/tests/decomposition/test_tica.py b/tests/decomposition/test_tica.py index 8f9ad2fd6..ad400e5b1 100644 --- a/tests/decomposition/test_tica.py +++ b/tests/decomposition/test_tica.py @@ -8,6 +8,7 @@ import numpy as np import pytest +from numpy.testing import assert_, assert_equal from deeptime.covariance import Covariance from deeptime.data import ellipsoids @@ -16,6 +17,22 @@ from deeptime.numeric import ZeroRankError +def test_update_projection_dimension(): + # tests for https://github.com/deeptime-ml/deeptime/issues/254 + data = np.random.normal(size=(1000, 50)) + model = TICA(lagtime=1, var_cutoff=.1).fit_fetch(data) + assert_equal(model.var_cutoff, .1) + assert_(model.transform(data).shape[1] <= 10) + model.var_cutoff = None + assert_equal(model.var_cutoff, None) + model.dim = 5 + assert_equal(model.dim, 5) + assert_(model.transform(data).shape[1] <= 5) + model.dim = 1 + assert_equal(model.dim, 1) + assert_equal(model.transform(data).shape[1], 1) + + def test_fit_reset(): lag = 100 np.random.seed(0)