Skip to content

Commit

Permalink
Update CudnnLstmModel.py
Browse files Browse the repository at this point in the history
Updated for new torch versions
  • Loading branch information
Yalan-Song authored Oct 26, 2022
1 parent 57f9a80 commit 17f0be2
Showing 1 changed file with 9 additions and 14 deletions.
23 changes: 9 additions & 14 deletions hydroDL/model/rnn/CudnnLstmModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,10 @@ def __init__(self, *, nx, ny, hiddenSize, dr=0.5, warmUpDay=None):
self.ct = 0
self.nLayer = 1
self.linearIn = torch.nn.Linear(nx, hiddenSize)
if torch.__version__ > "1.9":
# 2021-10-24. SCP: incorporate newer version of torch LSTM to avoid "weights not contiguous on memory" issue
self.lstm = torch.nn.LSTM(hiddenSize, hiddenSize, 2, dropout=dr)
else:
self.lstm = rnn.CudnnLstm(
inputSize=hiddenSize, hiddenSize=hiddenSize, dr=dr
)

self.lstm = rnn.CudnnLstm(
inputSize=hiddenSize, hiddenSize=hiddenSize, dr=dr
)
self.linearOut = torch.nn.Linear(hiddenSize, ny)
self.gpu = 1
self.name = "CudnnLstmModel"
Expand All @@ -46,12 +43,10 @@ def forward(self, x, doDropMC=False, dropoutFalse=False):
x, warmUpDay = self.extend_day(x, warmUpDay=self.warmUpDay)

x0 = F.relu(self.linearIn(x))
if torch.__version__ > "1.9":
outLSTM, (hn, cn) = self.lstm(x0)
else:
outLSTM, (hn, cn) = self.lstm(
x0, doDropMC=doDropMC, dropoutFalse=dropoutFalse
)

outLSTM, (hn, cn) = self.lstm(
x0, doDropMC=doDropMC, dropoutFalse=dropoutFalse
)
# outLSTMdr = self.drtest(outLSTM)
out = self.linearOut(outLSTM)

Expand All @@ -69,4 +64,4 @@ def extend_day(self, x, warm_up_day):

def reduce_day(self, x, warm_up_day):
x = x[warm_up_day:,:,:]
return x
return x

0 comments on commit 17f0be2

Please sign in to comment.