From f1ff1856984233db0efcddf524189460aeaf3787 Mon Sep 17 00:00:00 2001 From: Martin Gauch <15731649+gauchm@users.noreply.github.com> Date: Wed, 7 Oct 2020 16:49:21 +0200 Subject: [PATCH] 0.9.3-beta release (#3) MTS-LSTM: fix shared_mtslstm config argument --- neuralhydrology/__about__.py | 2 +- neuralhydrology/modelzoo/mtslstm.py | 16 ++++++++-------- neuralhydrology/utils/config.py | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/neuralhydrology/__about__.py b/neuralhydrology/__about__.py index bbd0a5bf..363306ca 100644 --- a/neuralhydrology/__about__.py +++ b/neuralhydrology/__about__.py @@ -1 +1 @@ -__version__ = "0.9.2-beta" +__version__ = "0.9.3-beta" diff --git a/neuralhydrology/modelzoo/mtslstm.py b/neuralhydrology/modelzoo/mtslstm.py index d2fd9587..80b4829b 100644 --- a/neuralhydrology/modelzoo/mtslstm.py +++ b/neuralhydrology/modelzoo/mtslstm.py @@ -64,8 +64,8 @@ def __init__(self, cfg: Config): # start to count the number of inputs input_sizes = len(cfg.camels_attributes + cfg.hydroatlas_attributes + cfg.static_inputs) - # if not is_shared_mtslstm, the LSTM gets an additional frequency flag as input. - if not self._is_shared_mtslstm: + # if is_shared_mtslstm, the LSTM gets an additional frequency flag as input. + if self._is_shared_mtslstm: input_sizes += len(self._frequencies) if cfg.use_basin_id_encoding: @@ -76,8 +76,8 @@ def __init__(self, cfg: Config): if isinstance(cfg.dynamic_inputs, list): input_sizes = {freq: input_sizes + len(cfg.dynamic_inputs) for freq in self._frequencies} else: - if not self._is_shared_mtslstm: - raise ValueError(f'Different inputs not allowed if shared_mtslstm is False.') + if self._is_shared_mtslstm: + raise ValueError(f'Different inputs not allowed if shared_mtslstm is used.') input_sizes = {freq: input_sizes + len(cfg.dynamic_inputs[freq]) for freq in self._frequencies} if not isinstance(cfg.hidden_size, dict): @@ -86,11 +86,11 @@ def __init__(self, cfg: Config): else: self._hidden_size = cfg.hidden_size - if (not self._is_shared_mtslstm + if (self._is_shared_mtslstm or self._transfer_mtslstm_states["h"] == "identity" or self._transfer_mtslstm_states["c"] == "identity") \ and any(size != self._hidden_size[self._frequencies[0]] for size in self._hidden_size.values()): - raise ValueError("All hidden sizes must be equal if shared_mtslstm=False or state transfer=identity.") + raise ValueError("All hidden sizes must be equal if shared_mtslstm is used or state transfer=identity.") # create layer depending on selected frequencies self._init_modules(input_sizes) @@ -107,7 +107,7 @@ def _init_modules(self, input_sizes: Dict[str, int]): for idx, freq in enumerate(self._frequencies): freq_input_size = input_sizes[freq] - if not self._is_shared_mtslstm and idx > 0: + if self._is_shared_mtslstm and idx > 0: self.lstms[freq] = self.lstms[self._frequencies[idx - 1]] # same LSTM for all frequencies. self.heads[freq] = self.heads[self._frequencies[idx - 1]] # same head for all frequencies. else: @@ -162,7 +162,7 @@ def _prepare_inputs(self, data: Dict[str, torch.Tensor], freq: str) -> torch.Ten else: pass - if not self._is_shared_mtslstm: + if self._is_shared_mtslstm: # add frequency one-hot encoding idx = self._frequencies.index(freq) one_hot_freq = torch.zeros(x_d.shape[0], x_d.shape[1], len(self._frequencies)).to(x_d) diff --git a/neuralhydrology/utils/config.py b/neuralhydrology/utils/config.py index 98392af1..591655c7 100644 --- a/neuralhydrology/utils/config.py +++ b/neuralhydrology/utils/config.py @@ -468,7 +468,7 @@ def seq_length(self) -> Union[int, Dict[str, int]]: @property def shared_mtslstm(self) -> bool: - return self._cfg.get("shared_mtslstm", True) + return self._cfg.get("shared_mtslstm", False) @property def static_inputs(self) -> List[str]: