Skip to content

Commit

Permalink
Prithvi add loss dict
Browse files Browse the repository at this point in the history
Signed-off-by: Benedikt Blumenstiel <[email protected]>
  • Loading branch information
blumenstiel committed Feb 21, 2025
1 parent e282b09 commit f6e6b70
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions terratorch/models/backbones/prithvi_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,7 @@ def forward(
latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio)
pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape)
loss = self.forward_loss(pixel_values, pred, mask)
loss = {'loss': loss} # TerraTorch expects loss dict
return loss, pred, mask

def forward_features(
Expand Down

0 comments on commit f6e6b70

Please sign in to comment.