Skip to content

Commit

Permalink
var_cutoff can be disabled in covariance koopman models. (#255)
Browse files Browse the repository at this point in the history
  • Loading branch information
clonker authored Sep 7, 2022
1 parent 153d9f6 commit 16ee03e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
9 changes: 5 additions & 4 deletions deeptime/decomposition/_koopman.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,15 +344,16 @@ def var_cutoff(self) -> Optional[float]:
precedence over the :meth:`dim` parameter.
:getter: Yields the current variance cutoff.
:setter: Sets a new variance cutoff
:setter: Sets a new variance cutoff or disables variance cutoff by setting the value to `None`.
:type: float or None
"""
return self._var_cutoff

@var_cutoff.setter
def var_cutoff(self, value):
assert 0 < value <= 1., "Invalid dimension parameter, if it is given in terms of a variance cutoff, " \
"it can only be in the interval (0, 1]."
def var_cutoff(self, value: Optional[float]):
assert value is None or 0 < value <= 1., \
"Invalid dimension parameter, if it is given in terms of a variance cutoff, " \
"it can only be in the interval (0, 1]."
self._var_cutoff = value
self._update_output_dimension()

Expand Down
17 changes: 17 additions & 0 deletions tests/decomposition/test_tica.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import pytest
from numpy.testing import assert_, assert_equal

from deeptime.covariance import Covariance
from deeptime.data import ellipsoids
Expand All @@ -16,6 +17,22 @@
from deeptime.numeric import ZeroRankError


def test_update_projection_dimension():
# tests for https://github.com/deeptime-ml/deeptime/issues/254
data = np.random.normal(size=(1000, 50))
model = TICA(lagtime=1, var_cutoff=.1).fit_fetch(data)
assert_equal(model.var_cutoff, .1)
assert_(model.transform(data).shape[1] <= 10)
model.var_cutoff = None
assert_equal(model.var_cutoff, None)
model.dim = 5
assert_equal(model.dim, 5)
assert_(model.transform(data).shape[1] <= 5)
model.dim = 1
assert_equal(model.dim, 1)
assert_equal(model.transform(data).shape[1], 1)


def test_fit_reset():
lag = 100
np.random.seed(0)
Expand Down

0 comments on commit 16ee03e

Please sign in to comment.