Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend/bands #106

Merged
merged 6 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion terratorch/models/backbones/prithvi_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from terratorch.datasets.utils import HLSBands
from terratorch.models.backbones.prithvi_select_patch_embed_weights import prithvi_select_patch_embed_weights
from terratorch.models.backbones.swin_encoder_decoder import MMSegSwinTransformer
from terratorch.models.backbones.utils import _estimate_in_chans

PRETRAINED_BANDS = [
HLSBands.BLUE,
Expand Down Expand Up @@ -174,7 +175,8 @@ def _create_swin_mmseg_transformer(
# the current swin model is not multitemporal
if "num_frames" in kwargs:
kwargs = {k: v for k, v in kwargs.items() if k != "num_frames"}
kwargs["in_chans"] = len(model_bands)

kwargs["in_chans"] = _estimate_in_chans(model_bands=model_bands)

def checkpoint_filter_wrapper_fn(state_dict, model):
return checkpoint_filter_fn(state_dict, model, pretrained_bands, model_bands)
Expand Down
3 changes: 2 additions & 1 deletion terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from terratorch.datasets import HLSBands
from terratorch.models.backbones.prithvi_select_patch_embed_weights import prithvi_select_patch_embed_weights
from terratorch.models.backbones.vit_encoder_decoder import TemporalViTEncoder
from terratorch.models.backbones.utils import _estimate_in_chans

PRETRAINED_BANDS = [
HLSBands.BLUE,
Expand Down Expand Up @@ -81,7 +82,7 @@ def _create_prithvi(
if "features_only" in kwargs:
kwargs = {k: v for k, v in kwargs.items() if k != "features_only"}

kwargs["in_chans"] = len(model_bands)
kwargs["in_chans"] = _estimate_in_chans(model_bands=model_bands)

def checkpoint_filter_wrapper_fn(state_dict, model):
return checkpoint_filter_fn(state_dict, model, pretrained_bands, model_bands)
Expand Down
35 changes: 35 additions & 0 deletions terratorch/models/backbones/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from terratorch.datasets import HLSBands

def _are_sublists_of_int(item) -> (bool, bool):

if all([isinstance(i, list) for i in item]):
if all([isinstance(i, int) for i in sum(item, [])]):
return True, True
else:
raise Exception(f"It's expected sublists be [int, int], but rceived {model_bands}")
elif len(item) == 2 and type(item[0]) == type(item[1]) == int:
return False, True
else:
return False, False

def _estimate_in_chans(model_bands: list[HLSBands] | list[str] | tuple[int, int] = None) -> int:

# Conditional to deal with the different possible choices for the bands
# Bands as lists of strings or enum
is_sublist, requires_special_eval = _are_sublists_of_int(model_bands)

# Bands as intervals limited by integers
# The bands numbering follows the Python convention (starts with 0)
# and includes the extrema (so the +1 in order to include the last band)
if requires_special_eval:

if is_sublist:
in_chans = sum([i[-1] - i[0] + 1 for i in model_bands])
else:
in_chans = model_bands[-1] - model_bands[0] + 1
else:
in_chans = len(model_bands)

return in_chans


1 change: 0 additions & 1 deletion terratorch/tasks/segmentation_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -
x = batch["image"]
y = batch["mask"]
model_output: ModelOutput = self(x)

loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
y_hat_hard = to_segmentation_prediction(model_output)
Expand Down