Skip to content

Commit

Permalink
More hardcoded values removed
Browse files Browse the repository at this point in the history
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
  • Loading branch information
Joao-L-S-Almeida committed Aug 16, 2024
1 parent 7394f7f commit f2d1d58
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions terratorch/models/backbones/vit_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def forward_encoder(
x = x.reshape(-1, self.in_chans, 1, *x.shape[-2:])
t, h, w = x.shape[-3:]
x = self.patch_embed(x)
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(self.embed_dim, (t, h // 16, w // 16), cls_token=True)).to(
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(self.embed_dim, (t, h // self.patch_size, w // self.patch_size), cls_token=True)).to(
x
)
# add pos embed w/o cls token
Expand Down Expand Up @@ -378,7 +378,7 @@ def forward_decoder(self, x: torch.Tensor, ids_restore: torch.Tensor, dim_info:
x = self.decoder_embed(x)
t, h, w = dim_info
decoder_pos_embed = torch.from_numpy(
get_3d_sincos_pos_embed(self.decoder_embed_dim, (t, h // 16, w // 16), cls_token=True)
get_3d_sincos_pos_embed(self.decoder_embed_dim, (t, h // self.patch_size, w // self.patch_size), cls_token=True)
).to(x)

# append mask tokens to sequence
Expand Down

0 comments on commit f2d1d58

Please sign in to comment.