diff --git a/sparse_autoencoder/optimizer/adam_with_reset.py b/sparse_autoencoder/optimizer/adam_with_reset.py index 5a84c941..dba88c66 100644 --- a/sparse_autoencoder/optimizer/adam_with_reset.py +++ b/sparse_autoencoder/optimizer/adam_with_reset.py @@ -8,7 +8,8 @@ from torch import Tensor from torch.nn.parameter import Parameter from torch.optim import Adam -from torch.optim.optimizer import params_t +try: from torch.optim.optimizer import params_t +except ImportError: from torch.optim.optimizer import ParamsT as params_t from sparse_autoencoder.tensor_types import Axis