Skip to content

Commit

Permalink
Merge pull request #273 from blumenstiel/main
Browse files Browse the repository at this point in the history
Update prithvi model
  • Loading branch information
romeokienzler authored Dec 3, 2024
2 parents 6ca6238 + fd5ea25 commit 5247667
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
4 changes: 2 additions & 2 deletions terratorch/models/backbones/prithvi_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def __init__(self,
for i in range(depth):
self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer))
self.feature_info.append(
{"num_chs": embed_dim * self.patch_embed.patch_size[0], "reduction": 1, "module": f"blocks.{i}"}
{"num_chs": embed_dim * self.patch_embed.grid_size[0], "reduction": 1, "module": f"blocks.{i}"}
)
self.blocks = nn.ModuleList(self.blocks)

Expand Down Expand Up @@ -418,7 +418,7 @@ def forward_features(
x = x + pos_embed[:, 1:, :]

if self.temporal_encoding:
num_tokens_per_frame = x.shape[1] // self.patch_embed.num_frames
num_tokens_per_frame = x.shape[1] // self.num_frames
temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
x = x + temporal_encoding
if self.location_encoding:
Expand Down
24 changes: 18 additions & 6 deletions terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,25 @@
default_cfgs = generate_default_cfgs(
{
"prithvi_vit_100": {
"hf_hub_id": "ibm-nasa-geospatial/Prithvi-100M",
"hf_hub_filename": "Prithvi_100M.pt",
"hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-1.0-100M",
"hf_hub_filename": "Prithvi_EO_V1_100M.pt",
},
"prithvi_eo_v2_300": {
"hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-300M",
"hf_hub_filename": "Prithvi_EO_V2_300M.pt",
},
"prithvi_eo_v2_300_tl": {
"hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
"hf_hub_filename": "Prithvi_EO_V2_300M_TL.pt",
},
"prithvi_eo_v2_600": {
"hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-600M",
"hf_hub_filename": "Prithvi_EO_V2_600M.pt",
},
"prithvi_eo_v2_600_tl": {
"hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-600M-TL",
"hf_hub_filename": "Prithvi_EO_V2_600M_TL.pt",
},
"prithvi_eo_v2_300": {},
"prithvi_eo_v2_300_tl": {},
"prithvi_eo_v2_600": {},
"prithvi_eo_v2_600_tl": {},
"prithvi_vit_tiny": {}
}
)
Expand Down

0 comments on commit 5247667

Please sign in to comment.