Skip to content

Commit

Permalink
Whe tubulet_size > 1, the number of temporal patches must be properly…
Browse files Browse the repository at this point in the history
… evaluated

Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
  • Loading branch information
Joao-L-S-Almeida committed Aug 16, 2024
1 parent f2d1d58 commit 4230854
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions terratorch/models/backbones/vit_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def __init__(
# MAE encoder specifics
self.patch_embed = PatchEmbed(pretrain_img_size, patch_size, num_frames, tubelet_size, in_chans, embed_dim)
self.patch_size = patch_size
self.tubulet_size = tubulet_size
self.feature_info = []
self.in_chans = in_chans
self.num_frames = num_frames
Expand Down Expand Up @@ -347,7 +348,7 @@ def forward_encoder(
x = x.reshape(-1, self.in_chans, 1, *x.shape[-2:])
t, h, w = x.shape[-3:]
x = self.patch_embed(x)
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(self.embed_dim, (t, h // self.patch_size, w // self.patch_size), cls_token=True)).to(
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(self.embed_dim, (t // self.tubulet_size, h // self.patch_size, w // self.patch_size), cls_token=True)).to(
x
)
# add pos embed w/o cls token
Expand Down Expand Up @@ -378,7 +379,7 @@ def forward_decoder(self, x: torch.Tensor, ids_restore: torch.Tensor, dim_info:
x = self.decoder_embed(x)
t, h, w = dim_info
decoder_pos_embed = torch.from_numpy(
get_3d_sincos_pos_embed(self.decoder_embed_dim, (t, h // self.patch_size, w // self.patch_size), cls_token=True)
get_3d_sincos_pos_embed(self.decoder_embed_dim, (t // self.tubulet_size, h // self.patch_size, w // self.patch_size), cls_token=True)
).to(x)

# append mask tokens to sequence
Expand Down Expand Up @@ -436,7 +437,7 @@ def forward_features(self, x) -> list[torch.Tensor]:
t, h, w = x.shape[-3:]
# embed patches
x = self.patch_embed(x)
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(self.embed_dim, (t, h // self.patch_size, w // self.patch_size), cls_token=True)).to(
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(self.embed_dim, (t // self.tubulet_size, h // self.patch_size, w // self.patch_size), cls_token=True)).to(
x
)
# add pos embed w/o cls token
Expand Down

0 comments on commit 4230854

Please sign in to comment.