From ee17c9696f31fd91b7bf1dcc81139db23188ba4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Thu, 8 Aug 2024 12:38:36 -0300 Subject: [PATCH 1/6] Logging MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/tasks/segmentation_tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index 5f1bcb10..6e54dc4e 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -254,7 +254,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - x = batch["image"] y = batch["mask"] model_output: ModelOutput = self(x) - + print(x.shape, model_output.output.shape, y.shape) loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) y_hat_hard = to_segmentation_prediction(model_output) From 233a5c28a1a361991844ef5283045b0cb0c582e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 9 Aug 2024 10:49:00 -0300 Subject: [PATCH 2/6] The number of channels can be estimated in different ways according to the kind of bands definition MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/backbones/prithvi_swin.py | 4 +++- terratorch/models/backbones/prithvi_vit.py | 3 ++- terratorch/tasks/segmentation_tasks.py | 1 - 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/terratorch/models/backbones/prithvi_swin.py b/terratorch/models/backbones/prithvi_swin.py index 2587a545..b09137a8 100644 --- a/terratorch/models/backbones/prithvi_swin.py +++ b/terratorch/models/backbones/prithvi_swin.py @@ -18,6 +18,7 @@ from terratorch.datasets.utils import HLSBands from terratorch.models.backbones.prithvi_select_patch_embed_weights import prithvi_select_patch_embed_weights from terratorch.models.backbones.swin_encoder_decoder import MMSegSwinTransformer +from terratorch.models.backbones.utils import _estimate_in_chans PRETRAINED_BANDS = [ HLSBands.BLUE, @@ -174,7 +175,8 @@ def _create_swin_mmseg_transformer( # the current swin model is not multitemporal if "num_frames" in kwargs: kwargs = {k: v for k, v in kwargs.items() if k != "num_frames"} - kwargs["in_chans"] = len(model_bands) + + kwargs["in_chans"] = _estimate_in_chans(model_bands=model_bands) def checkpoint_filter_wrapper_fn(state_dict, model): return checkpoint_filter_fn(state_dict, model, pretrained_bands, model_bands) diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index 03691307..a722d674 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -14,6 +14,7 @@ from terratorch.datasets import HLSBands from terratorch.models.backbones.prithvi_select_patch_embed_weights import prithvi_select_patch_embed_weights from terratorch.models.backbones.vit_encoder_decoder import TemporalViTEncoder +from terratorch.models.backbones.utils import _estimate_in_chans PRETRAINED_BANDS = [ HLSBands.BLUE, @@ -81,7 +82,7 @@ def _create_prithvi( if "features_only" in kwargs: kwargs = {k: v for k, v in kwargs.items() if k != "features_only"} - kwargs["in_chans"] = len(model_bands) + kwargs["in_chans"] = _estimate_in_chans(model_bands=model_bands) def checkpoint_filter_wrapper_fn(state_dict, model): return checkpoint_filter_fn(state_dict, model, pretrained_bands, model_bands) diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index 6e54dc4e..5f123351 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -254,7 +254,6 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - x = batch["image"] y = batch["mask"] model_output: ModelOutput = self(x) - print(x.shape, model_output.output.shape, y.shape) loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) y_hat_hard = to_segmentation_prediction(model_output) From 6bd2bde9c87f1af63587d2baa58f0fbb7051608d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 9 Aug 2024 11:27:30 -0300 Subject: [PATCH 3/6] terratorch.models.backbones.utils MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/backbones/utils.py | 31 ++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 terratorch/models/backbones/utils.py diff --git a/terratorch/models/backbones/utils.py b/terratorch/models/backbones/utils.py new file mode 100644 index 00000000..8945fdfd --- /dev/null +++ b/terratorch/models/backbones/utils.py @@ -0,0 +1,31 @@ +from terratorch.datasets import HLSBands + +def _are_sublists_of_int(item) -> bool: + + if all([isinstance(i, list) for i in item]): + if all([isinstance(i, int) for i in sum(item, [])]): + return True + else: + return False + else: + return False + +def _estimate_in_chans(model_bands: list[HLSBands] | list[str] | tuple[int, int] = None) -> int: + + # Conditional to deal with the different possible choices for the bands + # Bands as lists of strings or enum + if all([isinstance(b, str) for b in model_bands]): + in_chans = len(model_bands) + # Bands as intervals limited by integers + elif all([isinstance(b, int) for b in model_bands] or _are_sublists_of_int(model_bands)): + + if _are_sublists_of_int(model_bands): + in_chans = sum([i[-1] - i[0] for i in model_bands]) + else: + in_chans = model_bands[-1] - model_bands[0] + else: + raise Exception(f"Expected bands to be list(str) or [int, int] but received {model_bands}") + + return in_chans + + From 52f0c12baf867c2bd7900211899ea05f7e4dc604 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 9 Aug 2024 12:33:57 -0300 Subject: [PATCH 4/6] HLSBands also need be considered MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/backbones/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/terratorch/models/backbones/utils.py b/terratorch/models/backbones/utils.py index 8945fdfd..06e4f462 100644 --- a/terratorch/models/backbones/utils.py +++ b/terratorch/models/backbones/utils.py @@ -14,7 +14,7 @@ def _estimate_in_chans(model_bands: list[HLSBands] | list[str] | tuple[int, int] # Conditional to deal with the different possible choices for the bands # Bands as lists of strings or enum - if all([isinstance(b, str) for b in model_bands]): + if all([isinstance(b, str) or isinstance(b, HLSBands) for b in model_bands]): in_chans = len(model_bands) # Bands as intervals limited by integers elif all([isinstance(b, int) for b in model_bands] or _are_sublists_of_int(model_bands)): From 3d9b5f57a0e0038ac31c7941dab9f8a7573f2256 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 9 Aug 2024 13:17:11 -0300 Subject: [PATCH 5/6] Better conditional for all the kinds of bands MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/backbones/utils.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/terratorch/models/backbones/utils.py b/terratorch/models/backbones/utils.py index 06e4f462..e415e9b4 100644 --- a/terratorch/models/backbones/utils.py +++ b/terratorch/models/backbones/utils.py @@ -1,31 +1,33 @@ from terratorch.datasets import HLSBands -def _are_sublists_of_int(item) -> bool: +def _are_sublists_of_int(item) -> (bool, bool): if all([isinstance(i, list) for i in item]): if all([isinstance(i, int) for i in sum(item, [])]): - return True + return True, True else: - return False + raise Exception(f"It's expected sublists be [int, int], but rceived {model_bands}") + elif len(item) == 2 and type(item[0]) == type(item[1]) == int: + return False, True else: - return False + return False, False def _estimate_in_chans(model_bands: list[HLSBands] | list[str] | tuple[int, int] = None) -> int: # Conditional to deal with the different possible choices for the bands # Bands as lists of strings or enum - if all([isinstance(b, str) or isinstance(b, HLSBands) for b in model_bands]): - in_chans = len(model_bands) + is_sublist, requires_special_eval = _are_sublists_of_int(model_bands) + # Bands as intervals limited by integers - elif all([isinstance(b, int) for b in model_bands] or _are_sublists_of_int(model_bands)): + if requires_special_eval: - if _are_sublists_of_int(model_bands): + if is_sublist: in_chans = sum([i[-1] - i[0] for i in model_bands]) else: in_chans = model_bands[-1] - model_bands[0] else: - raise Exception(f"Expected bands to be list(str) or [int, int] but received {model_bands}") - + in_chans = len(model_bands) + return in_chans From ec8125ee448b471b5897d805e27cde975c96e2f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 16 Aug 2024 14:09:48 -0300 Subject: [PATCH 6/6] Intervals must include the upper extreme MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/models/backbones/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/terratorch/models/backbones/utils.py b/terratorch/models/backbones/utils.py index e415e9b4..cb680f76 100644 --- a/terratorch/models/backbones/utils.py +++ b/terratorch/models/backbones/utils.py @@ -19,12 +19,14 @@ def _estimate_in_chans(model_bands: list[HLSBands] | list[str] | tuple[int, int] is_sublist, requires_special_eval = _are_sublists_of_int(model_bands) # Bands as intervals limited by integers + # The bands numbering follows the Python convention (starts with 0) + # and includes the extrema (so the +1 in order to include the last band) if requires_special_eval: if is_sublist: - in_chans = sum([i[-1] - i[0] for i in model_bands]) + in_chans = sum([i[-1] - i[0] + 1 for i in model_bands]) else: - in_chans = model_bands[-1] - model_bands[0] + in_chans = model_bands[-1] - model_bands[0] + 1 else: in_chans = len(model_bands)