diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index 2eb78236..7fee6027 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -283,7 +283,7 @@ def __init__(self, for i in range(depth): self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)) self.feature_info.append( - {"num_chs": embed_dim * self.patch_embed.patch_size[0], "reduction": 1, "module": f"blocks.{i}"} + {"num_chs": embed_dim * self.patch_embed.grid_size[0], "reduction": 1, "module": f"blocks.{i}"} ) self.blocks = nn.ModuleList(self.blocks) @@ -418,7 +418,7 @@ def forward_features( x = x + pos_embed[:, 1:, :] if self.temporal_encoding: - num_tokens_per_frame = x.shape[1] // self.patch_embed.num_frames + num_tokens_per_frame = x.shape[1] // self.num_frames temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame) x = x + temporal_encoding if self.location_encoding: diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index b56994c2..6bee25fa 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -27,13 +27,25 @@ default_cfgs = generate_default_cfgs( { "prithvi_vit_100": { - "hf_hub_id": "ibm-nasa-geospatial/Prithvi-100M", - "hf_hub_filename": "Prithvi_100M.pt", + "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-1.0-100M", + "hf_hub_filename": "Prithvi_EO_V1_100M.pt", + }, + "prithvi_eo_v2_300": { + "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-300M", + "hf_hub_filename": "Prithvi_EO_V2_300M.pt", + }, + "prithvi_eo_v2_300_tl": { + "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL", + "hf_hub_filename": "Prithvi_EO_V2_300M_TL.pt", + }, + "prithvi_eo_v2_600": { + "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-600M", + "hf_hub_filename": "Prithvi_EO_V2_600M.pt", + }, + "prithvi_eo_v2_600_tl": { + "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-600M-TL", + "hf_hub_filename": "Prithvi_EO_V2_600M_TL.pt", }, - "prithvi_eo_v2_300": {}, - "prithvi_eo_v2_300_tl": {}, - "prithvi_eo_v2_600": {}, - "prithvi_eo_v2_600_tl": {}, "prithvi_vit_tiny": {} } )