Skip to content

Commit

Permalink
Update crit.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Yalan-Song authored Dec 1, 2023
1 parent 8d87998 commit c39a815
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions hydroDL/model/crit.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def forward(self, output, target):
class NSELossBatch(torch.nn.Module):
# Same as Fredrick 2019, batch NSE loss
# stdarray: the standard deviation of the runoff for all basins
def __init__(self, stdarray, eps=0.1,device = torch.cuda.current_device()):
def __init__(self, stdarray, eps=0.1,device = 'cpu'):
super(NSELossBatch, self).__init__()
self.std = stdarray
self.eps = eps
Expand Down Expand Up @@ -222,7 +222,7 @@ def forward(self, output, target, igrid):
class NSESqrtLossBatch(torch.nn.Module):
# Same as Fredrick 2019, batch NSE loss, use RMSE and STD instead
# stdarray: the standard deviation of the runoff for all basins
def __init__(self, stdarray, eps=0.1,device = torch.cuda.current_device()):
def __init__(self, stdarray, eps=0.1,device = 'cpu'):
super(NSESqrtLossBatch, self).__init__()
self.std = stdarray
self.eps = eps
Expand Down

0 comments on commit c39a815

Please sign in to comment.