Skip to content

Commit

Permalink
0.9.3-beta release (neuralhydrology#3)
Browse files Browse the repository at this point in the history
MTS-LSTM: fix shared_mtslstm config argument
  • Loading branch information
gauchm authored Oct 7, 2020
1 parent 7f3b0b8 commit f1ff185
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion neuralhydrology/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.9.2-beta"
__version__ = "0.9.3-beta"
16 changes: 8 additions & 8 deletions neuralhydrology/modelzoo/mtslstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def __init__(self, cfg: Config):
# start to count the number of inputs
input_sizes = len(cfg.camels_attributes + cfg.hydroatlas_attributes + cfg.static_inputs)

# if not is_shared_mtslstm, the LSTM gets an additional frequency flag as input.
if not self._is_shared_mtslstm:
# if is_shared_mtslstm, the LSTM gets an additional frequency flag as input.
if self._is_shared_mtslstm:
input_sizes += len(self._frequencies)

if cfg.use_basin_id_encoding:
Expand All @@ -76,8 +76,8 @@ def __init__(self, cfg: Config):
if isinstance(cfg.dynamic_inputs, list):
input_sizes = {freq: input_sizes + len(cfg.dynamic_inputs) for freq in self._frequencies}
else:
if not self._is_shared_mtslstm:
raise ValueError(f'Different inputs not allowed if shared_mtslstm is False.')
if self._is_shared_mtslstm:
raise ValueError(f'Different inputs not allowed if shared_mtslstm is used.')
input_sizes = {freq: input_sizes + len(cfg.dynamic_inputs[freq]) for freq in self._frequencies}

if not isinstance(cfg.hidden_size, dict):
Expand All @@ -86,11 +86,11 @@ def __init__(self, cfg: Config):
else:
self._hidden_size = cfg.hidden_size

if (not self._is_shared_mtslstm
if (self._is_shared_mtslstm
or self._transfer_mtslstm_states["h"] == "identity"
or self._transfer_mtslstm_states["c"] == "identity") \
and any(size != self._hidden_size[self._frequencies[0]] for size in self._hidden_size.values()):
raise ValueError("All hidden sizes must be equal if shared_mtslstm=False or state transfer=identity.")
raise ValueError("All hidden sizes must be equal if shared_mtslstm is used or state transfer=identity.")

# create layer depending on selected frequencies
self._init_modules(input_sizes)
Expand All @@ -107,7 +107,7 @@ def _init_modules(self, input_sizes: Dict[str, int]):
for idx, freq in enumerate(self._frequencies):
freq_input_size = input_sizes[freq]

if not self._is_shared_mtslstm and idx > 0:
if self._is_shared_mtslstm and idx > 0:
self.lstms[freq] = self.lstms[self._frequencies[idx - 1]] # same LSTM for all frequencies.
self.heads[freq] = self.heads[self._frequencies[idx - 1]] # same head for all frequencies.
else:
Expand Down Expand Up @@ -162,7 +162,7 @@ def _prepare_inputs(self, data: Dict[str, torch.Tensor], freq: str) -> torch.Ten
else:
pass

if not self._is_shared_mtslstm:
if self._is_shared_mtslstm:
# add frequency one-hot encoding
idx = self._frequencies.index(freq)
one_hot_freq = torch.zeros(x_d.shape[0], x_d.shape[1], len(self._frequencies)).to(x_d)
Expand Down
2 changes: 1 addition & 1 deletion neuralhydrology/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def seq_length(self) -> Union[int, Dict[str, int]]:

@property
def shared_mtslstm(self) -> bool:
return self._cfg.get("shared_mtslstm", True)
return self._cfg.get("shared_mtslstm", False)

@property
def static_inputs(self) -> List[str]:
Expand Down

0 comments on commit f1ff185

Please sign in to comment.