From ce045ef340e24fd1d8ea1a23af4d560dbc4dcd55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 8 Jul 2024 17:08:37 -0300 Subject: [PATCH 01/42] Trying to test for 3.12 (update packages if necessary) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index d7896ed3..27dc3e56 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10", "3.11"] + python-version: ["3.10", "3.11", "3.12"] steps: - name: Clone repo From 59a71bab0b92ee24ad5096b45974a02357648a49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 8 Jul 2024 17:22:12 -0300 Subject: [PATCH 02/42] Flexible version for torch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- requirements/required.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/required.txt b/requirements/required.txt index 06dd4ef4..07c9016a 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -2,7 +2,7 @@ torchgeo==0.5.1 rioxarray==0.15.0 albumentations==1.3.1 rasterio==1.3.9 -torch==2.1.0 +torch>=2.1.0 torchvision==0.16.0 torchmetrics==1.3.1 geopandas==0.14.2 @@ -10,4 +10,4 @@ lightly==1.4.25 h5py==3.10.0 geobench==1.0.0 mlflow==2.12.1 -lightning==2.2.5 \ No newline at end of file +lightning==2.2.5 From 2661c175ef2da97dbfd72701c5d98ed35f4e063f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 8 Jul 2024 17:24:06 -0300 Subject: [PATCH 03/42] Flexible version for torchvision and torchmetrics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- requirements/required.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/required.txt b/requirements/required.txt index 07c9016a..fab0c654 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -3,8 +3,8 @@ rioxarray==0.15.0 albumentations==1.3.1 rasterio==1.3.9 torch>=2.1.0 -torchvision==0.16.0 -torchmetrics==1.3.1 +torchvision>=0.16.0 +torchmetrics>=1.3.1 geopandas==0.14.2 lightly==1.4.25 h5py==3.10.0 From 9cbff30dc02cba6a406dcd6c0760cbf76c994818 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 19 Jul 2024 12:00:28 -0300 Subject: [PATCH 04/42] New fixed versions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- requirements/required.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements/required.txt b/requirements/required.txt index fab0c654..755c9cf2 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -2,9 +2,9 @@ torchgeo==0.5.1 rioxarray==0.15.0 albumentations==1.3.1 rasterio==1.3.9 -torch>=2.1.0 -torchvision>=0.16.0 -torchmetrics>=1.3.1 +torch==2.3.1 +torchvision==0.18.1 +torchmetrics==1.4.0 geopandas==0.14.2 lightly==1.4.25 h5py==3.10.0 From c3c4f1f2dee81b2dc7ecad267c715d573ecddc2b Mon Sep 17 00:00:00 2001 From: Pedro Henrique Conrado Date: Fri, 2 Aug 2024 14:36:40 -0400 Subject: [PATCH 05/42] adds smp model factory example Signed-off-by: Pedro Henrique Conrado --- examples/confs/smp_model_factory.yaml | 93 +++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 examples/confs/smp_model_factory.yaml diff --git a/examples/confs/smp_model_factory.yaml b/examples/confs/smp_model_factory.yaml new file mode 100644 index 00000000..7d9a1d53 --- /dev/null +++ b/examples/confs/smp_model_factory.yaml @@ -0,0 +1,93 @@ +benchmark_suffix: smp_test +experiment_name: smp_test +backbone: + backbone: resnet18 + backbone_args: + pretrained: False + output_stride: 2 + smp_decoder_channels: 512 + smp_encoder_depth: 5 + + # backbone: swin3d.swin3d_backbone.Swin3dBackbone + # backbone_args: + # pretrained: False + # output_stride: 2 + # out_channels: + # - 192 + # - 384 + # - 768 + # - 768 + # smp_decoder_channels: 768 + # smp_encoder_depth: 5 + + +tasks: + - name: cashew + type: segmentation + loss: ce + model_factory: SMPModelFactory + bands: + - RED + - GREEN + - BLUE + num_classes: 7 + max_epochs: 60 + direction: max + datamodule: + class_path: terratorch.datamodules.MBeninSmallHolderCashewsNonGeoDataModule + init_args: + batch_size: 16 + num_workers: 4 + train_transform: + - class_path: albumentations.Resize + init_args: + always_apply: True + height: 224 + width: 224 + - class_path: ToTensorV2 + test_transform: + - class_path: albumentations.Resize + init_args: + always_apply: True + height: 224 + width: 224 + - class_path: ToTensorV2 + val_transform: + - class_path: albumentations.Resize + init_args: + height: 224 + width: 224 + - class_path: ToTensorV2 + data_root: "/dccstor/geofm-finetuning/geobench/segmentation_v1.0" + bands: + - "RED" + - "GREEN" + - "BLUE" + decoder: IdentityDecoder + decoder_args: + channels: 128 + metric: val/Multiclass Jaccard Index + +n_trials: 16 +save_models: False +storage_uri: /path/to/storage +optimization_space: + model: + - DeepLabV3 + lr: + min: 6e-5 + max: 1e-3 + type: real + log: true + batch_size: + - 8 + - 16 + - 32 + decoder_channels: + - 32 + - 64 + - 128 + head_dropout: + min: 0.2 + max: 0.8 + type: real \ No newline at end of file From dd081366bac3a0120d953dae8ece6a2456093a19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Tue, 10 Sep 2024 10:52:38 -0300 Subject: [PATCH 06/42] Ignoring problematic dependencies MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- requirements/required.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/required.txt b/requirements/required.txt index 61b6f7c5..d49d8742 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -11,9 +11,9 @@ h5py==3.10.0 geobench==1.0.0 mlflow==2.14.3 lightning==2.2.5 -mmcv==2.0.0 +#mmcv==2.0.0 # Extra dependencies required by mmseg ftfy regex -openmim +#openmim #mim mmsegmentation From d75779fab6963507907a9a9df4ffe8c3e1c97496 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Tue, 10 Sep 2024 10:59:56 -0300 Subject: [PATCH 07/42] avoiding mmsegmentation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .github/workflows/test.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 96510882..27dc3e56 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -28,7 +28,6 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements/required.txt -r requirements/test.txt - mim install mmsegmentation - name: List pip dependencies run: pip list - name: Test with pytest From 816a2296d97c9e352ff1c0b5489de5becb93d624 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Tue, 10 Sep 2024 11:37:50 -0300 Subject: [PATCH 08/42] Linting for 3.12 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .github/workflows/pylint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 2d30dfd0..51802bed 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10","3.11"] + python-version: ["3.10","3.11","3.12"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} From b9b35a1761d9c10929cd13448d59a45ce9c5ef01 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Wed, 9 Oct 2024 19:05:05 +0200 Subject: [PATCH 09/42] Added generic multimodal dataset Signed-off-by: Benedikt Blumenstiel --- examples/confs/multimodal_sen1floods11.yaml | 183 ++++++ terratorch/datamodules/__init__.py | 4 +- .../generic_multimodal_data_module.py | 333 ++++++++++ terratorch/datasets/__init__.py | 8 + .../datasets/generic_multimodal_dataset.py | 586 ++++++++++++++++++ terratorch/tasks/segmentation_tasks.py | 3 +- 6 files changed, 1115 insertions(+), 2 deletions(-) create mode 100644 examples/confs/multimodal_sen1floods11.yaml create mode 100644 terratorch/datamodules/generic_multimodal_data_module.py create mode 100644 terratorch/datasets/generic_multimodal_dataset.py diff --git a/examples/confs/multimodal_sen1floods11.yaml b/examples/confs/multimodal_sen1floods11.yaml new file mode 100644 index 00000000..f591b229 --- /dev/null +++ b/examples/confs/multimodal_sen1floods11.yaml @@ -0,0 +1,183 @@ +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: output + name: sen1floods11_MM + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 40 + + max_epochs: 2 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + enable_checkpointing: true + default_root_dir: output/sen1floods11_MM/ + +data: + class_path: GenericMultiModalDataModule + init_args: + task: 'segmentation' + batch_size: 4 + num_workers: 0 + modalities: + - S2L2A + - S1 + - LULC + S2L2A_dataset_bands: + - COASTAL_AEROSOL + - BLUE + - GREEN + - RED + - RED_EDGE_1 + - RED_EDGE_2 + - RED_EDGE_3 + - NIR_BROAD + - NIR_NARROW + - CIRRUS + - SWIR_1 + - SWIR_2 + S2L2A_output_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + S1_dataset_bands: + - vv + - vh + S1_output_bands: + - vv + - vh + LULC_dataset_bands: + - lulc + LULC_output_bands: + - lulc + rgb_modality: S2L2A # If not provided, uses first modality + rgb_indices: + - 0 + - 1 + - 2 + + train_S2L2A_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand + train_S1_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand + train_LULC_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand + train_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand + val_S2L2A_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand + val_S1_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand + val_LULC_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand + val_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand + test_S2L2A_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand + test_S1_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand + test_LULC_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand + test_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand + + train_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_train.txt + val_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_valid.txt + test_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_test.txt + + S2L2A_grep: "*_S2L2AHand.tif" + S1_grep: "*_S1Hand.tif" + LULC_grep: "*_LULCHand.npy" + label_grep: "*_LabelHand.tif" + no_data_replace: 0 + no_label_replace: -1 + + S2L2A_constant_scale: 1. + S1_constant_scale: 1. + LULC_constant_scale: 1. + + S2L2A_means: + - 0.1412956 + - 0.13795798 + - 0.12353792 + - 0.30902815 + - 0.2044958 + - 0.11912015 + S2L2A_stds: + - 0.07406382 + - 0.07370365 + - 0.08692279 + - 0.11798815 + - 0.09772074 + - 0.07659938 + S1_means: + - -20 + - -20 + S1_stds: + - 10 + - 10 + LULC_means: + - 0 + LULC_stds: + - 1 + + num_classes: 2 + +# train_transform: +# - class_path: albumentations.CenterCrop # TODO: How to handle transforms with multiple modalities? +# init_args: +# height: 224 +# width: 224 +# - class_path: albumentations.HorizontalFlip +# init_args: +# p: 0.5 +# - class_path: ToTensorV2 + + +model: + class_path: terratorch.tasks.SemanticSegmentationTask + init_args: + model_factory: PrithviModelFactory + model_args: + decoder: FCNDecoder + pretrained: true + backbone: prithvi_vit_100 + decoder_channels: 256 + in_channels: 6 + bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + num_frames: 1 + num_classes: 2 + head_dropout: 0.1 + decoder_num_convs: 4 + head_channel_list: + - 256 + loss: ce + ignore_index: -1 + class_weights: + - 0.3 + - 0.7 + freeze_backbone: false + freeze_decoder: false + + +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 6.e-5 + weight_decay: 0.05 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss + diff --git a/terratorch/datamodules/__init__.py b/terratorch/datamodules/__init__.py index 1ed966fd..9996f612 100644 --- a/terratorch/datamodules/__init__.py +++ b/terratorch/datamodules/__init__.py @@ -31,6 +31,7 @@ # GenericNonGeoRegressionDataModule, from terratorch.datamodules.sen1floods11 import Sen1Floods11NonGeoDataModule from terratorch.datamodules.torchgeo_data_module import TorchGeoDataModule, TorchNonGeoDataModule +from terratorch.datamodules.generic_multimodal_data_module import GenericMultiModalDataModule __all__ = ( "GenericNonGeoSegmentationDataModule", @@ -56,5 +57,6 @@ "MNeonTreeNonGeoDataModule", "OpenSentinelMapDataModule", "PASTISDataModule", - "Sen4AgriNetDataModule" + "Sen4AgriNetDataModule", + "GenericMultiModalDataModule", ) diff --git a/terratorch/datamodules/generic_multimodal_data_module.py b/terratorch/datamodules/generic_multimodal_data_module.py new file mode 100644 index 00000000..f9a375ac --- /dev/null +++ b/terratorch/datamodules/generic_multimodal_data_module.py @@ -0,0 +1,333 @@ +# Copyright contributors to the Terratorch project + +""" +This module contains generic data modules for instantiation at runtime. +""" +import os +from collections.abc import Callable, Iterable +from pathlib import Path +from typing import Any + +import albumentations as A +import kornia.augmentation as K +import numpy as np +import torch +from torch import Tensor +from torch.utils.data import DataLoader +from torchgeo.datamodules import NonGeoDataModule +from torchgeo.transforms import AugmentationSequential + +from terratorch.datasets import (GenericMultimodalDataset, GenericMultimodalSegmentationDataset, + GenericMultimodalPixelwiseRegressionDataset, HLSBands) +from terratorch.io.file import load_from_file_or_attribute + + +def collate_chunk_dicts(batch_list): + batch = {} + for key, value in batch_list[0].items(): # TODO: Handle missing modalities when is allow_missing_modalities set. + if isinstance(value, torch.Tensor): + batch[key] = torch.concat([chunk[key] for chunk in batch_list]) + else: + batch[key] = [chunk[key] for chunk in batch_list] + return batch + + +def wrap_in_compose_is_list(transform_list): + # set check shapes to false because of the multitemporal case + return A.Compose(transform_list, is_check_shapes=False) if isinstance(transform_list, Iterable) else transform_list + + +class Normalize(Callable): + def __init__(self, means, stds): + super().__init__() + self.means = means + self.stds = stds + + def __call__(self, batch): + # min_value = self.means - 2 * self.stds + # max_value = self.means + 2 * self.stds + # img = (batch["image"] - min_value) / (max_value - min_value) + # img = torch.clip(img, 0, 1) + # batch["image"] = img + # return batch + image = batch["image"] + if len(image.shape) == 5: + means = torch.tensor(self.means, device=image.device).view(1, -1, 1, 1, 1) + stds = torch.tensor(self.stds, device=image.device).view(1, -1, 1, 1, 1) + elif len(image.shape) == 4: + means = torch.tensor(self.means, device=image.device).view(1, -1, 1, 1) + stds = torch.tensor(self.stds, device=image.device).view(1, -1, 1, 1) + else: + msg = f"Expected batch to have 5 or 4 dimensions, but got {len(image.shape)}" + raise Exception(msg) + batch["image"] = (image - means) / stds + return batch + + +class GenericMultiModalDataModule(NonGeoDataModule): + """ + This is a generic datamodule class for instantiating data modules at runtime. + Composes several [GenericNonGeoSegmentationDatasets][terratorch.datasets.GenericNonGeoSegmentationDataset] + """ + + def __init__( + self, + batch_size: int, + num_workers: int, + modalities: list[str], + task: str | None = None, + num_classes: int | None = None, + label_grep: str | None = None, + train_label_data_root: Path | None = None, + val_label_data_root: Path | None = None, + test_label_data_root: Path | None = None, + predict_data_root: Path | None = None, + train_split: Path | None = None, + val_split: Path | None = None, + test_split: Path | None = None, + ignore_split_file_extensions: bool = True, + allow_substring_split_file: bool = True, # TODO: Check if covered + predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, + predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, + predict_modality: str | None = None, + rgb_modality: str | None = None, + rgb_indices: list[int] | None = None, + # TODO: Check how to handle transforms + train_transform: A.Compose | None | list[A.BasicTransform] = None, + val_transform: A.Compose | None | list[A.BasicTransform] = None, + test_transform: A.Compose | None | list[A.BasicTransform] = None, + expand_temporal_dimension: bool = False, + reduce_zero_label: bool = False, + no_data_replace: float | None = None, + no_label_replace: int | None = None, + drop_last: bool = True, + chunk_data: bool = False, + **kwargs: Any, + ) -> None: + """Constructor + + Args: + # TODO: Update docs + batch_size (int): _description_ + num_workers (int): _description_ + train_data_root (Path): _description_ + val_data_root (Path): _description_ + test_data_root (Path): _description_ + predict_data_root (Path): _description_ + img_grep (str): _description_ + label_grep (str): _description_ + means (list[float]): _description_ + stds (list[float]): _description_ + num_classes (int): _description_ + train_label_data_root (Path | None, optional): _description_. Defaults to None. + val_label_data_root (Path | None, optional): _description_. Defaults to None. + test_label_data_root (Path | None, optional): _description_. Defaults to None. + train_split (Path | None, optional): _description_. Defaults to None. + val_split (Path | None, optional): _description_. Defaults to None. + test_split (Path | None, optional): _description_. Defaults to None. + ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split + file to determine which files to include in the dataset. + E.g. necessary for Eurosat, since the split files specify ".jpg" but files are + actually ".jpg". Defaults to True. + allow_substring_split_file (bool, optional): Whether the split files contain substrings + that must be present in file names to be included (as in mmsegmentation), or exact + matches (e.g. eurosat). Defaults to True. + dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None. + output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. + Naming must match that of dataset_bands. Defaults to None. + predict_dataset_bands (list[HLSBands | int] | None): Overwrites dataset_bands + with this value at predict time. + Defaults to None, which does not overwrite. + predict_output_bands (list[HLSBands | int] | None): Overwrites output_bands + with this value at predict time. Defaults to None, which does not overwrite. + constant_scale (float, optional): _description_. Defaults to 1. + rgb_indices (list[int] | None, optional): _description_. Defaults to None. + train_transform (Albumentations.Compose | None): Albumentations transform + to be applied to the train dataset. + Should end with ToTensorV2(). If used through the generic_data_module, + should not include normalization. Not supported for multi-temporal data. + Defaults to None, which simply applies ToTensorV2(). + val_transform (Albumentations.Compose | None): Albumentations transform + to be applied to the train dataset. + Should end with ToTensorV2(). If used through the generic_data_module, + should not include normalization. Not supported for multi-temporal data. + Defaults to None, which simply applies ToTensorV2(). + test_transform (Albumentations.Compose | None): Albumentations transform + to be applied to the train dataset. + Should end with ToTensorV2(). If used through the generic_data_module, + should not include normalization. Not supported for multi-temporal data. + Defaults to None, which simply applies ToTensorV2(). + no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None. + no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None. + expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w). + Defaults to False. + reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the + expected 0. Defaults to False. + drop_last (bool): Drop the last batch if it is not complete. Defaults to True. + """ + if task == 'segmentation': + dataset_class = GenericMultimodalSegmentationDataset + elif task == 'regression': + dataset_class = GenericMultimodalPixelwiseRegressionDataset + elif task is None: + dataset_class = GenericMultimodalDataset + else: + raise ValueError(f'Unknown task {task}, only segmentation and regression are supported.') + + super().__init__(dataset_class, batch_size, num_workers, **kwargs) + self.num_classes = num_classes + self.modalities = modalities + self.img_grep = {m: kwargs.get('{m}_grep', '*') for m in modalities} + self.label_grep = label_grep + self.train_root = {m: kwargs[f'train_{m}_data_root'] for m in modalities} + self.val_root = {m: kwargs[f'val_{m}_data_root'] for m in modalities} + self.test_root = {m: kwargs[f'test_{m}_data_root'] for m in modalities} + self.train_label_data_root = train_label_data_root + self.val_label_data_root = val_label_data_root + self.test_label_data_root = test_label_data_root + self.predict_root = predict_data_root + self.train_split = train_split + self.val_split = val_split + self.test_split = test_split + self.ignore_split_file_extensions = ignore_split_file_extensions + self.allow_substring_split_file = allow_substring_split_file + self.constant_scale = {m: kwargs.get(f'{m}_constant_scale', 1.) for m in modalities} + self.no_data_replace = no_data_replace + self.no_label_replace = no_label_replace + self.drop_last = drop_last + + self.dataset_bands = {m: kwargs.get(f'{m}_dataset_bands') for m in modalities if f'{m}_dataset_bands' in kwargs} + self.output_bands = {m: kwargs.get(f'{m}_output_bands') for m in modalities if f'{m}_output_bands' in kwargs} + + self.predict_dataset_bands = predict_dataset_bands or self.dataset_bands[predict_modality] \ + if predict_modality in self.dataset_bands else None + self.predict_output_bands = predict_output_bands or self.output_bands[predict_modality] \ + if predict_modality in self.output_bands else None + + self.rgb_modality = rgb_modality or modalities[0] + self.rgb_indices = rgb_indices + self.expand_temporal_dimension = expand_temporal_dimension + self.reduce_zero_label = reduce_zero_label + + assert train_transform is None, "transforms are not implemented yet" # TODO Handle transforms + self.train_transform = wrap_in_compose_is_list(train_transform) + self.val_transform = wrap_in_compose_is_list(val_transform) + self.test_transform = wrap_in_compose_is_list(test_transform) + + # self.aug = AugmentationSequential( + # K.Normalize(means, stds), + # data_keys=["image"], + # ) + means = {m: load_from_file_or_attribute(kwargs[f'{m}_means']) for m in modalities} + stds = {m: load_from_file_or_attribute(kwargs[f'{m}_stds']) for m in modalities} + + self.aug = {m: Normalize(means[m], stds[m]) for m in modalities} + + self.chunk_data = chunk_data + if chunk_data: + self.collate_fn = collate_chunk_dicts + # self.collate_fn = collate_fn_list_dicts + + def setup(self, stage: str) -> None: + if stage in ["fit"]: + self.train_dataset = self.dataset_class( + data_root=self.train_root, + num_classes=self.num_classes, + image_grep=self.img_grep, + label_grep=self.label_grep, + label_data_root=self.train_label_data_root, + split=self.train_split, + ignore_split_file_extensions=self.ignore_split_file_extensions, + allow_substring_split_file=self.allow_substring_split_file, + dataset_bands=self.dataset_bands, + output_bands=self.output_bands, + constant_scale=self.constant_scale, + rgb_modality=self.rgb_modality, + rgb_indices=self.rgb_indices, + transform=self.train_transform, + no_data_replace=self.no_data_replace, + no_label_replace=self.no_label_replace, + expand_temporal_dimension=self.expand_temporal_dimension, + reduce_zero_label=self.reduce_zero_label, + ) + if stage in ["fit", "validate"]: + self.val_dataset = self.dataset_class( + data_root=self.val_root, + num_classes=self.num_classes, + image_grep=self.img_grep, + label_grep=self.label_grep, + label_data_root=self.val_label_data_root, + split=self.val_split, + ignore_split_file_extensions=self.ignore_split_file_extensions, + allow_substring_split_file=self.allow_substring_split_file, + dataset_bands=self.dataset_bands, + output_bands=self.output_bands, + constant_scale=self.constant_scale, + rgb_modality=self.rgb_modality, + rgb_indices=self.rgb_indices, + transform=self.val_transform, + no_data_replace=self.no_data_replace, + no_label_replace=self.no_label_replace, + expand_temporal_dimension=self.expand_temporal_dimension, + reduce_zero_label=self.reduce_zero_label, + ) + if stage in ["test"]: + self.test_dataset = self.dataset_class( + data_root=self.test_root, + num_classes=self.num_classes, + image_grep=self.img_grep, + label_grep=self.label_grep, + label_data_root=self.test_label_data_root, + split=self.test_split, + ignore_split_file_extensions=self.ignore_split_file_extensions, + allow_substring_split_file=self.allow_substring_split_file, + dataset_bands=self.dataset_bands, + output_bands=self.output_bands, + constant_scale=self.constant_scale, + rgb_modality=self.rgb_modality, + rgb_indices=self.rgb_indices, + transform=self.test_transform, + no_data_replace=self.no_data_replace, + no_label_replace=self.no_label_replace, + expand_temporal_dimension=self.expand_temporal_dimension, + reduce_zero_label=self.reduce_zero_label, + ) + if stage in ["predict"] and self.predict_root: + self.predict_dataset = self.dataset_class( + data_root=self.predict_root, + num_classes=self.num_classes, + dataset_bands=self.predict_dataset_bands, + output_bands=self.predict_output_bands, + constant_scale=self.constant_scale, + rgb_modality=self.rgb_modality, + rgb_indices=self.rgb_indices, + transform=self.test_transform, + no_data_replace=self.no_data_replace, + no_label_replace=self.no_label_replace, + expand_temporal_dimension=self.expand_temporal_dimension, + reduce_zero_label=self.reduce_zero_label, + ) + + def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: + """Implement one or more PyTorch DataLoaders. + + Args: + split: Either 'train', 'val', 'test', or 'predict'. + + Returns: + A collection of data loaders specifying samples. + + Raises: + MisconfigurationException: If :meth:`setup` does not define a + dataset or sampler, or if the dataset or sampler has length 0. + """ + dataset = self._valid_attribute(f"{split}_dataset", "dataset") + batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size") + return DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=split == "train", + num_workers=self.num_workers, + collate_fn=self.collate_fn, + drop_last=split == "train" and self.drop_last, + ) diff --git a/terratorch/datasets/__init__.py b/terratorch/datasets/__init__.py index 0b241e80..81bd548d 100644 --- a/terratorch/datasets/__init__.py +++ b/terratorch/datasets/__init__.py @@ -8,6 +8,11 @@ from terratorch.datasets.generic_scalar_label_dataset import ( GenericNonGeoClassificationDataset, ) +from terratorch.datasets.generic_multimodal_dataset import ( + GenericMultimodalDataset, + GenericMultimodalSegmentationDataset, + GenericMultimodalPixelwiseRegressionDataset +) from terratorch.datasets.hls import HLSL30, HLSS30 from terratorch.datasets.m_bigearthnet import MBigEarthNonGeo from terratorch.datasets.m_brick_kiln import MBrickKilnNonGeo @@ -41,6 +46,9 @@ "GenericNonGeoPixelwiseRegressionDataset", "GenericNonGeoClassificationDataset", "GenericNonGeoRegressionDataset", + "GenericMultimodalDataset", + "GenericMultimodalSegmentationDataset", + "GenericMultimodalPixelwiseRegressionDataset", "FireScarsNonGeo", "FireScarsHLS", "FireScarsSegmentationMask", diff --git a/terratorch/datasets/generic_multimodal_dataset.py b/terratorch/datasets/generic_multimodal_dataset.py new file mode 100644 index 00000000..8bc25fea --- /dev/null +++ b/terratorch/datasets/generic_multimodal_dataset.py @@ -0,0 +1,586 @@ +# Copyright contributors to the Terratorch project + +"""Module containing generic dataset classes""" + +import glob +import logging +import os +from abc import ABC +from pathlib import Path +from typing import Any + +import albumentations as A +import matplotlib as mpl +import numpy as np +import rioxarray +import xarray as xr +from einops import rearrange +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +from matplotlib.patches import Rectangle +from torch import Tensor +from torchgeo.datasets import NonGeoDataset + +from terratorch.datasets.utils import HLSBands, default_transform, filter_valid_files, generate_bands_intervals + + +class GenericMultimodalDataset(NonGeoDataset, ABC): + """ + This is a generic dataset class to be used for instantiating datasets from arguments. + Ideally, one would create a dataset class specific to a dataset. + """ + + def __init__( + self, + data_root: Path, + label_data_root: Path | None = None, + image_grep: str | None = "*", + label_grep: str | None = "*", + split: Path | None = None, + ignore_split_file_extensions: bool = True, + allow_substring_split_file: bool = True, + rgb_modality: str | None = None, + rgb_indices: list[int] | None = None, + allow_missing_modalities: bool = False, # TODO: Not implemented on a data module level yet (collate_fn required). + dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, + output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, + constant_scale: dict[float] = None, + transform: A.Compose | None = None, + no_data_replace: float | None = None, + no_label_replace: int | None = None, + expand_temporal_dimension: bool = False, + reduce_zero_label: bool = False, + *args, **kwargs, + ) -> None: + """Constructor + + Args: + data_root (Path): Path to data root directory + label_data_root (Path, optional): Path to data root directory with labels. + If not specified, will use the same as for images. + image_grep (str, optional): Regular expression appended to data_root to find input images. + Defaults to "*". + label_grep (str, optional): Regular expression appended to data_root to find ground truth masks. + Defaults to "*". + split (Path, optional): Path to file containing files to be used for this split. + The file should be a new-line separated prefixes contained in the desired files. + Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep]) + ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split + file to determine which files to include in the dataset. + E.g. necessary for Eurosat, since the split files specify ".jpg" but files are + actually ".jpg". Defaults to True. + allow_substring_split_file (bool, optional): Whether the split files contain substrings + that must be present in file names to be included (as in mmsegmentation), or exact + matches (e.g. eurosat). Defaults to True. + rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. + dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None. + output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands. + constant_scale (float): Factor to multiply image values by. Defaults to 1. + transform (Albumentations.Compose | None): Albumentations transform to be applied. + Should end with ToTensorV2(). If used through the generic_data_module, + should not include normalization. Not supported for multi-temporal data. + Defaults to None, which simply applies ToTensorV2(). + no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None. + no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to -1. + expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w). + Defaults to False. + reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the + expected 0. Defaults to False. + """ + super().__init__() + + self.split_file = split + + self.modalities = list(data_root.keys()) + assert 'mask' not in self.modalities, "Modality cannot be called 'mask'." + self.label_data_root = label_data_root + # Get files per modality + if image_grep: + self.image_files = {m: sorted(glob.glob(os.path.join(m_root, image_grep[m]))) + for m, m_root in data_root.items()} + else: + self.image_files = {m: sorted(glob.glob(m_root)) for m, m_root in data_root.items()} + self.constant_scale = constant_scale or {m: 1. for m in self.modalities} + self.no_data_replace = no_data_replace + self.no_label_replace = no_label_replace + if self.label_data_root: + self.segmentation_mask_files = sorted(glob.glob(os.path.join(label_data_root, label_grep))) + else: + self.segmentation_mask_files = None + self.reduce_zero_label = reduce_zero_label + self.expand_temporal_dimension = expand_temporal_dimension + + if self.expand_temporal_dimension and len(dataset_bands) != self.modalities: + msg = "Please provide dataset_bands for each modality when expand_temporal_dimension is True" + raise Exception(msg) + + if self.split_file is not None: + with open(self.split_file) as f: + split = f.readlines() + valid_files = {rf"{substring.strip()}" for substring in split} + if not ignore_split_file_extensions or not allow_substring_split_file: + # TODO: Only need for special cases, can we generalize the multi-modal samples and remove this part? + for m, m_files in self.image_files.items(): + self.image_files[m] = filter_valid_files( + m_files, + valid_files=valid_files, + ignore_extensions=ignore_split_file_extensions, + allow_substring=allow_substring_split_file, + ) + if self.segmentation_mask_files: + self.segmentation_mask_files = filter_valid_files( + self.segmentation_mask_files, + valid_files=valid_files, + ignore_extensions=ignore_split_file_extensions, + allow_substring=allow_substring_split_file, + ) + else: + logging.warning('No split file provided. ' + 'This requires that all modalities have the same filename for aligned samples.') + if allow_missing_modalities: + all_files = [os.path.splitext(os.path.basename(file))[0] + for file in np.concatenate(list(self.image_files.values()))] + valid_files = list(set(all_files)) + else: + valid_files = [os.path.splitext(os.path.basename(file))[0] + for file in self.image_files[self.modalities[0]]] + logging.info(f'Found {len(valid_files)} file names.') + + # Get multi-modal samples in form: {'modality': path, ..., 'mask': path} + self.samples = [] + num_modalities = len(self.modalities) + int(self.segmentation_mask_files is not None) + for split_file in valid_files: + sample = {} + for m, files in self.image_files.items(): + matching_files = [file for file in files if split_file in file] + if matching_files: + sample[m] = matching_files[0] + if self.segmentation_mask_files: + matching_files = [file for file in self.segmentation_mask_files if split_file in file] + if matching_files: + sample['mask'] = matching_files[0] + else: + # Skip samples with missing labels + continue + if allow_missing_modalities or len(sample) == num_modalities: + self.samples.append(sample) + + self.rgb_modality = rgb_modality or self.modalities[0] + self.rgb_indices = rgb_indices or [0, 1, 2] + + if dataset_bands is not None: + self.dataset_bands = {m: generate_bands_intervals(m_bands) + for m, m_bands in dataset_bands.items()} + if output_bands is not None: + self.output_bands = {m: generate_bands_intervals(m_bands) + for m, m_bands in output_bands.items()} + + for modality in self.modalities: + if modality in self.output_bands and modality not in self.dataset_bands: + msg = f"If output bands are provided, dataset_bands must also be provided (modality: {modality})" + raise Exception(msg) # noqa: PLE0101 + + self.filter_indices = {} + # There is a special condition if the bands are defined as simple strings. + if self.output_bands: + for m in self.output_bands.keys(): + if len(set(self.output_bands[m]) & set(self.dataset_bands[m])) != len(self.output_bands[m]): + msg = f"Output bands must be a subset of dataset bands (Modality: {m})" + raise Exception(msg) + if self.output_bands[m] == self.dataset_bands[m]: + continue + + self.filter_indices[m] = [self.dataset_bands[m].index(band) for band in self.output_bands[m]] + + # TODO: Implement multi-modal transforms + # If no transform is given, apply only to transform to torch tensor + self.transform = transform if transform else default_transform + + import warnings + import rasterio + warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning) + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, index: int) -> dict[str, Any]: + output = {} + for modality, file in self.samples[index].items(): + data = self._load_file( + file, nan_replace=self.no_label_replace if modality == 'mask' else self.no_data_replace).to_numpy() + + # Expand temporal dim + if modality in self.filter_indices and self.expand_temporal_dimension: + data = rearrange(data, "(channels time) h w -> channels time h w", + channels=len(self.dataset_bands[modality])) + + if modality == 'mask': + data = data[0] + + # TODO: Assumes all modalities with three dimension and more to be channel-first images + if len(data.shape) >= 3: + # to channels last + data = np.moveaxis(data, -3, -1) + + if modality in self.filter_indices: + data = data[..., self.filter_indices[modality]] + + if modality != 'mask': + data = data.astype(np.float32) * self.constant_scale[modality] + + output[modality] = data + + if self.reduce_zero_label: + output["mask"] -= 1 + if self.transform: + output = self.transform(**output) + output["filename"] = self.samples[index] + + return output + + def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArray: + if path.endswith('.zarr') or path.endswith('.zarr.zip'): + data = xr.open_zarr(path, mask_and_scale=True) + data_var = list(data.data_vars)[0] # TODO: Make data var configurable if required (e.g. for time/loc) + data = data[data_var] + elif path.endswith('.npy'): + data = xr.DataArray(np.load(path)) + else: + data = rioxarray.open_rasterio(path, masked=True) + + if nan_replace is not None: + data = data.fillna(nan_replace) + return data + + +class GenericMultimodalSegmentationDataset(GenericMultimodalDataset): + """GenericNonGeoSegmentationDataset""" + + def __init__( + self, + data_root: Path, + num_classes: int, + label_data_root: Path, + image_grep: str | None = "*", + label_grep: str | None = "*", + split: Path | None = None, + ignore_split_file_extensions: bool = True, + allow_substring_split_file: bool = True, + rgb_modality: str | None = None, + rgb_indices: list[str] | None = None, + allow_missing_modalities: bool = False, + dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, + output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, + class_names: list[str] | None = None, + constant_scale: float = 1, + transform: A.Compose | None = None, + no_data_replace: float | None = None, + no_label_replace: int | None = None, + expand_temporal_dimension: bool = False, + reduce_zero_label: bool = False, + ) -> None: + """Constructor + + Args: + TODO: Update docs + data_root (Path): Path to data root directory + num_classes (int): Number of classes in the dataset + label_data_root (Path, optional): Path to data root directory with labels. + If not specified, will use the same as for images. + image_grep (str, optional): Regular expression appended to data_root to find input images. + Defaults to "*". + label_grep (str, optional): Regular expression appended to data_root to find ground truth masks. + Defaults to "*". + split (Path, optional): Path to file containing files to be used for this split. + The file should be a new-line separated prefixes contained in the desired files. + Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep]) + ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split + file to determine which files to include in the dataset. + E.g. necessary for Eurosat, since the split files specify ".jpg" but files are + actually ".jpg". Defaults to True + allow_substring_split_file (bool, optional): Whether the split files contain substrings + that must be present in file names to be included (as in mmsegmentation), or exact + matches (e.g. eurosat). Defaults to True. + rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. + dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. + output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. + class_names (list[str], optional): Class names. Defaults to None. + constant_scale (float): Factor to multiply image values by. Defaults to 1. + transform (Albumentations.Compose | None): Albumentations transform to be applied. + Should end with ToTensorV2(). If used through the generic_data_module, + should not include normalization. Not supported for multi-temporal data. + Defaults to None, which simply applies ToTensorV2(). + no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None. + no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None. + expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w). + Defaults to False. + reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the + expected 0. Defaults to False. + """ + super().__init__( + data_root, + label_data_root=label_data_root, + image_grep=image_grep, + label_grep=label_grep, + split=split, + ignore_split_file_extensions=ignore_split_file_extensions, + allow_substring_split_file=allow_substring_split_file, + rgb_modality=rgb_modality, + rgb_indices=rgb_indices, + allow_missing_modalities=allow_missing_modalities, + dataset_bands=dataset_bands, + output_bands=output_bands, + constant_scale=constant_scale, + transform=transform, + no_data_replace=no_data_replace, + no_label_replace=no_label_replace, + expand_temporal_dimension=expand_temporal_dimension, + reduce_zero_label=reduce_zero_label, + ) + self.num_classes = num_classes + self.class_names = class_names + + def __getitem__(self, index: int) -> dict[str, Any]: + item = super().__getitem__(index) + item["mask"] = item["mask"].long() + return item + + def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + raise NotImplementedError('Code is based on the generic single-modality dataset and not yet adapted. ' + 'Set `export TERRATORCH_NUM_VAL_PLOTS=0` before running terratorch.') + + image = sample[self.rgb_modality] + if len(image.shape) == 5: # TODO: Needed? Copied from generic dataest. + return + if isinstance(image, Tensor): + image = image.numpy() + image = image.take(self.rgb_indices, axis=0) + image = np.transpose(image, (1, 2, 0)) + image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1))) + image = np.clip(image, 0, 1) + + label_mask = sample["mask"] + if isinstance(label_mask, Tensor): + label_mask = label_mask.numpy() + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction_mask = sample["prediction"] + if isinstance(prediction_mask, Tensor): + prediction_mask = prediction_mask.numpy() + + return self._plot_sample( + image, + label_mask, + self.num_classes, + prediction=prediction_mask if showing_predictions else None, + suptitle=suptitle, + class_names=self.class_names, + ) + + @staticmethod + def _plot_sample(image, label, num_classes, prediction=None, suptitle=None, class_names=None): + num_images = 5 if prediction is not None else 4 + fig, ax = plt.subplots(1, num_images, figsize=(12, 10), layout="compressed") + + # for legend + ax[0].axis("off") + + norm = mpl.colors.Normalize(vmin=0, vmax=num_classes - 1) + ax[1].axis("off") + ax[1].title.set_text("Image") + ax[1].imshow(image) + + ax[2].axis("off") + ax[2].title.set_text("Ground Truth Mask") + ax[2].imshow(label, cmap="jet", norm=norm) + + ax[3].axis("off") + ax[3].title.set_text("GT Mask on Image") + ax[3].imshow(image) + ax[3].imshow(label, cmap="jet", alpha=0.3, norm=norm) + + if prediction is not None: + ax[4].title.set_text("Predicted Mask") + ax[4].imshow(prediction, cmap="jet", norm=norm) + + cmap = plt.get_cmap("jet") + legend_data = [] + for i, _ in enumerate(range(num_classes)): + class_name = class_names[i] if class_names else str(i) + data = [i, cmap(norm(i)), class_name] + legend_data.append(data) + handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data] + labels = [n for k, c, n in legend_data] + ax[0].legend(handles, labels, loc="center") + if suptitle is not None: + plt.suptitle(suptitle) + return fig + + +class GenericMultimodalPixelwiseRegressionDataset(GenericMultimodalDataset): + """GenericNonGeoPixelwiseRegressionDataset""" + + def __init__( + self, + data_root: Path, + label_data_root: Path, + image_grep: str | None = "*", + label_grep: str | None = "*", + split: Path | None = None, + ignore_split_file_extensions: bool = True, + allow_substring_split_file: bool = True, + rgb_modality: str | None = None, + rgb_indices: list[int] | None = None, + allow_missing_modalities : bool = False, + dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, + output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, + constant_scale: float = 1, + transform: A.Compose | None = None, + no_data_replace: float | None = None, + no_label_replace: int | None = None, + expand_temporal_dimension: bool = False, + reduce_zero_label: bool = False, + ) -> None: + """Constructor + + Args: + TODO: Update docs + data_root (Path): Path to data root directory + label_data_root (Path, optional): Path to data root directory with labels. + If not specified, will use the same as for images. + image_grep (str, optional): Regular expression appended to data_root to find input images. + Defaults to "*". + label_grep (str, optional): Regular expression appended to data_root to find ground truth masks. + Defaults to "*". + split (Path, optional): Path to file containing files to be used for this split. + The file should be a new-line separated prefixes contained in the desired files. + Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep]) + ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split + file to determine which files to include in the dataset. + E.g. necessary for Eurosat, since the split files specify ".jpg" but files are + actually ".jpg". Defaults to True. + allow_substring_split_file (bool, optional): Whether the split files contain substrings + that must be present in file names to be included (as in mmsegmentation), or exact + matches (e.g. eurosat). Defaults to True. + rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. + dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. + output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. + constant_scale (float): Factor to multiply image values by. Defaults to 1. + transform (Albumentations.Compose | None): Albumentations transform to be applied. + Should end with ToTensorV2(). If used through the generic_data_module, + should not include normalization. Not supported for multi-temporal data. + Defaults to None, which simply applies ToTensorV2(). + no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None. + no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None. + expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w). + Defaults to False. + reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the + expected 0. Defaults to False. + """ + super().__init__( + data_root, + label_data_root=label_data_root, + image_grep=image_grep, + label_grep=label_grep, + split=split, + ignore_split_file_extensions=ignore_split_file_extensions, + allow_substring_split_file=allow_substring_split_file, + rgb_modality=rgb_modality, + rgb_indices=rgb_indices, + allow_missing_modalities=allow_missing_modalities, + dataset_bands=dataset_bands, + output_bands=output_bands, + constant_scale=constant_scale, + transform=transform, + no_data_replace=no_data_replace, + no_label_replace=no_label_replace, + expand_temporal_dimension=expand_temporal_dimension, + reduce_zero_label=reduce_zero_label, + ) + + def __getitem__(self, index: int) -> dict[str, Any]: + item = super().__getitem__(index) + item["mask"] = item["mask"].float() + return item + + def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure: + """Plot a sample from the dataset. + + Args: + sample (dict[str, Tensor]): a sample returned by :meth:`__getitem__` + suptitle (str|None): optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + raise NotImplementedError('Code is based on the generic single-modality dataset and not yet adapted. ' + 'Set `export TERRATORCH_NUM_VAL_PLOTS=0` before running terratorch.') + + image = sample["image"] + if len(image.shape) == 5: + return + if isinstance(image, Tensor): + image = image.numpy() + image = image.take(self.rgb_indices, axis=0) + image = np.transpose(image, (1, 2, 0)) + image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1))) + image = np.clip(image, 0, 1) + + label_mask = sample["mask"] + if isinstance(label_mask, Tensor): + label_mask = label_mask.numpy() + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction_mask = sample["prediction"] + if isinstance(prediction_mask, Tensor): + prediction_mask = prediction_mask.numpy() + + return self._plot_sample( + image, + label_mask, + prediction=prediction_mask if showing_predictions else None, + suptitle=suptitle, + ) + + @staticmethod + def _plot_sample(image, label, prediction=None, suptitle=None): + num_images = 4 if prediction is not None else 3 + fig, ax = plt.subplots(1, num_images, figsize=(12, 10), layout="compressed") + + norm = mpl.colors.Normalize(vmin=label.min(), vmax=label.max()) + ax[0].axis("off") + ax[0].title.set_text("Image") + ax[0].imshow(image) + + ax[1].axis("off") + ax[1].title.set_text("Ground Truth Mask") + ax[1].imshow(label, cmap="Greens", norm=norm) + + ax[2].axis("off") + ax[2].title.set_text("GT Mask on Image") + ax[2].imshow(image) + ax[2].imshow(label, cmap="Greens", alpha=0.3, norm=norm) + # ax[2].legend() + + if prediction is not None: + ax[3].title.set_text("Predicted Mask") + ax[3].imshow(prediction, cmap="Greens", norm=norm) + + if suptitle is not None: + plt.suptitle(suptitle) + return fig diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index 9f527d60..9279847f 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -1,3 +1,4 @@ +import os from typing import Any import lightning @@ -16,7 +17,7 @@ from terratorch.tasks.optimizer_factory import optimizer_factory from terratorch.tasks.tiled_inference import TiledInferenceParameters, tiled_inference -BATCH_IDX_FOR_VALIDATION_PLOTTING = 10 +BATCH_IDX_FOR_VALIDATION_PLOTTING = os.getenv('TERRATORCH_NUM_VAL_PLOTS', 10) def to_segmentation_prediction(y: ModelOutput) -> Tensor: From 3134db9e3707351d95d9bd65930e5eedd220d65b Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Fri, 18 Oct 2024 15:50:04 +0200 Subject: [PATCH 10/42] Added pin memory Signed-off-by: Benedikt Blumenstiel --- .../datamodules/generic_pixel_wise_data_module.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/terratorch/datamodules/generic_pixel_wise_data_module.py b/terratorch/datamodules/generic_pixel_wise_data_module.py index 4f39fbca..90c69cd7 100644 --- a/terratorch/datamodules/generic_pixel_wise_data_module.py +++ b/terratorch/datamodules/generic_pixel_wise_data_module.py @@ -107,6 +107,7 @@ def __init__( no_data_replace: float | None = None, no_label_replace: int | None = None, drop_last: bool = True, + pin_memory: bool = True, **kwargs: Any, ) -> None: """Constructor @@ -168,6 +169,8 @@ def __init__( reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False. drop_last (bool): Drop the last batch if it is not complete. Defaults to True. + pin_memory (bool): If ``True``, the data loader will copy Tensors + into device/CUDA pinned memory before returning them. Defaults to False. """ super().__init__(GenericNonGeoSegmentationDataset, batch_size, num_workers, **kwargs) self.num_classes = num_classes @@ -186,6 +189,7 @@ def __init__( self.no_data_replace = no_data_replace self.no_label_replace = no_label_replace self.drop_last = drop_last + self.pin_memory = pin_memory self.train_label_data_root = train_label_data_root self.val_label_data_root = val_label_data_root @@ -313,6 +317,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: num_workers=self.num_workers, collate_fn=self.collate_fn, drop_last=split == "train" and self.drop_last, + pin_memory=self.pin_memory, ) @@ -356,6 +361,7 @@ def __init__( no_data_replace: float | None = None, no_label_replace: int | None = None, drop_last: bool = True, + pin_memory: bool = True, **kwargs: Any, ) -> None: """Constructor @@ -416,6 +422,9 @@ def __init__( reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False. drop_last (bool): Drop the last batch if it is not complete. Defaults to True. + pin_memory (bool): If ``True``, the data loader will copy Tensors + into device/CUDA pinned memory before returning them. Defaults to False. + """ super().__init__(GenericNonGeoPixelwiseRegressionDataset, batch_size, num_workers, **kwargs) self.img_grep = img_grep @@ -430,6 +439,7 @@ def __init__( self.ignore_split_file_extensions = ignore_split_file_extensions self.allow_substring_split_file = allow_substring_split_file self.drop_last = drop_last + self.pin_memory = pin_memory self.expand_temporal_dimension = expand_temporal_dimension self.reduce_zero_label = reduce_zero_label @@ -555,4 +565,5 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: num_workers=self.num_workers, collate_fn=self.collate_fn, drop_last=split == "train" and self.drop_last, + pin_memory=self.pin_memory, ) From 5c1b7c0903498e6298e50bbfb0887824237f6684 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Fri, 18 Oct 2024 17:16:18 +0200 Subject: [PATCH 11/42] Updated multimodel config and add transforms Signed-off-by: Benedikt Blumenstiel --- .../generic_multimodal_data_module.py | 109 ++++++++++++------ .../datasets/generic_multimodal_dataset.py | 38 ++++-- terratorch/datasets/transforms.py | 88 +++++++++++++- 3 files changed, 186 insertions(+), 49 deletions(-) diff --git a/terratorch/datamodules/generic_multimodal_data_module.py b/terratorch/datamodules/generic_multimodal_data_module.py index f9a375ac..6fe408bf 100644 --- a/terratorch/datamodules/generic_multimodal_data_module.py +++ b/terratorch/datamodules/generic_multimodal_data_module.py @@ -24,7 +24,7 @@ def collate_chunk_dicts(batch_list): batch = {} - for key, value in batch_list[0].items(): # TODO: Handle missing modalities when is allow_missing_modalities set. + for key, value in batch_list[0].items(): # TODO: Handle missing modalities when allow_missing_modalities is set. if isinstance(value, torch.Tensor): batch[key] = torch.concat([chunk[key] for chunk in batch_list]) else: @@ -32,9 +32,12 @@ def collate_chunk_dicts(batch_list): return batch -def wrap_in_compose_is_list(transform_list): +def wrap_in_compose_is_list(transform_list, additional_targets=None): # set check shapes to false because of the multitemporal case - return A.Compose(transform_list, is_check_shapes=False) if isinstance(transform_list, Iterable) else transform_list + if additional_targets: + additional_targets = {m: 'image' for m in additional_targets} + return A.Compose(transform_list, is_check_shapes=False, additional_targets=additional_targets) \ + if isinstance(transform_list, Iterable) else transform_list class Normalize(Callable): @@ -44,12 +47,6 @@ def __init__(self, means, stds): self.stds = stds def __call__(self, batch): - # min_value = self.means - 2 * self.stds - # max_value = self.means + 2 * self.stds - # img = (batch["image"] - min_value) / (max_value - min_value) - # img = torch.clip(img, 0, 1) - # batch["image"] = img - # return batch image = batch["image"] if len(image.shape) == 5: means = torch.tensor(self.means, device=image.device).view(1, -1, 1, 1, 1) @@ -75,6 +72,11 @@ def __init__( batch_size: int, num_workers: int, modalities: list[str], + train_data_root: dict, + val_data_root: dict, + test_data_root: dict, + means: dict, + stds: dict, task: str | None = None, num_classes: int | None = None, label_grep: str | None = None, @@ -87,20 +89,23 @@ def __init__( test_split: Path | None = None, ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, # TODO: Check if covered + dataset_bands: dict | None = None, + output_bands: dict | None = None, predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, - predict_modality: str | None = None, rgb_modality: str | None = None, rgb_indices: list[int] | None = None, - # TODO: Check how to handle transforms - train_transform: A.Compose | None | list[A.BasicTransform] = None, - val_transform: A.Compose | None | list[A.BasicTransform] = None, - test_transform: A.Compose | None | list[A.BasicTransform] = None, + constant_scale: dict | float = 1., + train_transform: dict | A.Compose | None | list[A.BasicTransform] = None, + val_transform: dict | A.Compose | None | list[A.BasicTransform] = None, + test_transform: dict | A.Compose | None | list[A.BasicTransform] = None, + shared_transforms: list | bool = True, expand_temporal_dimension: bool = False, reduce_zero_label: bool = False, no_data_replace: float | None = None, no_label_replace: int | None = None, drop_last: bool = True, + pin_memory: bool = False, chunk_data: bool = False, **kwargs: Any, ) -> None: @@ -164,6 +169,9 @@ def __init__( reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False. drop_last (bool): Drop the last batch if it is not complete. Defaults to True. + pin_memory (bool): If ``True``, the data loader will copy Tensors + into device/CUDA pinned memory before returning them. Defaults to False. + """ if task == 'segmentation': dataset_class = GenericMultimodalSegmentationDataset @@ -179,9 +187,9 @@ def __init__( self.modalities = modalities self.img_grep = {m: kwargs.get('{m}_grep', '*') for m in modalities} self.label_grep = label_grep - self.train_root = {m: kwargs[f'train_{m}_data_root'] for m in modalities} - self.val_root = {m: kwargs[f'val_{m}_data_root'] for m in modalities} - self.test_root = {m: kwargs[f'test_{m}_data_root'] for m in modalities} + self.train_root = train_data_root + self.val_root = val_data_root + self.test_root = test_data_root self.train_label_data_root = train_label_data_root self.val_label_data_root = val_label_data_root self.test_label_data_root = test_label_data_root @@ -191,42 +199,70 @@ def __init__( self.test_split = test_split self.ignore_split_file_extensions = ignore_split_file_extensions self.allow_substring_split_file = allow_substring_split_file - self.constant_scale = {m: kwargs.get(f'{m}_constant_scale', 1.) for m in modalities} + self.constant_scale = constant_scale + if isinstance(self.constant_scale, dict): + # Fill in missing modalities + self.constant_scale = {m: self.constant_scale[m] if m in self.constant_scale else 1. + for m in modalities} + else: + # Create dict + self.constant_scale = {m: constant_scale for m in modalities} self.no_data_replace = no_data_replace self.no_label_replace = no_label_replace self.drop_last = drop_last + self.pin_memory = pin_memory - self.dataset_bands = {m: kwargs.get(f'{m}_dataset_bands') for m in modalities if f'{m}_dataset_bands' in kwargs} - self.output_bands = {m: kwargs.get(f'{m}_output_bands') for m in modalities if f'{m}_output_bands' in kwargs} - - self.predict_dataset_bands = predict_dataset_bands or self.dataset_bands[predict_modality] \ - if predict_modality in self.dataset_bands else None - self.predict_output_bands = predict_output_bands or self.output_bands[predict_modality] \ - if predict_modality in self.output_bands else None + self.dataset_bands = dataset_bands + self.output_bands = output_bands + self.predict_dataset_bands = predict_dataset_bands + self.predict_output_bands = predict_output_bands self.rgb_modality = rgb_modality or modalities[0] self.rgb_indices = rgb_indices self.expand_temporal_dimension = expand_temporal_dimension self.reduce_zero_label = reduce_zero_label - assert train_transform is None, "transforms are not implemented yet" # TODO Handle transforms - self.train_transform = wrap_in_compose_is_list(train_transform) - self.val_transform = wrap_in_compose_is_list(val_transform) - self.test_transform = wrap_in_compose_is_list(test_transform) + # Transforms can be None (leads to to_tensor default), shared between modalities or individual per modality + if shared_transforms: + # Applying the same transforms with the same parameters to multiple images + shared_transforms = shared_transforms if isinstance(shared_transforms, list) else modalities + assert shared_transforms == modalities, "Non-image modalities not yet supported with shared_transforms" + + if isinstance(train_transform, dict): + self.train_transform = {m: wrap_in_compose_is_list(train_transform[m]) if m in train_transform else None + for m in modalities} + elif shared_transforms: + self.train_transform = wrap_in_compose_is_list(train_transform, additional_targets=shared_transforms) + else: + self.train_transform = {m: wrap_in_compose_is_list(train_transform) + for m in modalities} + + if isinstance(val_transform, dict): + self.val_transform = {m: wrap_in_compose_is_list(val_transform[m]) if m in val_transform else None + for m in modalities} + elif shared_transforms: + self.val_transform = wrap_in_compose_is_list(val_transform, additional_targets=shared_transforms) + else: + self.val_transform = {m: wrap_in_compose_is_list(val_transform) + for m in modalities} + + if isinstance(test_transform, dict): + self.test_transform = {m: wrap_in_compose_is_list(test_transform[m]) if m in test_transform else None + for m in modalities} + elif shared_transforms: + self.test_transform = wrap_in_compose_is_list(test_transform, additional_targets=shared_transforms) + else: + self.test_transform = {m: wrap_in_compose_is_list(test_transform) + for m in modalities} - # self.aug = AugmentationSequential( - # K.Normalize(means, stds), - # data_keys=["image"], - # ) - means = {m: load_from_file_or_attribute(kwargs[f'{m}_means']) for m in modalities} - stds = {m: load_from_file_or_attribute(kwargs[f'{m}_stds']) for m in modalities} + means = {m: load_from_file_or_attribute(means[m]) for m in means.keys()} + stds = {m: load_from_file_or_attribute(stds[m]) for m in stds.keys()} self.aug = {m: Normalize(means[m], stds[m]) for m in modalities} self.chunk_data = chunk_data if chunk_data: self.collate_fn = collate_chunk_dicts - # self.collate_fn = collate_fn_list_dicts def setup(self, stage: str) -> None: if stage in ["fit"]: @@ -330,4 +366,5 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: num_workers=self.num_workers, collate_fn=self.collate_fn, drop_last=split == "train" and self.drop_last, + pin_memory=self.pin_memory, ) diff --git a/terratorch/datasets/generic_multimodal_dataset.py b/terratorch/datasets/generic_multimodal_dataset.py index 8bc25fea..ee7d3ee9 100644 --- a/terratorch/datasets/generic_multimodal_dataset.py +++ b/terratorch/datasets/generic_multimodal_dataset.py @@ -21,7 +21,9 @@ from torch import Tensor from torchgeo.datasets import NonGeoDataset -from terratorch.datasets.utils import HLSBands, default_transform, filter_valid_files, generate_bands_intervals +from terratorch.datasets.utils import (HLSBands, default_transform, filter_valid_files, generate_bands_intervals, + to_tensor) +from terratorch.datasets.transforms import MultiModalTransforms class GenericMultimodalDataset(NonGeoDataset, ABC): @@ -100,7 +102,7 @@ def __init__( for m, m_root in data_root.items()} else: self.image_files = {m: sorted(glob.glob(m_root)) for m, m_root in data_root.items()} - self.constant_scale = constant_scale or {m: 1. for m in self.modalities} + self.constant_scale = {m: constant_scale[m] or 1. for m in self.modalities} self.no_data_replace = no_data_replace self.no_label_replace = no_label_replace if self.label_data_root: @@ -171,30 +173,42 @@ def __init__( if dataset_bands is not None: self.dataset_bands = {m: generate_bands_intervals(m_bands) for m, m_bands in dataset_bands.items()} + else: + self.dataset_bands = None if output_bands is not None: self.output_bands = {m: generate_bands_intervals(m_bands) for m, m_bands in output_bands.items()} - - for modality in self.modalities: - if modality in self.output_bands and modality not in self.dataset_bands: - msg = f"If output bands are provided, dataset_bands must also be provided (modality: {modality})" - raise Exception(msg) # noqa: PLE0101 + for modality in self.modalities: + if modality in self.output_bands and modality not in self.dataset_bands: + msg = f"If output bands are provided, dataset_bands must also be provided (modality: {modality})" + raise Exception(msg) # noqa: PLE0101 + else: + self.output_bands = {} self.filter_indices = {} # There is a special condition if the bands are defined as simple strings. if self.output_bands: for m in self.output_bands.keys(): + if m not in self.output_bands or self.output_bands[m] == self.dataset_bands[m]: + continue if len(set(self.output_bands[m]) & set(self.dataset_bands[m])) != len(self.output_bands[m]): msg = f"Output bands must be a subset of dataset bands (Modality: {m})" raise Exception(msg) - if self.output_bands[m] == self.dataset_bands[m]: - continue self.filter_indices[m] = [self.dataset_bands[m].index(band) for band in self.output_bands[m]] - # TODO: Implement multi-modal transforms + # If no transform is given, apply only to transform to torch tensor - self.transform = transform if transform else default_transform + if isinstance(transform, A.Compose): + self.transform = MultiModalTransforms(transform) + elif transform is None: + self.transform = to_tensor + else: + # TODO: Test transforms per modality + # Modality-specific transforms + transform = {m: transform[m] if m in transform else default_transform + for m in self.modalities} + self.transform = MultiModalTransforms(transform, shared=False) import warnings import rasterio @@ -233,7 +247,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: if self.reduce_zero_label: output["mask"] -= 1 if self.transform: - output = self.transform(**output) + output = self.transform(output) output["filename"] = self.samples[index] return output diff --git a/terratorch/datasets/transforms.py b/terratorch/datasets/transforms.py index c3b608b8..3e089193 100644 --- a/terratorch/datasets/transforms.py +++ b/terratorch/datasets/transforms.py @@ -3,6 +3,7 @@ from albumentations import BasicTransform, Compose, ImageOnlyTransform from einops import rearrange from torch import Tensor +import albumentations as A N_DIMS_FOR_TEMPORAL = 4 N_DIMS_FLATTENED_TEMPORAL = 3 @@ -37,7 +38,7 @@ def get_transform_init_args_names(self): class UnflattenTemporalFromChannels(ImageOnlyTransform): - """Flatten the temporal dimension into channels + """Unflatten the temporal dimension from the channel dimension Assumes channels first (usually should be run after ToTensorV2)""" def __init__(self, n_timesteps: int | None = None, n_channels: int | None = None): @@ -61,6 +62,70 @@ def apply(self, img, **params): def get_transform_init_args_names(self): return ("n_timesteps", "n_channels") +class FlattenSamplesIntoChannels(ImageOnlyTransform): + """Flatten the sample and optional temporal dimension into channels""" + + def __init__(self, time_dim: bool = True): + super().__init__(True, 1) + self.time_dim = time_dim + + def apply(self, img, **params): + if self.time_dim: + rearranged = rearrange(img, + "samples time height width channels -> height width (samples time channels)") + else: + rearranged = rearrange(img, "samples height width channels -> height width (samples channels)") + return rearranged + + def get_transform_init_args_names(self): + return () + + +class UnflattenSamplesFromChannels(ImageOnlyTransform): + """Unflatten the sample and optional the temporal dimension from the channel dimension + Assumes channels first (usually should be run after ToTensorV2)""" + + def __init__( + self, + time_dim: bool = True, + n_samples: int | None = None, + n_timesteps: int | None = None, + n_channels: int | None = None + ): + super().__init__(True, 1) + + self.time_dim = time_dim + if self.time_dim: + if bool(n_channels) + bool(n_timesteps) + bool(n_samples) < 2: + msg = "Two of n_channels, n_timesteps, and n_channels must be provided" + raise Exception(msg) + if n_timesteps and n_channels: + self.additional_info = {"channels": n_channels, "time": n_timesteps} + elif n_timesteps and n_samples: + self.additional_info = {"time": n_timesteps, "samples": n_samples} + else: + self.additional_info = {"channels": n_channels, "samples": n_samples} + else: + if n_channels is None and n_samples is None: + msg = "One of n_channels or n_samples must be provided" + raise Exception(msg) + self.additional_info = {"channels": n_channels} if n_channels else {"samples": n_samples} + + def apply(self, img, **params): + if self.time_dim: + rearranged = rearrange( + img, "(samples time channels) height width -> samples time channels height width", + **self.additional_info + ) + else: + rearranged = rearrange( + img, "(samples channels) height width -> samples channels height width", **self.additional_info + ) + return rearranged + + def get_transform_init_args_names(self): + return ("n_timesteps", "n_channels") + class Rearrange(ImageOnlyTransform): """Flatten the temporal dimension into channels""" @@ -91,3 +156,24 @@ def apply(self, img, **params): def get_transform_init_args_names(self): return "band_indices" + + +class MultiModalTransforms: + """Applies albumentations transforms to multiple images""" + def __init__(self, transforms: dict | A.Compose, shared : bool = True): + self.transforms = transforms + self.shared = shared + + def __call__(self, data: dict): + if self.shared: + # albumentations requires a key 'image' + image_modality = list(data.keys())[0] + data['image'] = data.pop(image_modality) + data = self.transforms(**data) + data[image_modality] = data.pop('image') + else: + # Applies transformations for each modality separate + for key, value in data.items(): + data[key] = self.transforms[key](image=value)['image'] # Only works with image modalities + + return data From daba2cec941d214cdba393b63f07d98b64705934 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Fri, 18 Oct 2024 18:27:59 +0200 Subject: [PATCH 12/42] Add sample_num_modalities Signed-off-by: Benedikt Blumenstiel --- .../generic_multimodal_data_module.py | 57 +++++++++++++++++-- .../datasets/generic_multimodal_dataset.py | 11 +++- 2 files changed, 60 insertions(+), 8 deletions(-) diff --git a/terratorch/datamodules/generic_multimodal_data_module.py b/terratorch/datamodules/generic_multimodal_data_module.py index 6fe408bf..5e4c6668 100644 --- a/terratorch/datamodules/generic_multimodal_data_module.py +++ b/terratorch/datamodules/generic_multimodal_data_module.py @@ -6,14 +6,14 @@ import os from collections.abc import Callable, Iterable from pathlib import Path -from typing import Any +from typing import Any, Iterator import albumentations as A import kornia.augmentation as K import numpy as np import torch from torch import Tensor -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, RandomSampler, BatchSampler, SequentialSampler from torchgeo.datamodules import NonGeoDataModule from torchgeo.transforms import AugmentationSequential @@ -61,6 +61,42 @@ def __call__(self, batch): return batch +class MultiModalBatchSampler(BatchSampler): + """ + Sample a defined number of modalities per batch (see sample_num_modalities and sample_replace) + """ + def __init__(self, modalities, sample_num_modalities, sample_replace, *args, **kwargs): + super().__init__(*args, **kwargs) + self.modalities = modalities + self.sample_num_modalities = sample_num_modalities + self.sample_replace = sample_replace + + def __iter__(self) -> Iterator[list[int]]: + # Select sampled modalities per batch + sampled_modalities = np.random.choice(self.modalities, self.sample_num_modalities, replace=self.sample_replace) + + if self.drop_last: + sampler_iter = iter(self.sampler) + while True: + try: + batch = [(next(sampler_iter), sampled_modalities) for _ in range(self.batch_size)] + yield batch + except StopIteration: + break + else: + batch = [0] * self.batch_size + idx_in_batch = 0 + for idx in self.sampler: + batch[idx_in_batch] = (idx, sampled_modalities) + idx_in_batch += 1 + if idx_in_batch == self.batch_size: + yield batch + idx_in_batch = 0 + batch = [0] * self.batch_size + if idx_in_batch > 0: + yield batch[:idx_in_batch] + + class GenericMultiModalDataModule(NonGeoDataModule): """ This is a generic datamodule class for instantiating data modules at runtime. @@ -88,7 +124,7 @@ def __init__( val_split: Path | None = None, test_split: Path | None = None, ignore_split_file_extensions: bool = True, - allow_substring_split_file: bool = True, # TODO: Check if covered + allow_substring_split_file: bool = True, # TODO: Check if covered dataset_bands: dict | None = None, output_bands: dict | None = None, predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, @@ -107,6 +143,8 @@ def __init__( drop_last: bool = True, pin_memory: bool = False, chunk_data: bool = False, + sample_num_modalities: int | None = None, + sample_replace: bool = False, **kwargs: Any, ) -> None: """Constructor @@ -211,6 +249,8 @@ def __init__( self.no_label_replace = no_label_replace self.drop_last = drop_last self.pin_memory = pin_memory + self.sample_num_modalities = sample_num_modalities + self.sample_replace = sample_replace self.dataset_bands = dataset_bands self.output_bands = output_bands @@ -359,12 +399,17 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: """ dataset = self._valid_attribute(f"{split}_dataset", "dataset") batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size") + batch_sampler = MultiModalBatchSampler( + self.modalities, self.sample_num_modalities, self.sample_replace, + RandomSampler(dataset) if split == "train" else SequentialSampler(dataset), + batch_size=batch_size, drop_last=split == "train" and self.drop_last) + return DataLoader( dataset=dataset, - batch_size=batch_size, - shuffle=split == "train", + batch_sampler=batch_sampler, + # shuffle=split == "train", num_workers=self.num_workers, collate_fn=self.collate_fn, - drop_last=split == "train" and self.drop_last, + # drop_last=split == "train" and self.drop_last, pin_memory=self.pin_memory, ) diff --git a/terratorch/datasets/generic_multimodal_dataset.py b/terratorch/datasets/generic_multimodal_dataset.py index ee7d3ee9..674781d7 100644 --- a/terratorch/datasets/generic_multimodal_dataset.py +++ b/terratorch/datasets/generic_multimodal_dataset.py @@ -204,7 +204,6 @@ def __init__( elif transform is None: self.transform = to_tensor else: - # TODO: Test transforms per modality # Modality-specific transforms transform = {m: transform[m] if m in transform else default_transform for m in self.modalities} @@ -219,7 +218,15 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> dict[str, Any]: output = {} - for modality, file in self.samples[index].items(): + if isinstance(index, tuple): + # Load only sampled modalities instead of all modalities + # (see sample_num_modalities in GenericMultiModalDataModule for details) + index, modalities = index + sample = {m: self.samples[index][m] for m in modalities} + else: + sample = self.samples[index] + + for modality, file in sample.items(): data = self._load_file( file, nan_replace=self.no_label_replace if modality == 'mask' else self.no_data_replace).to_numpy() From f839379ecb9ae2493a939d18046816f0894a1645 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Fri, 18 Oct 2024 19:10:26 +0200 Subject: [PATCH 13/42] Multimodal dataset with multiple dataset folders Signed-off-by: Benedikt Blumenstiel --- .../generic_multimodal_data_module.py | 36 ++++++++++++++---- .../datasets/generic_multimodal_dataset.py | 38 ++++++++++++------- 2 files changed, 53 insertions(+), 21 deletions(-) diff --git a/terratorch/datamodules/generic_multimodal_data_module.py b/terratorch/datamodules/generic_multimodal_data_module.py index 5e4c6668..4c3359f6 100644 --- a/terratorch/datamodules/generic_multimodal_data_module.py +++ b/terratorch/datamodules/generic_multimodal_data_module.py @@ -4,6 +4,7 @@ This module contains generic data modules for instantiation at runtime. """ import os +import logging from collections.abc import Callable, Iterable from pathlib import Path from typing import Any, Iterator @@ -72,6 +73,9 @@ def __init__(self, modalities, sample_num_modalities, sample_replace, *args, **k self.sample_replace = sample_replace def __iter__(self) -> Iterator[list[int]]: + """ + Code similar to BatchSampler but samples tuples in the format (idx, ['m1', 'm2', ...]) + """ # Select sampled modalities per batch sampled_modalities = np.random.choice(self.modalities, self.sample_num_modalities, replace=self.sample_replace) @@ -115,6 +119,7 @@ def __init__( stds: dict, task: str | None = None, num_classes: int | None = None, + img_grep: str | dict | None = None, label_grep: str | None = None, train_label_data_root: Path | None = None, val_label_data_root: Path | None = None, @@ -223,8 +228,11 @@ def __init__( super().__init__(dataset_class, batch_size, num_workers, **kwargs) self.num_classes = num_classes self.modalities = modalities - self.img_grep = {m: kwargs.get('{m}_grep', '*') for m in modalities} - self.label_grep = label_grep + if isinstance(img_grep, dict): + self.img_grep = {m: img_grep[m] if m in img_grep else '*' for m in modalities} + else: + self.img_grep = {m: img_grep or '*' for m in modalities} + self.label_grep = label_grep or '*' self.train_root = train_data_root self.val_root = val_data_root self.test_root = test_data_root @@ -326,6 +334,7 @@ def setup(self, stage: str) -> None: expand_temporal_dimension=self.expand_temporal_dimension, reduce_zero_label=self.reduce_zero_label, ) + logging.info(f'Train dataset: {len(self.train_dataset)}') if stage in ["fit", "validate"]: self.val_dataset = self.dataset_class( data_root=self.val_root, @@ -347,6 +356,7 @@ def setup(self, stage: str) -> None: expand_temporal_dimension=self.expand_temporal_dimension, reduce_zero_label=self.reduce_zero_label, ) + logging.info(f'Val dataset: {len(self.val_dataset)}') if stage in ["test"]: self.test_dataset = self.dataset_class( data_root=self.test_root, @@ -368,6 +378,7 @@ def setup(self, stage: str) -> None: expand_temporal_dimension=self.expand_temporal_dimension, reduce_zero_label=self.reduce_zero_label, ) + logging.info(f'Test dataset: {len(self.test_dataset)}') if stage in ["predict"] and self.predict_root: self.predict_dataset = self.dataset_class( data_root=self.predict_root, @@ -383,6 +394,7 @@ def setup(self, stage: str) -> None: expand_temporal_dimension=self.expand_temporal_dimension, reduce_zero_label=self.reduce_zero_label, ) + logging.info(f'Predict dataset: {len(self.predict_dataset)}') def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders. @@ -399,17 +411,25 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: """ dataset = self._valid_attribute(f"{split}_dataset", "dataset") batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size") - batch_sampler = MultiModalBatchSampler( - self.modalities, self.sample_num_modalities, self.sample_replace, - RandomSampler(dataset) if split == "train" else SequentialSampler(dataset), - batch_size=batch_size, drop_last=split == "train" and self.drop_last) + if self.sample_num_modalities: + # Custom batch sampler for sampling modalities per batch + batch_sampler = MultiModalBatchSampler( + self.modalities, self.sample_num_modalities, self.sample_replace, + RandomSampler(dataset) if split == "train" else SequentialSampler(dataset), + batch_size=batch_size, + drop_last=split == "train" and self.drop_last + ) + else: + batch_sampler = BatchSampler( + RandomSampler(dataset) if split == "train" else SequentialSampler(dataset), + batch_size=batch_size, + drop_last=split == "train" and self.drop_last + ) return DataLoader( dataset=dataset, batch_sampler=batch_sampler, - # shuffle=split == "train", num_workers=self.num_workers, collate_fn=self.collate_fn, - # drop_last=split == "train" and self.drop_last, pin_memory=self.pin_memory, ) diff --git a/terratorch/datasets/generic_multimodal_dataset.py b/terratorch/datasets/generic_multimodal_dataset.py index 674781d7..23bef394 100644 --- a/terratorch/datasets/generic_multimodal_dataset.py +++ b/terratorch/datasets/generic_multimodal_dataset.py @@ -26,6 +26,21 @@ from terratorch.datasets.transforms import MultiModalTransforms +def load_files(root, grep): + if isinstance(root, dict): + files = {} + for m, m_root in root.items(): + if isinstance(m_root, list): + # Iterate over a list of data folders + dir_lists = [glob.glob(os.path.join(r, grep[m])) for r in m_root] + files[m] = sorted([p for l in dir_lists for p in l]) # Concatenate + else: + files[m] = sorted(glob.glob(os.path.join(m_root, grep[m]))) + return files + elif isinstance(root, str): + return sorted(glob.glob(os.path.join(root, grep))) + + class GenericMultimodalDataset(NonGeoDataset, ABC): """ This is a generic dataset class to be used for instantiating datasets from arguments. @@ -34,8 +49,8 @@ class GenericMultimodalDataset(NonGeoDataset, ABC): def __init__( self, - data_root: Path, - label_data_root: Path | None = None, + data_root: Path | list[Path], + label_data_root: Path | list[Path] | None = None, image_grep: str | None = "*", label_grep: str | None = "*", split: Path | None = None, @@ -95,20 +110,17 @@ def __init__( self.modalities = list(data_root.keys()) assert 'mask' not in self.modalities, "Modality cannot be called 'mask'." - self.label_data_root = label_data_root - # Get files per modality - if image_grep: - self.image_files = {m: sorted(glob.glob(os.path.join(m_root, image_grep[m]))) - for m, m_root in data_root.items()} + self.image_files = load_files(data_root, image_grep) + + if label_data_root: + self.segmentation_mask_files = load_files(label_data_root, label_grep) else: - self.image_files = {m: sorted(glob.glob(m_root)) for m, m_root in data_root.items()} + # No labels + self.segmentation_mask_files = None + self.constant_scale = {m: constant_scale[m] or 1. for m in self.modalities} self.no_data_replace = no_data_replace self.no_label_replace = no_label_replace - if self.label_data_root: - self.segmentation_mask_files = sorted(glob.glob(os.path.join(label_data_root, label_grep))) - else: - self.segmentation_mask_files = None self.reduce_zero_label = reduce_zero_label self.expand_temporal_dimension = expand_temporal_dimension @@ -146,7 +158,7 @@ def __init__( else: valid_files = [os.path.splitext(os.path.basename(file))[0] for file in self.image_files[self.modalities[0]]] - logging.info(f'Found {len(valid_files)} file names.') + # Get multi-modal samples in form: {'modality': path, ..., 'mask': path} self.samples = [] From 157dcf04d9a111283bfe2a7ce039632ecd1e64b6 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Fri, 25 Oct 2024 16:07:17 +0200 Subject: [PATCH 14/42] Update multimodal dataest Signed-off-by: Benedikt Blumenstiel --- .../generic_multimodal_data_module.py | 17 --- .../datasets/generic_multimodal_dataset.py | 134 ++++++------------ 2 files changed, 43 insertions(+), 108 deletions(-) diff --git a/terratorch/datamodules/generic_multimodal_data_module.py b/terratorch/datamodules/generic_multimodal_data_module.py index 4c3359f6..c0ae3f8c 100644 --- a/terratorch/datamodules/generic_multimodal_data_module.py +++ b/terratorch/datamodules/generic_multimodal_data_module.py @@ -128,8 +128,6 @@ def __init__( train_split: Path | None = None, val_split: Path | None = None, test_split: Path | None = None, - ignore_split_file_extensions: bool = True, - allow_substring_split_file: bool = True, # TODO: Check if covered dataset_bands: dict | None = None, output_bands: dict | None = None, predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, @@ -173,13 +171,6 @@ def __init__( train_split (Path | None, optional): _description_. Defaults to None. val_split (Path | None, optional): _description_. Defaults to None. test_split (Path | None, optional): _description_. Defaults to None. - ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split - file to determine which files to include in the dataset. - E.g. necessary for Eurosat, since the split files specify ".jpg" but files are - actually ".jpg". Defaults to True. - allow_substring_split_file (bool, optional): Whether the split files contain substrings - that must be present in file names to be included (as in mmsegmentation), or exact - matches (e.g. eurosat). Defaults to True. dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None. output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. Naming must match that of dataset_bands. Defaults to None. @@ -243,8 +234,6 @@ def __init__( self.train_split = train_split self.val_split = val_split self.test_split = test_split - self.ignore_split_file_extensions = ignore_split_file_extensions - self.allow_substring_split_file = allow_substring_split_file self.constant_scale = constant_scale if isinstance(self.constant_scale, dict): # Fill in missing modalities @@ -321,8 +310,6 @@ def setup(self, stage: str) -> None: label_grep=self.label_grep, label_data_root=self.train_label_data_root, split=self.train_split, - ignore_split_file_extensions=self.ignore_split_file_extensions, - allow_substring_split_file=self.allow_substring_split_file, dataset_bands=self.dataset_bands, output_bands=self.output_bands, constant_scale=self.constant_scale, @@ -343,8 +330,6 @@ def setup(self, stage: str) -> None: label_grep=self.label_grep, label_data_root=self.val_label_data_root, split=self.val_split, - ignore_split_file_extensions=self.ignore_split_file_extensions, - allow_substring_split_file=self.allow_substring_split_file, dataset_bands=self.dataset_bands, output_bands=self.output_bands, constant_scale=self.constant_scale, @@ -365,8 +350,6 @@ def setup(self, stage: str) -> None: label_grep=self.label_grep, label_data_root=self.test_label_data_root, split=self.test_split, - ignore_split_file_extensions=self.ignore_split_file_extensions, - allow_substring_split_file=self.allow_substring_split_file, dataset_bands=self.dataset_bands, output_bands=self.output_bands, constant_scale=self.constant_scale, diff --git a/terratorch/datasets/generic_multimodal_dataset.py b/terratorch/datasets/generic_multimodal_dataset.py index 23bef394..394a535b 100644 --- a/terratorch/datasets/generic_multimodal_dataset.py +++ b/terratorch/datasets/generic_multimodal_dataset.py @@ -26,21 +26,6 @@ from terratorch.datasets.transforms import MultiModalTransforms -def load_files(root, grep): - if isinstance(root, dict): - files = {} - for m, m_root in root.items(): - if isinstance(m_root, list): - # Iterate over a list of data folders - dir_lists = [glob.glob(os.path.join(r, grep[m])) for r in m_root] - files[m] = sorted([p for l in dir_lists for p in l]) # Concatenate - else: - files[m] = sorted(glob.glob(os.path.join(m_root, grep[m]))) - return files - elif isinstance(root, str): - return sorted(glob.glob(os.path.join(root, grep))) - - class GenericMultimodalDataset(NonGeoDataset, ABC): """ This is a generic dataset class to be used for instantiating datasets from arguments. @@ -49,13 +34,11 @@ class GenericMultimodalDataset(NonGeoDataset, ABC): def __init__( self, - data_root: Path | list[Path], + data_root: dict[Path], label_data_root: Path | list[Path] | None = None, image_grep: str | None = "*", label_grep: str | None = "*", split: Path | None = None, - ignore_split_file_extensions: bool = True, - allow_substring_split_file: bool = True, rgb_modality: str | None = None, rgb_indices: list[int] | None = None, allow_missing_modalities: bool = False, # TODO: Not implemented on a data module level yet (collate_fn required). @@ -82,13 +65,6 @@ def __init__( split (Path, optional): Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep]) - ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split - file to determine which files to include in the dataset. - E.g. necessary for Eurosat, since the split files specify ".jpg" but files are - actually ".jpg". Defaults to True. - allow_substring_split_file (bool, optional): Whether the split files contain substrings - that must be present in file names to be included (as in mmsegmentation), or exact - matches (e.g. eurosat). Defaults to True. rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None. output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands. @@ -110,13 +86,13 @@ def __init__( self.modalities = list(data_root.keys()) assert 'mask' not in self.modalities, "Modality cannot be called 'mask'." - self.image_files = load_files(data_root, image_grep) - if label_data_root: - self.segmentation_mask_files = load_files(label_data_root, label_grep) - else: - # No labels - self.segmentation_mask_files = None + # Convert path strings to lists + for m, m_dir in data_root.items(): + if not isinstance(m_dir, list): + data_root[m] = [m_dir] + if label_data_root and not isinstance(label_data_root, list): + label_data_root = [label_data_root] self.constant_scale = {m: constant_scale[m] or 1. for m in self.modalities} self.no_data_replace = no_data_replace @@ -128,55 +104,53 @@ def __init__( msg = "Please provide dataset_bands for each modality when expand_temporal_dimension is True" raise Exception(msg) + # Load samples based on split file if self.split_file is not None: with open(self.split_file) as f: split = f.readlines() valid_files = {rf"{substring.strip()}" for substring in split} - if not ignore_split_file_extensions or not allow_substring_split_file: - # TODO: Only need for special cases, can we generalize the multi-modal samples and remove this part? - for m, m_files in self.image_files.items(): - self.image_files[m] = filter_valid_files( - m_files, - valid_files=valid_files, - ignore_extensions=ignore_split_file_extensions, - allow_substring=allow_substring_split_file, - ) - if self.segmentation_mask_files: - self.segmentation_mask_files = filter_valid_files( - self.segmentation_mask_files, - valid_files=valid_files, - ignore_extensions=ignore_split_file_extensions, - allow_substring=allow_substring_split_file, - ) + else: - logging.warning('No split file provided. ' - 'This requires that all modalities have the same filename for aligned samples.') + image_files = {} + for m, m_dirs in data_root.items(): + dir_lists = [glob.glob(os.path.join(r, image_grep[m])) for r in m_dirs] + image_files[m] = sorted([p for l in dir_lists for p in l]) # Concatenate + + if label_data_root: + dir_lists = [glob.glob(os.path.join(r, label_grep)) for r in label_data_root] + segmentation_mask_files = sorted([p for l in dir_lists for p in l]) # Concatenate + if allow_missing_modalities: - all_files = [os.path.splitext(os.path.basename(file))[0] - for file in np.concatenate(list(self.image_files.values()))] - valid_files = list(set(all_files)) + valid_files = set([os.path.splitext(os.path.basename(file))[0] + for file in np.concatenate(list(image_files.values()))]) else: valid_files = [os.path.splitext(os.path.basename(file))[0] - for file in self.image_files[self.modalities[0]]] + for file in image_files[self.modalities[0]]] - - # Get multi-modal samples in form: {'modality': path, ..., 'mask': path} self.samples = [] - num_modalities = len(self.modalities) + int(self.segmentation_mask_files is not None) - for split_file in valid_files: + num_modalities = len(self.modalities) + int(label_data_root is not None) + # Iterate over all files in split + for file in valid_files: sample = {} - for m, files in self.image_files.items(): - matching_files = [file for file in files if split_file in file] - if matching_files: - sample[m] = matching_files[0] - if self.segmentation_mask_files: - matching_files = [file for file in self.segmentation_mask_files if split_file in file] - if matching_files: - sample['mask'] = matching_files[0] - else: - # Skip samples with missing labels - continue - if allow_missing_modalities or len(sample) == num_modalities: + # Iterate over all modalities + for m, m_dirs in data_root.items(): + # Iterate over all directories of the current modality + for m_dir in m_dirs: + m_files = glob.glob(os.path.join(m_dir, file + image_grep[m])) + if m_files: + sample[m] = m_files[0] + break + if label_data_root: + for l_dir in label_data_root: + l_files = glob.glob(os.path.join(l_dir, file + label_grep)) + if l_files: + sample['mask'] = l_files[0] + break + if 'mask' not in sample: + # Only add sample if mask is present + break + + if len(sample) == num_modalities or allow_missing_modalities: self.samples.append(sample) self.rgb_modality = rgb_modality or self.modalities[0] @@ -297,8 +271,6 @@ def __init__( image_grep: str | None = "*", label_grep: str | None = "*", split: Path | None = None, - ignore_split_file_extensions: bool = True, - allow_substring_split_file: bool = True, rgb_modality: str | None = None, rgb_indices: list[str] | None = None, allow_missing_modalities: bool = False, @@ -327,13 +299,6 @@ def __init__( split (Path, optional): Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep]) - ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split - file to determine which files to include in the dataset. - E.g. necessary for Eurosat, since the split files specify ".jpg" but files are - actually ".jpg". Defaults to True - allow_substring_split_file (bool, optional): Whether the split files contain substrings - that must be present in file names to be included (as in mmsegmentation), or exact - matches (e.g. eurosat). Defaults to True. rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. @@ -356,8 +321,6 @@ def __init__( image_grep=image_grep, label_grep=label_grep, split=split, - ignore_split_file_extensions=ignore_split_file_extensions, - allow_substring_split_file=allow_substring_split_file, rgb_modality=rgb_modality, rgb_indices=rgb_indices, allow_missing_modalities=allow_missing_modalities, @@ -472,8 +435,6 @@ def __init__( image_grep: str | None = "*", label_grep: str | None = "*", split: Path | None = None, - ignore_split_file_extensions: bool = True, - allow_substring_split_file: bool = True, rgb_modality: str | None = None, rgb_indices: list[int] | None = None, allow_missing_modalities : bool = False, @@ -500,13 +461,6 @@ def __init__( split (Path, optional): Path to file containing files to be used for this split. The file should be a new-line separated prefixes contained in the desired files. Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep]) - ignore_split_file_extensions (bool, optional): Whether to disregard extensions when using the split - file to determine which files to include in the dataset. - E.g. necessary for Eurosat, since the split files specify ".jpg" but files are - actually ".jpg". Defaults to True. - allow_substring_split_file (bool, optional): Whether the split files contain substrings - that must be present in file names to be included (as in mmsegmentation), or exact - matches (e.g. eurosat). Defaults to True. rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. @@ -528,8 +482,6 @@ def __init__( image_grep=image_grep, label_grep=label_grep, split=split, - ignore_split_file_extensions=ignore_split_file_extensions, - allow_substring_split_file=allow_substring_split_file, rgb_modality=rgb_modality, rgb_indices=rgb_indices, allow_missing_modalities=allow_missing_modalities, From a8053f098ba53a50f2aa4ca430c8529279d8ec32 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Wed, 30 Oct 2024 10:12:45 +0100 Subject: [PATCH 15/42] Added allow_substring_split_file to multimodal dataset Signed-off-by: Benedikt Blumenstiel --- .../generic_multimodal_data_module.py | 6 +++ .../datasets/generic_multimodal_dataset.py | 48 ++++++++++++++----- 2 files changed, 41 insertions(+), 13 deletions(-) diff --git a/terratorch/datamodules/generic_multimodal_data_module.py b/terratorch/datamodules/generic_multimodal_data_module.py index c0ae3f8c..d872b6be 100644 --- a/terratorch/datamodules/generic_multimodal_data_module.py +++ b/terratorch/datamodules/generic_multimodal_data_module.py @@ -134,6 +134,7 @@ def __init__( predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, rgb_modality: str | None = None, rgb_indices: list[int] | None = None, + allow_substring_split_file: bool = False, constant_scale: dict | float = 1., train_transform: dict | A.Compose | None | list[A.BasicTransform] = None, val_transform: dict | A.Compose | None | list[A.BasicTransform] = None, @@ -234,6 +235,7 @@ def __init__( self.train_split = train_split self.val_split = val_split self.test_split = test_split + self.allow_substring_split_file = allow_substring_split_file self.constant_scale = constant_scale if isinstance(self.constant_scale, dict): # Fill in missing modalities @@ -310,6 +312,7 @@ def setup(self, stage: str) -> None: label_grep=self.label_grep, label_data_root=self.train_label_data_root, split=self.train_split, + allow_substring_split_file=self.allow_substring_split_file, dataset_bands=self.dataset_bands, output_bands=self.output_bands, constant_scale=self.constant_scale, @@ -330,6 +333,7 @@ def setup(self, stage: str) -> None: label_grep=self.label_grep, label_data_root=self.val_label_data_root, split=self.val_split, + allow_substring_split_file=self.allow_substring_split_file, dataset_bands=self.dataset_bands, output_bands=self.output_bands, constant_scale=self.constant_scale, @@ -350,6 +354,7 @@ def setup(self, stage: str) -> None: label_grep=self.label_grep, label_data_root=self.test_label_data_root, split=self.test_split, + allow_substring_split_file=self.allow_substring_split_file, dataset_bands=self.dataset_bands, output_bands=self.output_bands, constant_scale=self.constant_scale, @@ -366,6 +371,7 @@ def setup(self, stage: str) -> None: self.predict_dataset = self.dataset_class( data_root=self.predict_root, num_classes=self.num_classes, + allow_substring_split_file=self.allow_substring_split_file, dataset_bands=self.predict_dataset_bands, output_bands=self.predict_output_bands, constant_scale=self.constant_scale, diff --git a/terratorch/datasets/generic_multimodal_dataset.py b/terratorch/datasets/generic_multimodal_dataset.py index 394a535b..24e58950 100644 --- a/terratorch/datasets/generic_multimodal_dataset.py +++ b/terratorch/datasets/generic_multimodal_dataset.py @@ -42,6 +42,7 @@ def __init__( rgb_modality: str | None = None, rgb_indices: list[int] | None = None, allow_missing_modalities: bool = False, # TODO: Not implemented on a data module level yet (collate_fn required). + allow_substring_split_file: bool = False, dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, constant_scale: dict[float] = None, @@ -118,14 +119,19 @@ def __init__( if label_data_root: dir_lists = [glob.glob(os.path.join(r, label_grep)) for r in label_data_root] - segmentation_mask_files = sorted([p for l in dir_lists for p in l]) # Concatenate + image_files['mask'] = sorted([p for l in dir_lists for p in l]) # Concatenate + + if allow_substring_split_file: + # Get exact match of filenames + get_file_id = lambda s: os.path.basename(s) + else: + # Remove file extensions + get_file_id = lambda s: os.path.splitext(os.path.basename(s))[0] if allow_missing_modalities: - valid_files = set([os.path.splitext(os.path.basename(file))[0] - for file in np.concatenate(list(image_files.values()))]) + valid_files = set([get_file_id(file) for file in np.concatenate(list(image_files.values()))]) else: - valid_files = [os.path.splitext(os.path.basename(file))[0] - for file in image_files[self.modalities[0]]] + valid_files = [get_file_id(file) for file in image_files[self.modalities[0]]] self.samples = [] num_modalities = len(self.modalities) + int(label_data_root is not None) @@ -136,16 +142,32 @@ def __init__( for m, m_dirs in data_root.items(): # Iterate over all directories of the current modality for m_dir in m_dirs: - m_files = glob.glob(os.path.join(m_dir, file + image_grep[m])) - if m_files: - sample[m] = m_files[0] - break + if allow_substring_split_file: + # Substring match with image_grep + m_files = glob.glob(os.path.join(m_dir, file + image_grep[m])) + if m_files: + sample[m] = m_files[0] + break + else: + # Exact match + file_path = os.path.join(m_dir, file) + if os.path.isfile(file_path): + sample[m] = file_path + break if label_data_root: for l_dir in label_data_root: - l_files = glob.glob(os.path.join(l_dir, file + label_grep)) - if l_files: - sample['mask'] = l_files[0] - break + if allow_substring_split_file: + # Substring match with label_grep + l_files = glob.glob(os.path.join(l_dir, file + label_grep)) + if l_files: + sample['mask'] = l_files[0] + break + else: + # Exact match + file_path = os.path.join(l_dir, file) + if os.path.isfile(file_path): + sample['mask'] = file_path + break if 'mask' not in sample: # Only add sample if mask is present break From 3565259381ead112f6e844f35ba5de10fe119e1a Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Thu, 31 Oct 2024 18:44:00 +0100 Subject: [PATCH 16/42] Fix multimodal datalaoding Signed-off-by: Benedikt Blumenstiel --- .../generic_multimodal_data_module.py | 32 ++++++----- .../datasets/generic_multimodal_dataset.py | 54 +++++++++++++------ terratorch/models/pixel_wise_model.py | 6 ++- 3 files changed, 63 insertions(+), 29 deletions(-) diff --git a/terratorch/datamodules/generic_multimodal_data_module.py b/terratorch/datamodules/generic_multimodal_data_module.py index d872b6be..b2752d16 100644 --- a/terratorch/datamodules/generic_multimodal_data_module.py +++ b/terratorch/datamodules/generic_multimodal_data_module.py @@ -41,24 +41,30 @@ def wrap_in_compose_is_list(transform_list, additional_targets=None): if isinstance(transform_list, Iterable) else transform_list -class Normalize(Callable): +class MultimodalNormalize(Callable): def __init__(self, means, stds): super().__init__() self.means = means self.stds = stds def __call__(self, batch): - image = batch["image"] - if len(image.shape) == 5: - means = torch.tensor(self.means, device=image.device).view(1, -1, 1, 1, 1) - stds = torch.tensor(self.stds, device=image.device).view(1, -1, 1, 1, 1) - elif len(image.shape) == 4: - means = torch.tensor(self.means, device=image.device).view(1, -1, 1, 1) - stds = torch.tensor(self.stds, device=image.device).view(1, -1, 1, 1) - else: - msg = f"Expected batch to have 5 or 4 dimensions, but got {len(image.shape)}" - raise Exception(msg) - batch["image"] = (image - means) / stds + for m in self.means.keys(): + if m not in batch: + continue + image = batch[m] + if len(image.shape) == 5: + means = torch.tensor(self.means[m], device=image.device).view(1, -1, 1, 1, 1) + stds = torch.tensor(self.stds[m], device=image.device).view(1, -1, 1, 1, 1) + elif len(image.shape) == 4: + means = torch.tensor(self.means[m], device=image.device).view(1, -1, 1, 1) + stds = torch.tensor(self.stds[m], device=image.device).view(1, -1, 1, 1) + elif len(self.means[m]) == 1: + means = torch.tensor(self.means[m], device=image.device) + stds = torch.tensor(self.stds[m], device=image.device) + else: + msg = f"Expected batch to have 5 or 4 dimensions, but got {len(image.shape)}" + raise Exception(msg) + batch[m] = (image - means) / stds return batch @@ -297,7 +303,7 @@ def __init__( means = {m: load_from_file_or_attribute(means[m]) for m in means.keys()} stds = {m: load_from_file_or_attribute(stds[m]) for m in stds.keys()} - self.aug = {m: Normalize(means[m], stds[m]) for m in modalities} + self.aug = MultimodalNormalize(means, stds) self.chunk_data = chunk_data if chunk_data: diff --git a/terratorch/datasets/generic_multimodal_dataset.py b/terratorch/datasets/generic_multimodal_dataset.py index 24e58950..e059e876 100644 --- a/terratorch/datasets/generic_multimodal_dataset.py +++ b/terratorch/datasets/generic_multimodal_dataset.py @@ -5,6 +5,7 @@ import glob import logging import os +import torch from abc import ABC from pathlib import Path from typing import Any @@ -18,7 +19,6 @@ from matplotlib import pyplot as plt from matplotlib.figure import Figure from matplotlib.patches import Rectangle -from torch import Tensor from torchgeo.datasets import NonGeoDataset from terratorch.datasets.utils import (HLSBands, default_transform, filter_valid_files, generate_bands_intervals, @@ -26,6 +26,21 @@ from terratorch.datasets.transforms import MultiModalTransforms +class MultimodalToTensor(): + def __init__(self, modalities): + self.modalities = modalities + def __call__(self, d): + new_dict = {} + for k, v in d.items(): + if not isinstance(v, np.ndarray): + new_dict[k] = v + else: + if k in self.modalities and len(v.shape) >= 3: # Assuming raster modalities with 3+ dimensions + v = np.moveaxis(v, -1, 0) + new_dict[k] = torch.from_numpy(v) + return new_dict + + class GenericMultimodalDataset(NonGeoDataset, ABC): """ This is a generic dataset class to be used for instantiating datasets from arguments. @@ -205,12 +220,12 @@ def __init__( self.filter_indices[m] = [self.dataset_bands[m].index(band) for band in self.output_bands[m]] - # If no transform is given, apply only to transform to torch tensor if isinstance(transform, A.Compose): self.transform = MultiModalTransforms(transform) elif transform is None: - self.transform = to_tensor + self.transform = MultimodalToTensor(self.modalities) + logging.warning(f'Default transforms ') else: # Modality-specific transforms transform = {m: transform[m] if m in transform else default_transform @@ -246,9 +261,8 @@ def __getitem__(self, index: int) -> dict[str, Any]: if modality == 'mask': data = data[0] - # TODO: Assumes all modalities with three dimension and more to be channel-first images - if len(data.shape) >= 3: - # to channels last + if len(data.shape) >= 3: # TODO: Assumes raster modalities by 3+ dimensions. + # to channels last (required by albumentations) data = np.moveaxis(data, -3, -1) if modality in self.filter_indices: @@ -263,7 +277,13 @@ def __getitem__(self, index: int) -> dict[str, Any]: output["mask"] -= 1 if self.transform: output = self.transform(output) - output["filename"] = self.samples[index] + + # Tasks expect data to be stored in 'image', moving modalities to image dict + output = { + 'image': {m: output[m] for m in self.modalities if m in output}, + 'mask': output['mask'] if 'mask' in output else None, + 'filename': self.samples[index] + } return output @@ -296,6 +316,7 @@ def __init__( rgb_modality: str | None = None, rgb_indices: list[str] | None = None, allow_missing_modalities: bool = False, + allow_substring_split_file: bool = False, dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, class_names: list[str] | None = None, @@ -346,6 +367,7 @@ def __init__( rgb_modality=rgb_modality, rgb_indices=rgb_indices, allow_missing_modalities=allow_missing_modalities, + allow_substring_split_file=allow_substring_split_file, dataset_bands=dataset_bands, output_bands=output_bands, constant_scale=constant_scale, @@ -363,7 +385,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: item["mask"] = item["mask"].long() return item - def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure: + def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> Figure: """Plot a sample from the dataset. Args: @@ -381,7 +403,7 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure image = sample[self.rgb_modality] if len(image.shape) == 5: # TODO: Needed? Copied from generic dataest. return - if isinstance(image, Tensor): + if isinstance(image, torch.Tensor): image = image.numpy() image = image.take(self.rgb_indices, axis=0) image = np.transpose(image, (1, 2, 0)) @@ -389,13 +411,13 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure image = np.clip(image, 0, 1) label_mask = sample["mask"] - if isinstance(label_mask, Tensor): + if isinstance(label_mask, torch.Tensor): label_mask = label_mask.numpy() showing_predictions = "prediction" in sample if showing_predictions: prediction_mask = sample["prediction"] - if isinstance(prediction_mask, Tensor): + if isinstance(prediction_mask, torch.Tensor): prediction_mask = prediction_mask.numpy() return self._plot_sample( @@ -460,6 +482,7 @@ def __init__( rgb_modality: str | None = None, rgb_indices: list[int] | None = None, allow_missing_modalities : bool = False, + allow_substring_split_file: bool = False, dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, constant_scale: float = 1, @@ -507,6 +530,7 @@ def __init__( rgb_modality=rgb_modality, rgb_indices=rgb_indices, allow_missing_modalities=allow_missing_modalities, + allow_substring_split_file=allow_substring_split_file, dataset_bands=dataset_bands, output_bands=output_bands, constant_scale=constant_scale, @@ -522,7 +546,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: item["mask"] = item["mask"].float() return item - def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure: + def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> Figure: """Plot a sample from the dataset. Args: @@ -540,7 +564,7 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure image = sample["image"] if len(image.shape) == 5: return - if isinstance(image, Tensor): + if isinstance(image, torch.Tensor): image = image.numpy() image = image.take(self.rgb_indices, axis=0) image = np.transpose(image, (1, 2, 0)) @@ -548,13 +572,13 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure image = np.clip(image, 0, 1) label_mask = sample["mask"] - if isinstance(label_mask, Tensor): + if isinstance(label_mask, torch.Tensor): label_mask = label_mask.numpy() showing_predictions = "prediction" in sample if showing_predictions: prediction_mask = sample["prediction"] - if isinstance(prediction_mask, Tensor): + if isinstance(prediction_mask, torch.Tensor): prediction_mask = prediction_mask.numpy() return self._plot_sample( diff --git a/terratorch/models/pixel_wise_model.py b/terratorch/models/pixel_wise_model.py index 53c33f59..28e86ae0 100644 --- a/terratorch/models/pixel_wise_model.py +++ b/terratorch/models/pixel_wise_model.py @@ -90,7 +90,11 @@ def _check_for_single_channel_and_squeeze(x): def forward(self, x: torch.Tensor) -> ModelOutput: """Sequentially pass `x` through model`s encoder, decoder and heads""" self.check_input_shape(x) - input_size = x.shape[-2:] + if isinstance(x, torch.Tensor): + input_size = x.shape[-2:] + elif isinstance(x, dict): + # Multimodal input in passed as dict + input_size = list(x.values())[0].shape[-2:] features = self.encoder(x) ## only for backwards compatibility with pre-neck times. From 1f610e73da0b69f8c4aab1885328070a64706ff6 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Fri, 1 Nov 2024 12:54:49 +0100 Subject: [PATCH 17/42] Align multimodal data structure to single modal datasets Signed-off-by: Benedikt Blumenstiel --- .../datamodules/generic_multimodal_data_module.py | 7 ++++++- terratorch/datasets/generic_multimodal_dataset.py | 13 ++++++++++--- terratorch/datasets/transforms.py | 2 +- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/terratorch/datamodules/generic_multimodal_data_module.py b/terratorch/datamodules/generic_multimodal_data_module.py index b2752d16..f9c91411 100644 --- a/terratorch/datamodules/generic_multimodal_data_module.py +++ b/terratorch/datamodules/generic_multimodal_data_module.py @@ -28,6 +28,8 @@ def collate_chunk_dicts(batch_list): for key, value in batch_list[0].items(): # TODO: Handle missing modalities when allow_missing_modalities is set. if isinstance(value, torch.Tensor): batch[key] = torch.concat([chunk[key] for chunk in batch_list]) + elif isinstance(value, dict): + batch[key] = collate_chunk_dicts([chunk[key] for chunk in batch_list]) else: batch[key] = [chunk[key] for chunk in batch_list] return batch @@ -53,16 +55,19 @@ def __call__(self, batch): continue image = batch[m] if len(image.shape) == 5: + # B, C, T, H, W means = torch.tensor(self.means[m], device=image.device).view(1, -1, 1, 1, 1) stds = torch.tensor(self.stds[m], device=image.device).view(1, -1, 1, 1, 1) elif len(image.shape) == 4: + # B, C, H, W means = torch.tensor(self.means[m], device=image.device).view(1, -1, 1, 1) stds = torch.tensor(self.stds[m], device=image.device).view(1, -1, 1, 1) elif len(self.means[m]) == 1: + # B, (T,) H, W means = torch.tensor(self.means[m], device=image.device) stds = torch.tensor(self.stds[m], device=image.device) else: - msg = f"Expected batch to have 5 or 4 dimensions, but got {len(image.shape)}" + msg = f"Expected batch to have 5 or 4 dimensions or a single channel, but got {len(image.shape)}" raise Exception(msg) batch[m] = (image - means) / stds return batch diff --git a/terratorch/datasets/generic_multimodal_dataset.py b/terratorch/datasets/generic_multimodal_dataset.py index e059e876..4b4bb3cd 100644 --- a/terratorch/datasets/generic_multimodal_dataset.py +++ b/terratorch/datasets/generic_multimodal_dataset.py @@ -35,8 +35,14 @@ def __call__(self, d): if not isinstance(v, np.ndarray): new_dict[k] = v else: + # TODO: This code has hard assumptions on the data structure if k in self.modalities and len(v.shape) >= 3: # Assuming raster modalities with 3+ dimensions - v = np.moveaxis(v, -1, 0) + if len(v.shape) <= 4: + v = np.moveaxis(v, -1, 0) # C, H, W or C, T, H, W + elif len(v.shape) == 5: + v = np.moveaxis(v, -1, 1) # B, C, T, H, W + else: + raise ValueError(f'Unexpected shape for {k}: {v.shape}') new_dict[k] = torch.from_numpy(v) return new_dict @@ -261,7 +267,8 @@ def __getitem__(self, index: int) -> dict[str, Any]: if modality == 'mask': data = data[0] - if len(data.shape) >= 3: # TODO: Assumes raster modalities by 3+ dimensions. + if len(data.shape) >= 3: + # TODO: Assumes data structure to be (B), (T), C, H, W but could also be C, T, H, W # to channels last (required by albumentations) data = np.moveaxis(data, -3, -1) @@ -401,7 +408,7 @@ def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> 'Set `export TERRATORCH_NUM_VAL_PLOTS=0` before running terratorch.') image = sample[self.rgb_modality] - if len(image.shape) == 5: # TODO: Needed? Copied from generic dataest. + if len(image.shape) == 5: # TODO: Fix plot code. return if isinstance(image, torch.Tensor): image = image.numpy() diff --git a/terratorch/datasets/transforms.py b/terratorch/datasets/transforms.py index 3e089193..a213ea28 100644 --- a/terratorch/datasets/transforms.py +++ b/terratorch/datasets/transforms.py @@ -114,7 +114,7 @@ def __init__( def apply(self, img, **params): if self.time_dim: rearranged = rearrange( - img, "(samples time channels) height width -> samples time channels height width", + img, "(samples time channels) height width -> samples channels time height width", **self.additional_info ) else: From 0f85428a5eb7a7672bb5e6285b8a99be9df5f90b Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Fri, 1 Nov 2024 13:07:06 +0100 Subject: [PATCH 18/42] Add channel pos arg and fix multimodal transforms Signed-off-by: Benedikt Blumenstiel --- .../generic_multimodal_data_module.py | 12 ++++++++--- .../datasets/generic_multimodal_dataset.py | 21 +++++++++++-------- terratorch/datasets/transforms.py | 2 +- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/terratorch/datamodules/generic_multimodal_data_module.py b/terratorch/datamodules/generic_multimodal_data_module.py index f9c91411..67aa0463 100644 --- a/terratorch/datamodules/generic_multimodal_data_module.py +++ b/terratorch/datamodules/generic_multimodal_data_module.py @@ -51,9 +51,9 @@ def __init__(self, means, stds): def __call__(self, batch): for m in self.means.keys(): - if m not in batch: + if m not in batch['image']: continue - image = batch[m] + image = batch['image'][m] if len(image.shape) == 5: # B, C, T, H, W means = torch.tensor(self.means[m], device=image.device).view(1, -1, 1, 1, 1) @@ -69,7 +69,7 @@ def __call__(self, batch): else: msg = f"Expected batch to have 5 or 4 dimensions or a single channel, but got {len(image.shape)}" raise Exception(msg) - batch[m] = (image - means) / stds + batch['image'][m] = (image - means) / stds return batch @@ -160,6 +160,7 @@ def __init__( chunk_data: bool = False, sample_num_modalities: int | None = None, sample_replace: bool = False, + channel_position: int = -3, **kwargs: Any, ) -> None: """Constructor @@ -271,6 +272,7 @@ def __init__( self.rgb_indices = rgb_indices self.expand_temporal_dimension = expand_temporal_dimension self.reduce_zero_label = reduce_zero_label + self.channel_position = channel_position # Transforms can be None (leads to to_tensor default), shared between modalities or individual per modality if shared_transforms: @@ -334,6 +336,7 @@ def setup(self, stage: str) -> None: no_label_replace=self.no_label_replace, expand_temporal_dimension=self.expand_temporal_dimension, reduce_zero_label=self.reduce_zero_label, + channel_position=self.channel_position, ) logging.info(f'Train dataset: {len(self.train_dataset)}') if stage in ["fit", "validate"]: @@ -355,6 +358,7 @@ def setup(self, stage: str) -> None: no_label_replace=self.no_label_replace, expand_temporal_dimension=self.expand_temporal_dimension, reduce_zero_label=self.reduce_zero_label, + channel_position=self.channel_position, ) logging.info(f'Val dataset: {len(self.val_dataset)}') if stage in ["test"]: @@ -376,6 +380,7 @@ def setup(self, stage: str) -> None: no_label_replace=self.no_label_replace, expand_temporal_dimension=self.expand_temporal_dimension, reduce_zero_label=self.reduce_zero_label, + channel_position=self.channel_position, ) logging.info(f'Test dataset: {len(self.test_dataset)}') if stage in ["predict"] and self.predict_root: @@ -393,6 +398,7 @@ def setup(self, stage: str) -> None: no_label_replace=self.no_label_replace, expand_temporal_dimension=self.expand_temporal_dimension, reduce_zero_label=self.reduce_zero_label, + channel_position=self.channel_position, ) logging.info(f'Predict dataset: {len(self.predict_dataset)}') diff --git a/terratorch/datasets/generic_multimodal_dataset.py b/terratorch/datasets/generic_multimodal_dataset.py index 4b4bb3cd..a408cffc 100644 --- a/terratorch/datasets/generic_multimodal_dataset.py +++ b/terratorch/datasets/generic_multimodal_dataset.py @@ -21,9 +21,8 @@ from matplotlib.patches import Rectangle from torchgeo.datasets import NonGeoDataset -from terratorch.datasets.utils import (HLSBands, default_transform, filter_valid_files, generate_bands_intervals, - to_tensor) -from terratorch.datasets.transforms import MultiModalTransforms +from terratorch.datasets.utils import HLSBands, default_transform, filter_valid_files, generate_bands_intervals +from terratorch.datasets.transforms import MultimodalTransforms class MultimodalToTensor(): @@ -72,6 +71,7 @@ def __init__( no_label_replace: int | None = None, expand_temporal_dimension: bool = False, reduce_zero_label: bool = False, + channel_position: int = -1, *args, **kwargs, ) -> None: """Constructor @@ -121,6 +121,7 @@ def __init__( self.no_label_replace = no_label_replace self.reduce_zero_label = reduce_zero_label self.expand_temporal_dimension = expand_temporal_dimension + self.channel_position = channel_position if self.expand_temporal_dimension and len(dataset_bands) != self.modalities: msg = "Please provide dataset_bands for each modality when expand_temporal_dimension is True" @@ -228,15 +229,14 @@ def __init__( # If no transform is given, apply only to transform to torch tensor if isinstance(transform, A.Compose): - self.transform = MultiModalTransforms(transform) + self.transform = MultimodalTransforms(transform) elif transform is None: self.transform = MultimodalToTensor(self.modalities) - logging.warning(f'Default transforms ') else: # Modality-specific transforms transform = {m: transform[m] if m in transform else default_transform for m in self.modalities} - self.transform = MultiModalTransforms(transform, shared=False) + self.transform = MultimodalTransforms(transform, shared=False) import warnings import rasterio @@ -267,10 +267,9 @@ def __getitem__(self, index: int) -> dict[str, Any]: if modality == 'mask': data = data[0] - if len(data.shape) >= 3: - # TODO: Assumes data structure to be (B), (T), C, H, W but could also be C, T, H, W + if len(data.shape) >= 3 and self.channel_position: # to channels last (required by albumentations) - data = np.moveaxis(data, -3, -1) + data = np.moveaxis(data, self.channel_position, -1) if modality in self.filter_indices: data = data[..., self.filter_indices[modality]] @@ -333,6 +332,7 @@ def __init__( no_label_replace: int | None = None, expand_temporal_dimension: bool = False, reduce_zero_label: bool = False, + channel_position: int = -3, ) -> None: """Constructor @@ -383,6 +383,7 @@ def __init__( no_label_replace=no_label_replace, expand_temporal_dimension=expand_temporal_dimension, reduce_zero_label=reduce_zero_label, + channel_position=channel_position, ) self.num_classes = num_classes self.class_names = class_names @@ -498,6 +499,7 @@ def __init__( no_label_replace: int | None = None, expand_temporal_dimension: bool = False, reduce_zero_label: bool = False, + channel_position: int = -3, ) -> None: """Constructor @@ -546,6 +548,7 @@ def __init__( no_label_replace=no_label_replace, expand_temporal_dimension=expand_temporal_dimension, reduce_zero_label=reduce_zero_label, + channel_position=channel_position, ) def __getitem__(self, index: int) -> dict[str, Any]: diff --git a/terratorch/datasets/transforms.py b/terratorch/datasets/transforms.py index a213ea28..6f28b6b8 100644 --- a/terratorch/datasets/transforms.py +++ b/terratorch/datasets/transforms.py @@ -158,7 +158,7 @@ def get_transform_init_args_names(self): return "band_indices" -class MultiModalTransforms: +class MultimodalTransforms: """Applies albumentations transforms to multiple images""" def __init__(self, transforms: dict | A.Compose, shared : bool = True): self.transforms = transforms From 1cbe5dde336f9985c478ae9c0a387845ab033463 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Tue, 5 Nov 2024 09:57:26 +0100 Subject: [PATCH 19/42] Added MultiMAE Signed-off-by: Benedikt Blumenstiel --- examples/confs/multimae_sen1floods11.yaml | 165 ++++ terratorch/models/backbones/__init__.py | 1 + .../models/backbones/multimae/criterion.py | 195 +++++ .../backbones/multimae/input_adapters.py | 289 +++++++ .../models/backbones/multimae/multimae.py | 488 +++++++++++ .../backbones/multimae/multimae_utils.py | 336 ++++++++ .../multimae/output_adapter_utils.py | 303 +++++++ .../backbones/multimae/output_adapters.py | 759 ++++++++++++++++++ .../models/backbones/multimae_register.py | 389 +++++++++ 9 files changed, 2925 insertions(+) create mode 100644 examples/confs/multimae_sen1floods11.yaml create mode 100644 terratorch/models/backbones/multimae/criterion.py create mode 100644 terratorch/models/backbones/multimae/input_adapters.py create mode 100644 terratorch/models/backbones/multimae/multimae.py create mode 100644 terratorch/models/backbones/multimae/multimae_utils.py create mode 100644 terratorch/models/backbones/multimae/output_adapter_utils.py create mode 100644 terratorch/models/backbones/multimae/output_adapters.py create mode 100644 terratorch/models/backbones/multimae_register.py diff --git a/examples/confs/multimae_sen1floods11.yaml b/examples/confs/multimae_sen1floods11.yaml new file mode 100644 index 00000000..67ee4174 --- /dev/null +++ b/examples/confs/multimae_sen1floods11.yaml @@ -0,0 +1,165 @@ +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: output + name: multimae_sen1floods11 + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 40 + + max_epochs: 2 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + enable_checkpointing: true + default_root_dir: output/multimae_sen1floods11/ + +data: + class_path: GenericMultiModalDataModule + init_args: + task: 'segmentation' + batch_size: 4 + num_workers: 0 + modalities: + - S2L2A + - S1 + - LULC + rgb_modality: S2L2A # If not provided, uses first modality + rgb_indices: + - 2 + - 1 + - 0 + + train_data_root: + S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand + S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand + LULC: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand + train_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand + val_data_root: + S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand + S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand + LULC: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand + val_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand + test_data_root: + S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand + S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand + LULC: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand + test_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand + + train_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_train.txt + val_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_valid.txt + test_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_test.txt + + allow_substring_split_file: True + img_grep: + S2L2A: "*_S2L2AHand.tif" + S1: "*_S1Hand.tif" + LULC: "*_LULCHand.npy" + label_grep: "*_LabelHand.tif" + no_label_replace: -1 + no_data_replace: 0 + + means: + S2L2A: + - 1793.243 + - 1924.863 + - 2184.553 + - 2340.936 + - 2671.402 + - 3240.082 + - 3468.412 + - 3563.244 + - 3627.704 + - 3711.071 + - 3416.714 + - 2849.625 + S1: + - -12.577 + - -20.265 + LULC: + - 0 + stds: + S2L2A: + - 1160.144 + - 1201.092 + - 1219.943 + - 1397.225 + - 1400.035 + - 1373.136 + - 1429.17 + - 1485.025 + - 1447.836 + - 1652.703 + - 1471.002 + - 1365.30 + S1: + - 5.179 + - 5.872 + LULC: + - 1 + + num_classes: 2 + + train_transform: + - class_path: albumentations.RandomCrop + init_args: + height: 224 + width: 224 + - class_path: albumentations.HorizontalFlip + init_args: + p: 0.5 + - class_path: ToTensorV2 + + +model: + class_path: terratorch.tasks.SemanticSegmentationTask + init_args: + model_factory: EncoderDecoderFactory + model_args: + decoder: FCNDecoder + backbone_pretrained: false + backbone: multimae_base + backbone_input_adapters: + - S1 + - S2L2A + - LULC + decoder_channels: 256 + num_classes: 2 + head_dropout: 0.1 + decoder_num_convs: 4 + head_channel_list: + - 256 + loss: ce + ignore_index: -1 + class_weights: + - 0.3 + - 0.7 + class_names: + - Others + - Flood + freeze_backbone: false + freeze_decoder: false + +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 6.e-5 + weight_decay: 0.05 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss + diff --git a/terratorch/models/backbones/__init__.py b/terratorch/models/backbones/__init__.py index da4dba5d..aa36b3f9 100644 --- a/terratorch/models/backbones/__init__.py +++ b/terratorch/models/backbones/__init__.py @@ -4,5 +4,6 @@ import terratorch.models.backbones.prithvi_swin import terratorch.models.backbones.prithvi_vit import terratorch.models.backbones.scalemae +import terratorch.models.backbones.multimae_register from terratorch.models.backbones.unet import UNet diff --git a/terratorch/models/backbones/multimae/criterion.py b/terratorch/models/backbones/multimae/criterion.py new file mode 100644 index 00000000..5f3e14d6 --- /dev/null +++ b/terratorch/models/backbones/multimae/criterion.py @@ -0,0 +1,195 @@ +# Copyright (c) EPFL VILAB. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Based on timm, DeiT, DINO, MoCo-v3, BEiT, MAE-priv and MAE code bases +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/facebookresearch/deit +# https://github.com/facebookresearch/dino +# https://github.com/facebookresearch/moco-v3 +# https://github.com/microsoft/unilm/tree/master/beit +# https://github.com/BUPT-PRIV/MAE-priv +# https://github.com/facebookresearch/mae +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +class MaskedCrossEntropyLoss(nn.Module): + """Cross-entropy loss with masking + :param patch_size: Patch size + :param stride: Stride of task / modality + :param label_smoothing: Amount of smoothing in the loss (default is 0.0) + """ + + def __init__( + self, patch_size: int = 16, stride: int = 1, label_smoothing: float = 0.0 + ): + super().__init__() + self.patch_size = patch_size + self.stride = stride + self.scale_factor = patch_size // stride + self.label_smoothing = label_smoothing + + def forward(self, input, target, mask=None): + + loss = F.cross_entropy( + input, target, reduction="none", label_smoothing=self.label_smoothing + ) + + if mask is not None: + if mask.sum() == 0: + return torch.tensor(0).to(loss.device) + + H, W = input.shape[-2:] + nh, nw = H // self.scale_factor, W // self.scale_factor + # Resize mask and upsample + mask = rearrange(mask, "b (nh nw) -> b nh nw", nh=nh, nw=nw) + mask = F.interpolate( + mask.unsqueeze(1).float(), size=(H, W), mode="nearest" + ).squeeze(1) + loss = loss * mask + # Compute mean per sample + loss = loss.flatten(start_dim=1).sum(dim=1) / mask.flatten(start_dim=1).sum( + dim=1 + ) + loss = loss.nanmean() # Account for zero masks + else: + loss = loss.mean() # If this is ever nan, we want it to stop training + + return loss + + +class MaskedMSELoss(nn.Module): + """L1 loss with masking + :param patch_size: Patch size + :param stride: Stride of task / modality + :param norm_pix: Normalized pixel loss + """ + + def __init__(self, patch_size: int = 16, stride: int = 1, norm_pix=False): + super().__init__() + self.patch_size = patch_size + self.stride = stride + self.scale_factor = patch_size // stride + self.norm_pix = norm_pix + + def patchify(self, imgs, nh, nw): + p = self.scale_factor + x = rearrange( + imgs, "b c (nh p1) (nw p2) -> b (nh nw) (p1 p2 c)", nh=nh, nw=nw, p1=p, p2=p + ) + return x + + def unpatchify(self, x, nh, nw): + p = self.scale_factor + imgs = rearrange( + x, "b (nh nw) (p1 p2 c) -> b c (nh p1) (nw p2)", nh=nh, nw=nw, p1=p, p2=p + ) + return imgs + + def forward(self, input, target, mask=None): + + H, W = input.shape[-2:] + nh, nw = H // self.scale_factor, W // self.scale_factor + + if self.norm_pix: + target = self.patchify(target, nh, nw) + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + eps = 1e-6 + target = (target - mean) / torch.sqrt(var + eps) + target = self.unpatchify(target, nh, nw) + + loss = F.mse_loss(input, target, reduction="none") + + if mask is not None: + if mask.sum() == 0: + return torch.tensor(0).to(loss.device) + + # Resize mask and upsample + mask = rearrange(mask, "b (nh nw) -> b nh nw", nh=nh, nw=nw) + mask = F.interpolate( + mask.unsqueeze(1).float(), size=(H, W), mode="nearest" + ).squeeze(1) + loss = loss.mean(dim=1) # B, C, H, W -> B, H, W + loss = loss * mask + # Compute mean per sample + loss = loss.flatten(start_dim=1).sum(dim=1) / mask.flatten(start_dim=1).sum( + dim=1 + ) + loss = loss.nanmean() # Account for zero masks + else: + loss = loss.mean() # If this is ever nan, we want it to stop training + + return loss + + +class MaskedL1Loss(nn.Module): + """L1 loss with masking + :param patch_size: Patch size + :param stride: Stride of task / modality + :param norm_pix: Normalized pixel loss + """ + + def __init__(self, patch_size: int = 16, stride: int = 1, norm_pix=False): + super().__init__() + self.patch_size = patch_size + self.stride = stride + self.scale_factor = patch_size // stride + self.norm_pix = norm_pix + + def patchify(self, imgs, nh, nw): + p = self.scale_factor + x = rearrange( + imgs, "b c (nh p1) (nw p2) -> b (nh nw) (p1 p2 c)", nh=nh, nw=nw, p1=p, p2=p + ) + return x + + def unpatchify(self, x, nh, nw): + p = self.scale_factor + imgs = rearrange( + x, "b (nh nw) (p1 p2 c) -> b c (nh p1) (nw p2)", nh=nh, nw=nw, p1=p, p2=p + ) + return imgs + + def forward(self, input, target, mask=None): + + H, W = input.shape[-2:] + nh, nw = H // self.scale_factor, W // self.scale_factor + + if self.norm_pix: + target = self.patchify(target, nh, nw) + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + eps = 1e-6 + target = (target - mean) / torch.sqrt(var + eps) + target = self.unpatchify(target, nh, nw) + + loss = F.l1_loss(input, target, reduction="none") + + if mask is not None: + if mask.sum() == 0: + return torch.tensor(0).to(loss.device) + + # Resize mask and upsample + mask = rearrange(mask, "b (nh nw) -> b nh nw", nh=nh, nw=nw) + mask = F.interpolate( + mask.unsqueeze(1).float(), size=(H, W), mode="nearest" + ).squeeze(1) + loss = loss.mean(dim=1) # B, C, H, W -> B, H, W + loss = loss * mask + # Compute mean per sample + loss = loss.flatten(start_dim=1).sum(dim=1) / mask.flatten(start_dim=1).sum( + dim=1 + ) + loss = loss.nanmean() # Account for zero masks + else: + loss = loss.mean() # If this is ever nan, we want it to stop training + + return loss diff --git a/terratorch/models/backbones/multimae/input_adapters.py b/terratorch/models/backbones/multimae/input_adapters.py new file mode 100644 index 00000000..464a93fa --- /dev/null +++ b/terratorch/models/backbones/multimae/input_adapters.py @@ -0,0 +1,289 @@ +# Copyright (c) EPFL VILAB. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Based on timm, DeiT, DINO, MoCo-v3, BEiT, MAE-priv and MAE code bases +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/facebookresearch/deit +# https://github.com/facebookresearch/dino +# https://github.com/facebookresearch/moco-v3 +# https://github.com/microsoft/unilm/tree/master/beit +# https://github.com/BUPT-PRIV/MAE-priv +# https://github.com/facebookresearch/mae +# -------------------------------------------------------- + + +import math + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +from terratorch.models.backbones.multimae.multimae_utils import build_2d_sincos_posemb, pair, trunc_normal_ + + +class PatchedInputAdapter(nn.Module): + """Adapter for spatial inputs, like images or feature maps. + Creates tokens from patches over the image. + + :param num_channels: Number of input channels of the image/feature map + :param stride_level: Stride level compared to the full-sized image. + E.g. 4 for 1/4th the size of the image. + :param patch_size_full: Int or tuple of the patch size over the full image size. + Patch size for smaller inputs will be computed accordingly. + :param dim_tokens: Dimension of output tokens. Can be set using init method. + :param sincos_pos_emb: Set to True (default) to use fixed 2D sin-cos positional embeddings + :param learnable_pos_emb: Set to True to learn positional embeddings instead + :param image_size: Default image size. Used to initialize size of positional embeddings. + """ + + def __init__( + self, + num_channels: int, + stride_level: int, + patch_size_full: int | tuple[int, int], + dim_tokens: int | None = None, + sincos_pos_emb: bool = True, + learnable_pos_emb: bool = False, + image_size: int | tuple[int] = 224, + ): + + super().__init__() + self.num_channels = num_channels + self.stride_level = stride_level + self.patch_size_full = pair(patch_size_full) + self.dim_tokens = dim_tokens + self.sincos_pos_emb = sincos_pos_emb + self.learnable_pos_emb = learnable_pos_emb + self.image_size = pair(image_size) + self.num_patches = (self.image_size[0] // patch_size_full) * ( + self.image_size[1] // patch_size_full + ) + + # Actual patch height and width, taking into account stride of input + self.P_H = max(1, self.patch_size_full[0] // stride_level) + self.P_W = max(1, self.patch_size_full[1] // stride_level) + + if self.dim_tokens is not None: + self.init(dim_tokens=dim_tokens) + + def init(self, dim_tokens: int = 768): + """ + Initialize parts of encoder that are dependent on dimension of tokens. + Should be called when setting up MultiMAE. + + :param dim_tokens: Dimension of tokens + """ + self.dim_tokens = dim_tokens + + # Task embedding identifying from which task a given token comes from + # Fixed-size positional embeddings. Can be interpolated to different input sizes + h_posemb = self.image_size[0] // (self.stride_level * self.P_H) + w_posemb = self.image_size[1] // (self.stride_level * self.P_W) + if self.sincos_pos_emb: + self.pos_emb = build_2d_sincos_posemb( + h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens + ) + self.pos_emb = nn.Parameter( + self.pos_emb, requires_grad=self.learnable_pos_emb + ) + else: + self.pos_emb = nn.Parameter( + torch.zeros(1, self.dim_tokens, h_posemb, w_posemb) + ) + trunc_normal_(self.pos_emb, std=0.02) + + # Image -> tokens projection + self.proj = nn.Conv2d( + in_channels=self.num_channels, + out_channels=self.dim_tokens, + kernel_size=(self.P_H, self.P_W), + stride=(self.P_H, self.P_W), + padding=0, + # TODO: Not working with padding! Misalignment between pos embed and x_patch. + # padding=(self.P_H // 2, self.P_W // 2) # padding taken from segformer so it should work for any image size + ) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_emb"} + + def forward(self, x): + """ + Forward pass through input adapter, transforming image to sequence of tokens. + Adds task and positional encodings. + + :param x: Input image tensor + """ + B, C, H, W = x.shape + assert ( + self.dim_tokens is not None + ), "Need to call init(dim_tokens) function first" + # assert (H % self.P_H == 0) and ( + # W % self.P_W == 0 + # ), f"Image sizes {H}x{W} must be divisible by patch sizes {self.P_H}x{self.P_W}" + N_H, N_W = math.ceil(H / self.P_H), math.ceil(W / self.P_W) # Number of patches in height and width. Adapted for segformer padding + + # Create patches [B, C, H, W] -> [B, (H*W), C] + x_patch = rearrange(self.proj(x), "b d nh nw -> b (nh nw) d") + + # Create positional embedding + x_pos_emb = F.interpolate( + self.pos_emb, size=(N_H, N_W), mode="bicubic", align_corners=False + ) + x_pos_emb = rearrange(x_pos_emb, "b d nh nw -> b (nh nw) d") + + # Add patches and positional embeddings + x = x_patch + x_pos_emb + + return x + + +class SemSegInputAdapter(nn.Module): + """ + Adapter for spatial inputs, like images or feature maps. + Creates tokens from patches over the image. + + :param num_classes: Number of input semantic classes + :param stride_level: Stride level compared to the full-sized image. + E.g. 4 for 1/4th the size of the image. + :param patch_size_full: Int or tuple of the patch size over the full image size. + Patch size for smaller inputs will be computed accordingly. + :param dim_tokens: Dimension of output tokens. Can be set using init method. + :param sincos_pos_emb: Set to True (default) to use fixed 2D sin-cos positional embeddings + :param learnable_pos_emb: Set to True to learn positional embeddings instead + :param image_size: Default image size. Used to initialize size of positional embeddings. + :param dim_class_emb: Dimension of learned class embedding + :param interpolate_class_emb: Set to True to average pool class embeddings of each patch + :param emb_padding_idx: Padding index (e.g. image border), default is None + """ + + def __init__( + self, + num_classes: int, + stride_level: int, + patch_size_full: int | tuple[int, int], + dim_tokens: int | None = None, + sincos_pos_emb: int = True, + learnable_pos_emb: int = False, + image_size: int | tuple[int] = 224, + dim_class_emb: int = 64, + interpolate_class_emb: bool = False, + emb_padding_idx: int | None = None, + ): + super().__init__() + self.num_classes = num_classes + self.stride_level = stride_level + self.patch_size_full = pair(patch_size_full) + self.dim_tokens = dim_tokens + self.sincos_pos_emb = sincos_pos_emb + self.learnable_pos_emb = learnable_pos_emb + self.image_size = pair(image_size) + self.dim_class_emb = dim_class_emb + self.interpolate_class_emb = interpolate_class_emb + self.emb_padding_idx = emb_padding_idx + if self.emb_padding_idx is not None: + self.num_classes += 1 + + # Actual patch height and width, taking into account stride of input + self.P_H = max(1, self.patch_size_full[0] // stride_level) + self.P_W = max(1, self.patch_size_full[1] // stride_level) + + if self.dim_tokens is not None: + self.init(dim_tokens=dim_tokens) + + def init(self, dim_tokens: int = 768): + """ + Initialize parts of encoder that are dependent on dimension of tokens. + Should be called when setting up MultiMAE. + + :param dim_tokens: Dimension of tokens + """ + self.dim_tokens = dim_tokens + + # Task embedding identifying from which task a given token comes from + # Fixed-size positional embeddings. Can be interpolated to different input sizes + h_posemb = self.image_size[0] // (self.stride_level * self.P_H) + w_posemb = self.image_size[1] // (self.stride_level * self.P_W) + if self.sincos_pos_emb: + self.pos_emb = build_2d_sincos_posemb( + h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens + ) + self.pos_emb = nn.Parameter( + self.pos_emb, requires_grad=self.learnable_pos_emb + ) + else: + self.pos_emb = nn.Parameter( + torch.zeros(1, self.dim_tokens, h_posemb, w_posemb) + ) + trunc_normal_(self.pos_emb, std=0.02) + + # Image -> tokens projection + self.class_emb = nn.Embedding( + num_embeddings=self.num_classes, + embedding_dim=self.dim_class_emb, + padding_idx=self.emb_padding_idx, + ) + trunc_normal_(self.class_emb.weight, std=0.02) + + if self.interpolate_class_emb: + self.proj = nn.Sequential( + nn.Upsample( + scale_factor=(1 / self.P_H, 1 / self.P_W), mode="bilinear" + ), # Actually a downsample operation + nn.Conv2d( + in_channels=self.dim_class_emb, + out_channels=self.dim_tokens, + kernel_size=1, + stride=1, + ), + ) + else: + self.proj = nn.Conv2d( + in_channels=self.dim_class_emb, + out_channels=self.dim_tokens, + kernel_size=(self.P_H, self.P_W), + stride=(self.P_H, self.P_W), + ) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_emb", "class_emb"} + + def forward(self, x): + """ + Forward pass through input adapter, transforming image to sequence of tokens. + Adds task and positional encodings. + + :param x: Input image tensor + """ + if len(x.shape) == 4: + # Remove channel dim + x = x.squeeze() + + B, H, W = x.shape + assert ( + self.dim_tokens is not None + ), "Need to call init(dim_tokens) function first" + assert (H % self.P_H == 0) and ( + W % self.P_W == 0 + ), f"Image sizes {H}x{W} must be divisible by patch sizes {self.P_H}x{self.P_W}" + N_H, N_W = H // self.P_H, W // self.P_W # Number of patches in height and width + + # Map to embedding + x = rearrange(self.class_emb(x.to(int)), "b nh nw c -> b c nh nw") + + # Create patches [B, C, H, W] -> [B, (H*W), C] + x_patch = rearrange(self.proj(x), "b d nh nw -> b (nh nw) d") + + # Create positional embedding + x_pos_emb = F.interpolate(self.pos_emb, size=(N_H, N_W), mode="bilinear") + x_pos_emb = rearrange(x_pos_emb, "b d nh nw -> b (nh nw) d") + + # Add patches and positional embeddings + x = x_patch + x_pos_emb + + return x diff --git a/terratorch/models/backbones/multimae/multimae.py b/terratorch/models/backbones/multimae/multimae.py new file mode 100644 index 00000000..1f12bcc8 --- /dev/null +++ b/terratorch/models/backbones/multimae/multimae.py @@ -0,0 +1,488 @@ +# Copyright (c) EPFL VILAB. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Based on timm, DeiT, DINO, MoCo-v3, BEiT, MAE-priv and MAE code bases +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/facebookresearch/deit +# https://github.com/facebookresearch/dino +# https://github.com/facebookresearch/moco-v3 +# https://github.com/microsoft/unilm/tree/master/beit +# https://github.com/BUPT-PRIV/MAE-priv +# https://github.com/facebookresearch/mae +# -------------------------------------------------------- + +import itertools +import math +from collections import OrderedDict +from functools import partial + +import torch +from einops import repeat +from torch import nn +from torch.distributions.dirichlet import Dirichlet + +from .multimae_utils import Block, trunc_normal_ + + +class MultiMAE(nn.Module): + """MultiMAE: Multi-task Multi-modal Masked Autoencoder + This module performs masking in its forward pass. + The MultiViT module defined below inherits from this module and performs a regular forward pass, + and should be used instead for downstream tasks + + + :param input_adapters: Dictionary of task -> input adapters + :param output_adapters: Optional dictionary of task -> output adapters + + :param num_global_tokens: Number of additional global tokens to add (like cls tokens), default is 1 + :param dim_tokens: Dimension of encoder tokens + :param depth: Depth of encoder + :param num_heads: Number of attention heads + :param mlp_ratio: MLP hidden dim ratio + :param qkv_bias: Set to False to disable bias + :param drop_rate: Dropout after MLPs and Attention + :param attn_drop_rate: Attention matrix drop rate + :param drop_path_rate: DropPath drop rate + :param norm_layer: Type of normalization layer + """ + + default_norm_layer = partial(nn.LayerNorm, eps=1e-6) + + def __init__( + self, + input_adapters: dict[str, nn.Module], + output_adapters: dict[str, nn.Module] | None, + num_global_tokens: int = 1, + dim_tokens: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + norm_layer: nn.Module = default_norm_layer, + merging_method: str = None, + **kwargs, + ): + super().__init__() + + # Initialize input and output adapters + for adapter in input_adapters.values(): + adapter.init(dim_tokens=dim_tokens) + self.input_adapters = nn.ModuleDict(input_adapters) + if output_adapters is not None: + for adapter in output_adapters.values(): + adapter.init(dim_tokens_enc=dim_tokens) + self.output_adapters = nn.ModuleDict(output_adapters) + else: + self.output_adapters = None + + # Additional learnable tokens that can be used by encoder to process/store global information + self.num_global_tokens = num_global_tokens + self.global_tokens = nn.Parameter(torch.zeros(1, num_global_tokens, dim_tokens)) + trunc_normal_(self.global_tokens, std=0.02) + + # Transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + # Encoder init is adapted for timm registry + self.feature_info = [] + self.layers = [] + scale = 1 + self.out_channels = [] + if merging_method == 'concat': # TODO: Move prepare/concat to this model forward? + embed_factor = len(input_adapters) + else: + embed_factor = 1 + + for i in range(depth): + layer = Block( + dim=dim_tokens, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + ) + self.layers.append(layer) + # TODO: Scale needed? Check what reduction means + if i > 0: + scale *= 2 + self.feature_info += [ + { + "num_chs": int(dim_tokens) * embed_factor, + "reduction": scale, + "module": f"layers.{i}", + } + ] + + self.layers: nn.ModuleList = nn.ModuleList(self.layers) # added for compatibility with timm features_only + + self.apply(self._init_weights) + for name, m in self.named_modules(): + if isinstance(m, nn.Linear): + if "qkv" in name: + # treat the weights of Q, K, V separately + val = math.sqrt(6.0 / float(m.weight.shape[0] // 3 + m.weight.shape[1])) + nn.init.uniform_(m.weight, -val, val) + elif "kv" in name: + # treat the weights of K, V separately + val = math.sqrt(6.0 / float(m.weight.shape[0] // 2 + m.weight.shape[1])) + nn.init.uniform_(m.weight, -val, val) + + if isinstance(m, nn.Conv2d): + if ".proj" in name: + # From MAE, initialize projection like nn.Linear (instead of nn.Conv2d) + w = m.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_num_layers(self): + return len(self.layers) + + @torch.jit.ignore + def no_weight_decay(self): + no_wd_set = {"global_tokens"} + + for task, adapter in self.input_adapters.items(): + if hasattr(adapter, "no_weight_decay"): + to_skip = adapter.no_weight_decay() + to_skip = {f"input_adapters.{task}.{name}" for name in to_skip} + no_wd_set = no_wd_set | to_skip + + for task, adapter in self.output_adapters.items(): + if hasattr(adapter, "no_weight_decay"): + to_skip = adapter.no_weight_decay() + to_skip = {f"output_adapters.{task}.{name}" for name in to_skip} + no_wd_set = no_wd_set | to_skip + + return no_wd_set + + def sample_alphas(self, B: int, n_tasks: int, alphas: float = 1.0, eps: float = 1e-5): + """ + Sample alphas for Dirichlet sampling such that tasks are first uniformly chosen and then Dirichlet sampling + is performed over the chosen ones. + + :param B: Batch size + :param n_tasks: Number of input tasks + :param alphas: Float or list to multiply task choices {0,1} by + :param eps: Small constant since Dirichlet alphas need to be positive + """ + valid_task_choices = torch.Tensor([list(i) for i in itertools.product([0, 1], repeat=n_tasks)][1:]) + rand_per_sample_choice = torch.randint(0, len(valid_task_choices), (B,)) + alphas_tensor = torch.index_select(valid_task_choices, 0, rand_per_sample_choice) + alphas_tensor = alphas_tensor * torch.tensor(alphas) + eps + return alphas_tensor + + def generate_random_masks( + self, + input_tokens: dict[str, torch.Tensor], + num_encoded_tokens: int, + alphas: float | list[float] = 1.0, + sample_tasks_uniformly: bool = False, + ): + """ + Sample a total of num_encoded_tokens from different tasks using Dirichlet sampling. + + :param input_tokens: Dictionary of tensors to sample num_encoded_tokens from + :param num_encoded_tokens: Number of tokens to select + :param alphas: Dirichlet distribution parameter alpha. Lower alpha = harder, + less uniform sampling. Can be float or list of floats. + :param sample_tasks_uniformly: Set to True to first sample 1-n_tasks uniformly at random + for each sample in the batch. Dirichlet sampling is then done over selected subsets. + """ + B = list(input_tokens.values())[0].shape[0] + device = next(iter(input_tokens.values())).device + + alphas = [alphas] * len(input_tokens) if isinstance(alphas, float) else alphas + if sample_tasks_uniformly: + alphas = self.sample_alphas(B, len(input_tokens), alphas=alphas) + task_sampling_dist = Dirichlet(alphas).sample().to(device) + else: + task_sampling_dist = Dirichlet(torch.Tensor(alphas)).sample((B,)).to(device) + + samples_per_task = (task_sampling_dist * num_encoded_tokens).round().long() + + task_masks = [] + num_tokens_per_task = [task_tokens.shape[1] for task_tokens in input_tokens.values()] + for i, num_tokens in enumerate(num_tokens_per_task): + # Use noise to shuffle arange + noise = torch.rand(B, num_tokens, device=device) # noise in [0, 1] + ids_arange_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + mask = torch.arange(num_tokens, device=device).unsqueeze(0).expand(B, -1) + mask = torch.gather(mask, dim=1, index=ids_arange_shuffle) + # 0 is keep (unmasked), 1 is remove (masked) + mask = torch.where(mask < samples_per_task[:, i].unsqueeze(1), 0, 1) + task_masks.append(mask) + + mask_all = torch.cat(task_masks, dim=1) + ids_shuffle = torch.argsort(mask_all + torch.rand_like(mask_all.float()), dim=1) + ids_restore = torch.argsort(ids_shuffle, dim=1) + ids_keep = ids_shuffle[:, :num_encoded_tokens] + + # Update binary mask to adjust for task rounding + mask_all = torch.ones_like(mask_all) + mask_all[:, :num_encoded_tokens] = 0 + # Unshuffle to get the binary mask + mask_all = torch.gather(mask_all, dim=1, index=ids_restore) + # Split to get task masks + task_masks = torch.split(mask_all, num_tokens_per_task, dim=1) + # Convert to dict + task_masks = dict(zip(input_tokens.keys(), task_masks, strict=False)) + + return task_masks, ids_keep, ids_restore + + @staticmethod + def make_mask( + N_H, + N_W, + xy_idxs, + full_tasks=[], + indicate_visible=True, + flatten=True, + device="cuda", + ): + """ + Creates masks for each task, given lists of un-masked x,y coordinates. + """ + xy_idxs = {k: torch.LongTensor(v) for k, v in xy_idxs.items()} + + task_masks = {k: torch.ones(N_H, N_W).to(device) for k in xy_idxs.keys()} + + for k in xy_idxs.keys(): + if len(xy_idxs[k]) > 0: + task_masks[k][xy_idxs[k][:, 1], xy_idxs[k][:, 0]] = 0 + + for task in full_tasks: + task_masks[task][:] = 0 + + if not indicate_visible: + task_masks = {k: 1 - v for k, v in task_masks.items()} + + if flatten: + task_masks = {k: v.flatten().unsqueeze(0) for k, v in task_masks.items()} + + return task_masks + + def generate_input_info(self, input_task_tokens, image_size): + input_info = OrderedDict() + i = 0 + input_info["tasks"] = {} + for domain, tensor in input_task_tokens.items(): + num_tokens = tensor.shape[1] + d = { + "num_tokens": num_tokens, + "has_2d_posemb": True, # TODO: Modify when adding non-2D tasks + "start_idx": i, + "end_idx": i + num_tokens, + } + i += num_tokens + input_info["tasks"][domain] = d + + input_info["image_size"] = image_size + input_info["num_task_tokens"] = i + input_info["num_global_tokens"] = self.num_global_tokens + + return input_info + + def forward( + self, + x: dict[str, torch.Tensor] | torch.Tensor, + mask_inputs: bool = False, + task_masks: dict[str, torch.Tensor] = None, + num_encoded_tokens: int = 128, + alphas: float | list[float] = 1.0, + sample_tasks_uniformly: bool = False, + fp32_output_adapters: list[str] | None = None, + ): + """ + Forward pass through input adapters, transformer encoder and output adapters. + If specified, will randomly drop input tokens. + + :param x: Input tensor or dictionary of tensors + :param mask_inputs: Set to True to enable random masking of input patches + :param task_masks: Optional dictionary of task->mask pairs. + :param num_encoded_tokens: Number of tokens to randomly select for encoder. + Only used if mask_inputs is True. + :param alphas: Dirichlet distribution parameter alpha for task sampling. + Higher alpha = harder, less uniform sampling. Can be float or list of floats. + :param sample_tasks_uniformly: Set to True if tasks should be uniformly presampled, + before Dirichlet sampling decides share of masked tokens between them. + :param fp32_output_adapters: List of task identifiers to force output adapters to + run with mixed precision turned off for stability reasons. + """ + + if fp32_output_adapters is None: + fp32_output_adapters = [] + ## Processing input modalities + # If input x is a Tensor, assume it's RGB + x = {"rgb": x} if isinstance(x, torch.Tensor) else x + + # Need image size for tokens->image reconstruction + # We assume that at least one of rgb or semseg is given as input before masking + if "rgb" in x: + B, C, H, W = x["rgb"].shape + elif "semseg" in x: + B, H, W = x["semseg"].shape + H *= self.input_adapters["semseg"].stride_level + W *= self.input_adapters["semseg"].stride_level + else: + shape = list(x.values())[0].shape + B = shape[0] + H, W = shape[2:] + + # Encode selected inputs to tokens + input_task_tokens = { + domain: self.input_adapters[domain](tensor) + for domain, tensor in x.items() + if domain in self.input_adapters + } + + input_info = self.generate_input_info(input_task_tokens=input_task_tokens, image_size=(H, W)) + + # Select random subset of tokens from the chosen input tasks and concatenate them + if mask_inputs: + num_encoded_tokens = num_encoded_tokens if num_encoded_tokens is not None else self.num_encoded_tokens + else: + num_encoded_tokens = sum([tensor.shape[1] for tensor in input_task_tokens.values()]) + + ## Generating masks + if task_masks is None: + task_masks, ids_keep, ids_restore = self.generate_random_masks( + input_task_tokens, num_encoded_tokens, alphas=alphas, sample_tasks_uniformly=sample_tasks_uniformly + ) + else: + mask_all = torch.cat([task_masks[task] for task in input_task_tokens.keys()], dim=1) + ids_shuffle = torch.argsort(mask_all, dim=1) + ids_restore = torch.argsort(ids_shuffle, dim=1) + ids_keep = ids_shuffle[:, : (mask_all == 0).sum()] + + input_tokens = torch.cat([task_tokens for task_tokens in input_task_tokens.values()], dim=1) + + # Apply mask + input_tokens = torch.gather(input_tokens, dim=1, + index=ids_keep.unsqueeze(-1).repeat(1, 1, input_tokens.shape[2])) + + # Add global tokens to input tokens + global_tokens = repeat(self.global_tokens, "() n d -> b n d", b=B) + input_tokens = torch.cat([input_tokens, global_tokens], dim=1) + + ## Transformer forward pass + outputs = [] + encoder_tokens = input_tokens + for layer in self.layers: + encoder_tokens = layer(encoder_tokens) + outputs.append(encoder_tokens) + + ## Output decoders + if self.output_adapters is None: + return outputs, task_masks + + # Decode tokens for each task using task-specific output adapters + preds = { + domain: self.output_adapters[domain]( + encoder_tokens=encoder_tokens, + input_info=input_info, + ids_keep=ids_keep, + ids_restore=ids_restore, + ) + for domain in self.output_adapters + if domain not in fp32_output_adapters + } + # Force running selected output adapters in fp32 mode + with torch.cuda.amp.autocast(enabled=False): + for domain in fp32_output_adapters: + if domain not in self.output_adapters: + continue + preds[domain] = self.output_adapters[domain]( + encoder_tokens=encoder_tokens.float(), + input_info=input_info, + ids_keep=ids_keep, + ids_restore=ids_restore, + ) + + return preds, task_masks + + +class MultiViT(MultiMAE): + """MultiViT: Multi-modal Vision Transformer + This is MultiMAE without masking and with a simplified / faster forward pass + + + :param input_adapters: Dictionary of task -> input adapters + :param output_adapters: Optional dictionary of task -> output adapters + + :param num_global_tokens: Number of additional global tokens to add (like cls tokens), default is 1 + :param dim_tokens: Dimension of encoder tokens + :param depth: Depth of encoder + :param num_heads: Number of attention heads + :param mlp_ratio: MLP hidden dim ratio + :param qkv_bias: Set to False to disable bias + :param drop_rate: Dropout after MLPs and Attention + :param attn_drop_rate: Attention matrix drop rate + :param drop_path_rate: DropPath drop rate + :param norm_layer: Type of normalization layer + """ + + def process_input(self, x): + # If input x is a Tensor, assume it's RGB + x = {"rgb": x} if isinstance(x, torch.Tensor) else x + # Need image size for tokens->image reconstruction + if "rgb" in x: + B, _, H, W = x["rgb"].shape + elif "semseg" in x: + B, H, W = x["semseg"].shape + H *= self.input_adapters["semseg"].stride_level + W *= self.input_adapters["semseg"].stride_level + else: + B, _, H, W = list(x.values())[0].shape # TODO: Deal with case where not all have same shape + + # Encode selected inputs to tokens + input_task_tokens = { + domain: self.input_adapters[domain](tensor) + for domain, tensor in x.items() + if domain in self.input_adapters + } + + input_info = self.generate_input_info(input_task_tokens=input_task_tokens, image_size=(H, W)) + input_tokens = torch.cat(list(input_task_tokens.values()), dim=1) + + # Add global tokens to input tokens + global_tokens = repeat(self.global_tokens, "() n d -> b n d", b=B) + input_tokens = torch.cat([input_tokens, global_tokens], dim=1) + return input_tokens, input_info + + def forward( + self, + x: dict[str, torch.Tensor] | torch.Tensor, + ): + """ + Forward pass through input adapters, transformer encoder and output adapters. + + :param x: Input tensor or dictionary of tensors + :param return_all_layers: Set to True to return all transformer layers + """ + + num_modalities = len(x) + input_tokens, _ = self.process_input(x) + + encoder_tokens = [] + tokens = input_tokens + for block in self.layers: + tokens = block(tokens) + encoder_tokens.append(tokens) + return encoder_tokens diff --git a/terratorch/models/backbones/multimae/multimae_utils.py b/terratorch/models/backbones/multimae/multimae_utils.py new file mode 100644 index 00000000..0d218f5f --- /dev/null +++ b/terratorch/models/backbones/multimae/multimae_utils.py @@ -0,0 +1,336 @@ +# Copyright (c) EPFL VILAB. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Based on timm, DeiT, DINO, MoCo-v3, BEiT, MAE-priv and MAE code bases +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/facebookresearch/deit +# https://github.com/facebookresearch/dino +# https://github.com/facebookresearch/moco-v3 +# https://github.com/microsoft/unilm/tree/master/beit +# https://github.com/BUPT-PRIV/MAE-priv +# https://github.com/facebookresearch/mae +# -------------------------------------------------------- + +import math +import warnings + +import torch +import torch.nn as nn +from einops import rearrange + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + + +def build_2d_sincos_posemb(h, w, embed_dim=1024, temperature=10000.0): + """Sine-cosine positional embeddings from MoCo-v3 + + Source: https://github.com/facebookresearch/moco-v3/blob/main/vits.py + """ + grid_w = torch.arange(w, dtype=torch.float32) + grid_h = torch.arange(h, dtype=torch.float32) + grid_w, grid_h = torch.meshgrid(grid_w, grid_h) + assert ( + embed_dim % 4 == 0 + ), "Embed dimension must be divisible by 4 for 2D sin-cos position embedding" + pos_dim = embed_dim // 4 + omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim + omega = 1.0 / (temperature**omega) + out_w = torch.einsum("m,d->md", [grid_w.flatten(), omega]) + out_h = torch.einsum("m,d->md", [grid_h.flatten(), omega]) + pos_emb = torch.cat( + [torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1 + )[None, :, :] + pos_emb = rearrange(pos_emb, "b (h w) d -> b d h w", h=h, w=w, d=embed_dim) + return pos_emb + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + # commit this for the orignal BERT implement + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CrossAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, context): + B, N, C = x.shape + _, M, _ = context.shape + + q = ( + self.q(x) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + kv = ( + self.kv(context) + .reshape(B, M, 2, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class DecoderBlock(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.self_attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.cross_attn = CrossAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.query_norm = norm_layer(dim) + self.context_norm = norm_layer(dim) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + def forward(self, x, context): + x = x + self.drop_path(self.self_attn(self.norm1(x))) + x = x + self.drop_path( + self.cross_attn(self.query_norm(x), self.context_norm(context)) + ) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x diff --git a/terratorch/models/backbones/multimae/output_adapter_utils.py b/terratorch/models/backbones/multimae/output_adapter_utils.py new file mode 100644 index 00000000..25b2a2d6 --- /dev/null +++ b/terratorch/models/backbones/multimae/output_adapter_utils.py @@ -0,0 +1,303 @@ +# Copyright (c) EPFL VILAB. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Based on timm, DPT and ConvNeXt code bases +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/isl-org/DPT +# https://github.com/facebookresearch/ConvNeXt +# -------------------------------------------------------- + +import torch +import torch.nn as nn + +from .multimae_utils import DropPath + + +class ConvNeXtBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path: Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 0 (disabled for isotropic ConvNeXt). + + Code from: https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py + """ + + def __init__(self, dim, drop_path=0.0, layer_scale_init_value=0.0): + super().__init__() + self.dwconv = nn.Conv2d( + dim, dim, kernel_size=7, padding=3, groups=dim + ) # depthwise conv + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, 4 * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn): + """Init. + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + self.conv2 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +def make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand == True: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + + scratch.layer_rn = nn.ModuleList( + [ + scratch.layer1_rn, + scratch.layer2_rn, + scratch.layer3_rn, + scratch.layer4_rn, + ] + ) + + return scratch + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + ): + """Init. + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups = 1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, + out_features, + kernel_size=1, + stride=1, + padding=0, + bias=True, + groups=1, + ) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + + +def make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class Interpolate(nn.Module): + """Interpolation module.""" + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + ) + + return x diff --git a/terratorch/models/backbones/multimae/output_adapters.py b/terratorch/models/backbones/multimae/output_adapters.py new file mode 100644 index 00000000..cef8d22b --- /dev/null +++ b/terratorch/models/backbones/multimae/output_adapters.py @@ -0,0 +1,759 @@ +# Copyright (c) EPFL VILAB. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Based on timm, DeiT, DINO, MoCo-v3, BEiT, MAE-priv MAE, DPT and ConvNeXt code bases +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/facebookresearch/deit +# https://github.com/facebookresearch/dino +# https://github.com/facebookresearch/moco-v3 +# https://github.com/microsoft/unilm/tree/master/beit +# https://github.com/BUPT-PRIV/MAE-priv +# https://github.com/facebookresearch/mae +# https://github.com/isl-org/DPT +# https://github.com/facebookresearch/ConvNeXt +# -------------------------------------------------------- + +from functools import partial +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +from .multimae_utils import (Block, CrossAttention, Mlp, + build_2d_sincos_posemb, pair, trunc_normal_) +from .output_adapter_utils import (ConvNeXtBlock, Interpolate, + make_fusion_block, make_scratch) + + +class SpatialOutputAdapter(nn.Module): + """Cross-attention adapter for spatial outputs, like images or feature maps. + + :param num_channels: Number of input channels of the image/feature map + :param stride_level: Stride level compared to the full-sized image. + E.g. 4 for 1/4th the size of the image. + :param patch_size_full: Int or tuple of the patch size over the full image size. + Patch size for smaller inputs will be computed accordingly. + :param dim_tokens_enc: Dimension of tokens coming from encoder. Can be set using init method. + :param dim_tokens: Dimension of decoder tokens + :param depth: Number of additional (full self-attention) transformer layers after initial cross attention and MLP + :param learnable_pos_emb: Set to True to learn positional embeddings instead + :param image_size: Default image size. Used to initialize size of positional embeddings. + :param mlp_ratio: MLP hidden dim ratio + :param num_heads: Number of attention heads + :param qkv_bias: Set to True to enable bias + :param drop_rate: Probability of dropping attention layer outputs + :param attn_drop_rate: Probability of dropping attention matrix elements + :param drop_path_rate: DropPath drop rate + :param norm_layer: Type of normalization layer + :param use_task_queries: When set to True, adds task specific tokens from encoder (if available) + to the corresponding query entries + :param task: Task for which encoder tokens are added to the queries of the decoder (e.g. RGB if decoder is used for RGB) + :param context_tasks: Tasks / modalities from the encoder. Used to create learned embeddings for each task. + :param use_xattn: When set to True, attend to the tokens from the encoder through a cross-attention layer + """ + + def __init__(self, + num_channels: int, + stride_level: int, + patch_size_full: Union[int, Tuple[int, int]], + dim_tokens_enc: Optional[int] = None, + dim_tokens: int = 256, + depth: int = 2, + learnable_pos_emb: int = False, + image_size: Union[int, Tuple[int]] = 224, + mlp_ratio: int = 4.0, + num_heads: int = 8, + qkv_bias: bool = True, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), + use_task_queries: bool = True, + task: Optional[str] = None, + context_tasks: Optional[list] = None, + use_xattn: bool = True + ): + super().__init__() + self.num_channels = num_channels + self.stride_level = stride_level + self.patch_size_full = pair(patch_size_full) + self.dim_tokens_enc = dim_tokens_enc + self.dim_tokens = dim_tokens + self.learnable_pos_emb = learnable_pos_emb + self.image_size = pair(image_size) + self.use_task_queries = use_task_queries + self.task = task + self.use_xattn = use_xattn + + # Actual patch height and width, taking into account stride of input + self.P_H = max(1, self.patch_size_full[0] // stride_level) + self.P_W = max(1, self.patch_size_full[1] // stride_level) + + if context_tasks is not None: + self.task_embeddings = nn.ParameterDict( + {task: nn.Parameter(torch.zeros(1, 1, self.dim_tokens)) for task in context_tasks}) + for embedding in self.task_embeddings.values(): + trunc_normal_(embedding, std=0.02) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.dim_tokens)) + + # Fixed-size positional embeddings. Can be interpolated to different input sizes + h_posemb = self.image_size[0] // (self.stride_level * self.P_H) + w_posemb = self.image_size[1] // (self.stride_level * self.P_W) + if not self.learnable_pos_emb: + self.pos_emb = build_2d_sincos_posemb(h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens) + self.pos_emb = nn.Parameter(self.pos_emb, requires_grad=False) + else: + self.pos_emb = nn.Parameter(torch.zeros(1, h_posemb, w_posemb, self.dim_tokens)) + trunc_normal_(self.pos_emb, std=0.02) + + # One cross attention layer followed by MLP block, an optional transformer, and an output projection + if self.use_xattn: + self.decoder = CrossAttention( + dim=self.dim_tokens, num_heads=num_heads, qkv_bias=qkv_bias, + attn_drop=attn_drop_rate, proj_drop=drop_rate) + self.context_norm = norm_layer(self.dim_tokens) + self.query_norm = norm_layer(self.dim_tokens) + self.out_norm = norm_layer(self.dim_tokens) + + mlp_hidden_dim = int(self.dim_tokens * mlp_ratio) + self.mlp = Mlp(in_features=self.dim_tokens, hidden_features=mlp_hidden_dim) + + # Optional full self-attention transformer layers + if depth > 0: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.decoder_transformer = nn.Sequential(*[ + Block(dim=self.dim_tokens, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth) + ]) + else: + self.decoder_transformer = nn.Identity() + + self.dim_patch = self.num_channels * self.P_H * self.P_W + self.out_proj = nn.Linear(self.dim_tokens, self.dim_patch) + + if self.dim_tokens_enc is not None: + self.init(dim_tokens_enc=dim_tokens_enc) + + def init(self, dim_tokens_enc: int = 768): + ''' + Initialize parts of decoder that are dependent on dimension of encoder tokens. + Should be called when setting up MultiMAE. + + :param dim_tokens_enc: Dimension of tokens coming from encoder + ''' + self.dim_tokens_enc = dim_tokens_enc + + # Projection of encoder tokens to the patch dimension + self.proj_context = nn.Linear(self.dim_tokens_enc, self.dim_tokens) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_emb', 'mask_token', 'task_embeddings'} + + def generate_context_embeddings(self, input_info, + bs: int, + size: Tuple[int, int], + device: Optional[torch.device] = None): + context_embeddings = [] + for task, info in input_info["tasks"].items(): + if self.task_embeddings is not None and task in self.task_embeddings: + task_emb = repeat(self.task_embeddings[task], '() () d -> b n d', b=bs, n=info['num_tokens']) + else: + task_emb = torch.zeros((bs, info['num_tokens'], self.dim_tokens), device=device) + + if info['has_2d_posemb']: + pos_emb = F.interpolate(self.pos_emb, size=size, mode='bilinear', align_corners=False) + pos_emb = rearrange(pos_emb, 'b d nh nw -> b (nh nw) d') + assert info['num_tokens'] == pos_emb.shape[1] + task_emb = task_emb + pos_emb + + context_embeddings.append(task_emb) + + context_embeddings = torch.cat(context_embeddings, dim=1) + + return context_embeddings + + def get_queries_and_context(self, context_tokens, input_info, ids_keep, ids_restore): + B = context_tokens.shape[0] + H, W = input_info['image_size'] + # Number of patches in height and width + N_H = H // (self.stride_level * self.P_H) + N_W = W // (self.stride_level * self.P_W) + + if 'num_global_tokens' in input_info: + context_tokens_without_global = context_tokens[:, :-input_info['num_global_tokens']] + else: + context_tokens_without_global = context_tokens + + # Add mask tokens + mask_tokens = repeat(self.mask_token, '() () d -> b n d', b=B, + n=input_info['num_task_tokens'] - context_tokens_without_global.shape[1]) + context_with_mask = torch.cat([context_tokens_without_global, mask_tokens], dim=1) + + # Unshuffle context_with_mask + context_with_mask = torch.gather(context_with_mask, dim=1, + index=ids_restore.unsqueeze(-1).repeat(1, 1, context_with_mask.shape[2])) + + # Generate context_emb and add them to context + context_emb = self.generate_context_embeddings(input_info=input_info, bs=B, size=(N_H, N_W), + device=context_tokens.device) + context_with_mask = context_with_mask + context_emb + + # Generate queries + if self.use_task_queries and self.task in input_info['tasks']: + start_idx = input_info['tasks'][self.task]['start_idx'] + end_idx = input_info['tasks'][self.task]['end_idx'] + queries = context_with_mask[:, start_idx:end_idx] + else: + queries = repeat(self.mask_token, '() () d -> b n d', b=B, n=N_H * N_W) + queries_pos_emb = F.interpolate(self.pos_emb, size=(N_H, N_W), mode='bilinear', align_corners=False) + queries_pos_emb = rearrange(queries_pos_emb, 'b d nh nw -> b (nh nw) d') + queries = queries + queries_pos_emb + if self.task_embeddings is not None and self.task in self.task_embeddings: + queries_task_emb = repeat(self.task_embeddings[self.task], '() () d -> b n d', b=B, n=N_H * N_W) + queries = queries + queries_task_emb + + # Unshuffle context and keep only initial context (yes, again) + context_tokens_without_global = torch.gather(context_with_mask, dim=1, + index=ids_keep.unsqueeze(-1).repeat(1, 1, context_with_mask.shape[2])) + + # Add back global tokens + if 'num_global_tokens' in input_info: + context_tokens = torch.cat( + [context_tokens_without_global, context_tokens[:, -input_info['num_global_tokens']:]], dim=1) + else: + context_tokens = context_tokens_without_global + + return queries, context_tokens + + def forward(self, + encoder_tokens: torch.Tensor, + input_info: Dict, + ids_keep: torch.Tensor, + ids_restore: torch.Tensor, + ): + """ + Forward pass taking output tokens from encoder and optionally a subset of them corresponding + to this output adapter's task (needs an additional mask describing position of these tokens in the queries). + + :param encoder_tokens: Output of encoder + :param input_info: Dictionary with information about the input modalities + :param ids_keep: IDs of unmasked tokens (tokens given to the encoder) + :param ids_restore: IDs to unshuffle tokens + """ + assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first' + H, W = input_info['image_size'] + # Number of patches in height and width + N_H = H // (self.stride_level * self.P_H) + N_W = W // (self.stride_level * self.P_W) + + # Project encoder tokens to decoder tokens + context_tokens = self.proj_context(encoder_tokens) + + # Get queries and context + queries, context_tokens = self.get_queries_and_context(context_tokens, input_info, ids_keep, ids_restore) + + # Perform cross attention of queries to context tokens, followed by an MLP + if self.use_xattn: + x = self.decoder(self.query_norm(queries), self.context_norm(context_tokens)) + x = x + self.mlp(self.out_norm(x)) + else: + x = queries + + # Optional transformer layers if depth > 0 + x = self.decoder_transformer(x) + + # Project each token to (C * P_H * P_W) + x = self.out_proj(x) + + # Reshape sequence of patches into image + x = rearrange( + x, 'b (nh nw) (c ph pw) -> b c (nh ph) (nw pw)', + nh=N_H, nw=N_W, ph=self.P_H, pw=self.P_W, c=self.num_channels + ) + + return x + + +class LinearOutputAdapter(nn.Module): + """ + Linear output adapter. + + :param num_classes: Number of classes + :param dim_tokens_enc: Dimension of tokens from the encoder + :param use_mean_pooling: When set to True, uses mean pooling before linear classification head. + Otherwise, use last token (usually the global token) + :param norm_layer: Normalization layer + :param init_scale: Initialization scale for linear classification head + """ + + def __init__(self, + num_classes: int, + dim_tokens_enc: Optional[int] = None, + use_mean_pooling: bool = True, + norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), + init_scale: float = 1.0): + super().__init__() + self.num_classes = num_classes + self.dim_tokens_enc = dim_tokens_enc + self.use_mean_pooling = use_mean_pooling + self.norm_layer = norm_layer + self.init_scale = init_scale + + if self.dim_tokens_enc is not None: + self.init(dim_tokens_enc=dim_tokens_enc) + + def init(self, dim_tokens_enc: int = 768): + """ + Initialize parts of decoder that are dependent on dimension of encoder tokens. + Should be called when setting up MultiMAE. + + :param dim_tokens_enc: Dimension of tokens coming from encoder + """ + self.dim_tokens_enc = dim_tokens_enc + + self.norm = self.norm_layer(self.dim_tokens_enc) + self.head = nn.Linear(dim_tokens_enc, self.num_classes) if self.num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + self.head.weight.data.mul_(self.init_scale) + self.head.bias.data.mul_(self.init_scale) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.init(dim_tokens_enc=self.dim_tokens_enc) + + def forward(self, + encoder_tokens: torch.Tensor, + **kwargs): + + if self.use_mean_pooling: + x = encoder_tokens.mean(1) + else: + # Global token is added at the end + x = encoder_tokens[:, -1] + + x = self.head(self.norm(x)) + return x + + +class SegmenterMaskTransformerAdapter(nn.Module): + """Output adapter inspired by the Segmenter-Mask architecture + + This head is the implementation of `Segmenter: `_. + + :param num_classes: Number of classes + :param depth: Depth of decoder + :param num_heads: Number of attention heads + :param embed_dim: Dimension of decoder tokens + :param mlp_ratio: MLP hidden dim ratio + :param drop_path_rate: DropPath drop rate + :param drop_rate: Dropout after MLPs and Attention + :param attn_drop_rate: Attention matrix drop rate + :param qkv_bias: Set to False to disable bias + :param main_tasks: Tasks to use for the adapter. Only tokens coming from these tasks are kept. + :param patch_size: Size of patches + :param norm_layer: Type of normalization layer + """ + + def __init__( + self, + num_classes, + depth: int = 2, + num_heads: int = 12, + embed_dim: int = 768, + mlp_ratio=4, + drop_path_rate=0.1, + drop_rate=0.0, + attn_drop_rate=0.0, + qkv_bias=True, + main_tasks: str = ('rgb',), + patch_size: int = 16, + norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), + **kwargs, + ): + super().__init__() + self.main_tasks = main_tasks + self.patch_size = patch_size + self.embed_dim = embed_dim + self.num_classes = num_classes + + self.cls_emb = nn.Parameter(torch.zeros(1, num_classes, embed_dim)) + trunc_normal_(self.cls_emb, std=0.02) + + self.patch_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.classes_proj = nn.Linear(embed_dim, embed_dim, bias=False) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + self.blocks = nn.ModuleList([ + Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth) + ]) + + self.decoder_norm = norm_layer(embed_dim) + self.mask_norm = norm_layer(num_classes) + self.apply(self._init_weights) + + def init(self, dim_tokens_enc: int = 768): + """ + Initialize parts of decoder that are dependent on dimension of encoder tokens. + Should be called when setting up MultiMAE. + + :param dim_tokens_enc: Dimension of tokens coming from encoder + """ + self.in_channels = dim_tokens_enc * len(self.main_tasks) + + # Projection of encoder tokens to the patch dimension + self.proj_dec = nn.Linear(self.in_channels, self.embed_dim) + self._init_weights(self.proj_dec) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def adapt_tokens(self, encoder_tokens, input_info): + # Adapt tokens + x = [] + for task in self.main_tasks: + start_idx = input_info['tasks'][task]['start_idx'] + end_idx = input_info['tasks'][task]['end_idx'] + x.append(encoder_tokens[:, start_idx:end_idx]) + + x = torch.cat(x, dim=-1) + return x + + def forward(self, encoder_tokens: torch.Tensor, input_info: Dict): + H, W = input_info['image_size'] + N_H, N_W = H // self.patch_size, W // self.patch_size + + x = self.adapt_tokens(encoder_tokens, input_info) + + x = self.proj_dec(x) + cls_emb = self.cls_emb.expand(x.shape[0], -1, -1) + x = torch.cat((x, cls_emb), 1) + + for blk in self.blocks: + x = blk(x) + + x = self.decoder_norm(x) + + patches = self.patch_proj(x[:, :-self.num_classes]) + cls_seg_feat = self.classes_proj(x[:, -self.num_classes:]) + + patches = F.normalize(patches, dim=2, p=2) + cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2) + + masks = patches @ cls_seg_feat.transpose(1, 2) + masks = self.mask_norm(masks) + masks = rearrange(masks, "b (nh nw) c -> b c nh nw", nh=N_H, nw=N_W) + + # Interpolate to semseg res + masks = F.interpolate(masks, size=(H, W), mode="bilinear") + + return masks + + +class ConvNeXtAdapter(nn.Module): + """Output adapter with ConvNext blocks for semantic segmentation + + :param num_classes: Number of classes + :param num_heads: Number of attention heads + :param embed_dim: Token dimension after projection, and before reshaping operation. + :param preds_per_patch: Increases size of feature map by reshaping each patch Each patch gets reshaped + from embed_dim x 1 x 1 to (embed_dim / preds_per_patch) x (preds_per_patch ** 0.5) x (preds_per_patch ** 0.5) + :param main_tasks: Tasks to use for the adapter. Only tokens coming from these tasks are kept. + :param patch_size: Size of patches + :param depth: Number of ConvNeXt blocks + :interpolate_mode: Interpolation mode for final upsampling + """ + + def __init__( + self, + num_classes, + embed_dim: int = 6144, + preds_per_patch: int = 16, + main_tasks: Iterable[str] = ('rgb',), + patch_size: int = 16, + depth: int = 4, + interpolate_mode: str = 'bilinear', + **kwargs, + ): + super().__init__() + self.main_tasks = main_tasks + self.patch_size = patch_size + self.embed_dim = embed_dim + self.preds_per_patch = preds_per_patch + self.class_dim = embed_dim // preds_per_patch + self.num_classes = num_classes + self.interpolate_mode = interpolate_mode + + self.blocks = nn.Sequential(*[ + ConvNeXtBlock(dim=self.class_dim) + for _ in range(depth) + ]) + self.final_layer = nn.Conv2d(self.class_dim, self.num_classes, 1) + self.apply(self._init_weights) + + def init(self, dim_tokens_enc: int = 768): + """ + Initialize parts of decoder that are dependent on dimension of encoder tokens. + Should be called when setting up MultiMAE. + + :param dim_tokens_enc: Dimension of tokens coming from encoder + """ + self.in_channels = dim_tokens_enc * len(self.main_tasks) + + # Projection of encoder tokens to the patch dimension + self.proj_dec = nn.Linear(self.in_channels, self.embed_dim) + self._init_weights(self.proj_dec) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def adapt_tokens(self, encoder_tokens, input_info): + # Adapt tokens + x = [] + for task in self.main_tasks: + start_idx = input_info['tasks'][task]['start_idx'] + end_idx = input_info['tasks'][task]['end_idx'] + x.append(encoder_tokens[:, start_idx:end_idx]) + + x = torch.cat(x, dim=-1) + return x + + def forward(self, encoder_tokens: torch.Tensor, input_info: Dict): + H, W = input_info['image_size'] + N_H, N_W = H // self.patch_size, W // self.patch_size + + x = self.adapt_tokens(encoder_tokens, input_info) + + x = self.proj_dec(x) + x = rearrange(x, "b n (p c) -> b (n p) c", n=N_H * N_W, p=self.preds_per_patch, c=self.class_dim) + x = rearrange(x, "b (nh nw ph pw) c -> b c (nh ph) (nw pw)", + nh=N_H, nw=N_W, + ph=int(self.preds_per_patch ** 0.5), + pw=int(self.preds_per_patch ** 0.5)) + x = self.blocks(x) + x = self.final_layer(x) + + # Interpolate to semseg res + x = F.interpolate(x, size=(H, W), mode=self.interpolate_mode) + + return x + + +class DPTOutputAdapter(nn.Module): + """DPT output adapter. + + :param num_classes: Number of output channels + :param stride_level: tride level compared to the full-sized image. + E.g. 4 for 1/4th the size of the image. + :param patch_size_full: Int or tuple of the patch size over the full image size. + Patch size for smaller inputs will be computed accordingly. + :param hooks: Index of intermediate layers + :param layer_dims: Dimension of intermediate layers + :param feature_dim: Feature dimension + :param use_bn: If set to True, activates batch norm + :param dim_tokens_enc: Dimension of tokens coming from encoder + """ + + def __init__(self, + num_classes: int = 3, + stride_level: int = 1, + patch_size: Union[int, Tuple[int, int]] = 16, + main_tasks: Iterable[str] = ('rgb',), + hooks: List[int] = [2, 5, 8, 11], + layer_dims: List[int] = [96, 192, 384, 768], + feature_dim: int = 256, + use_bn: bool = False, + dim_tokens_enc: Optional[int] = None, + head_type: str = 'regression', + **kwargs): + super().__init__() + self.num_channels = num_classes + self.stride_level = stride_level + self.patch_size = pair(patch_size) + self.main_tasks = main_tasks + self.hooks = hooks + self.layer_dims = layer_dims + self.feature_dim = feature_dim + self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) if dim_tokens_enc is not None else None + self.head_type = head_type + + # Actual patch height and width, taking into account stride of input + self.P_H = max(1, self.patch_size[0] // stride_level) + self.P_W = max(1, self.patch_size[1] // stride_level) + + self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False) + + self.scratch.refinenet1 = make_fusion_block(feature_dim, use_bn) + self.scratch.refinenet2 = make_fusion_block(feature_dim, use_bn) + self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn) + self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn) + + if self.head_type == 'regression': + # The "DPTDepthModel" head + self.head = nn.Sequential( + nn.Conv2d(feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(feature_dim // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, self.num_channels, kernel_size=1, stride=1, padding=0) + ) + elif self.head_type == 'semseg': + # The "DPTSegmentationModel" head + self.head = nn.Sequential( + nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(), + nn.ReLU(True), + nn.Dropout(0.1, False), + nn.Conv2d(feature_dim, self.num_channels, kernel_size=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + ) + else: + raise ValueError('DPT head_type must be "regression" or "semseg".') + + if self.dim_tokens_enc is not None: + self.init(dim_tokens_enc=dim_tokens_enc) + + def init(self, dim_tokens_enc: int = 768): + """ + Initialize parts of decoder that are dependent on dimension of encoder tokens. + Should be called when setting up MultiMAE. + + :param dim_tokens_enc: Dimension of tokens coming from encoder + """ + self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) + + # Set up activation postprocessing layers + + self.act_1_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc, + out_channels=self.layer_dims[0], + kernel_size=1, stride=1, padding=0, + ), + nn.ConvTranspose2d( + in_channels=self.layer_dims[0], + out_channels=self.layer_dims[0], + kernel_size=4, stride=4, padding=0, + bias=True, dilation=1, groups=1, + ) + ) + + self.act_2_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc, + out_channels=self.layer_dims[1], + kernel_size=1, stride=1, padding=0, + ), + nn.ConvTranspose2d( + in_channels=self.layer_dims[1], + out_channels=self.layer_dims[1], + kernel_size=2, stride=2, padding=0, + bias=True, dilation=1, groups=1, + ) + ) + + self.act_3_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc, + out_channels=self.layer_dims[2], + kernel_size=1, stride=1, padding=0, + ) + ) + + self.act_4_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc, + out_channels=self.layer_dims[3], + kernel_size=1, stride=1, padding=0, + ), + nn.Conv2d( + in_channels=self.layer_dims[3], + out_channels=self.layer_dims[3], + kernel_size=3, stride=2, padding=1, + ) + ) + + self.act_postprocess = nn.ModuleList([ + self.act_1_postprocess, + self.act_2_postprocess, + self.act_3_postprocess, + self.act_4_postprocess + ]) + + def adapt_tokens(self, encoder_tokens, input_info): + # Adapt tokens + x = [] + for task in self.main_tasks: + start_idx = input_info['tasks'][task]['start_idx'] + end_idx = input_info['tasks'][task]['end_idx'] + x.append(encoder_tokens[:, start_idx:end_idx]) + + x = torch.cat(x, dim=-1) + return x + + def forward(self, encoder_tokens: List[torch.Tensor], input_info: Dict): + assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first' + H, W = input_info['image_size'] + # Number of patches in height and width + N_H = H // (self.stride_level * self.P_H) + N_W = W // (self.stride_level * self.P_W) + + # Hook decoder onto 4 layers from specified ViT layers + layers = [encoder_tokens[hook] for hook in self.hooks] + + # Extract only task-relevant tokens and ignore global tokens. + layers = [self.adapt_tokens(l, input_info) for l in layers] + + # Reshape tokens to spatial representation + layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers] + + # Postprocess activations + layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] + + # Project layers to chosen feature dim + layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] + + # Fuse layers using refinement stages + path_4 = self.scratch.refinenet4(layers[3]) + path_3 = self.scratch.refinenet3(path_4, layers[2]) + path_2 = self.scratch.refinenet2(path_3, layers[1]) + path_1 = self.scratch.refinenet1(path_2, layers[0]) + + # Output head + out = self.head(path_1) + + return out \ No newline at end of file diff --git a/terratorch/models/backbones/multimae_register.py b/terratorch/models/backbones/multimae_register.py new file mode 100644 index 00000000..ff62b937 --- /dev/null +++ b/terratorch/models/backbones/multimae_register.py @@ -0,0 +1,389 @@ +""" +This module handles registering multimae models into timm. +""" +import logging +import torch +import numpy as np +from pathlib import Path +from functools import partial +from timm.models import FeatureInfo +from timm.models._builder import build_model_with_cfg +from timm.models._registry import generate_default_cfgs, register_model + +from terratorch.datasets.utils import HLSBands, Modalities, S1Bands, DEMBands, LULCclasses +from terratorch.models.backbones.multimae.multimae import MultiMAE, MultiViT +from terratorch.models.backbones.multimae.criterion import MaskedMSELoss, MaskedCrossEntropyLoss +from terratorch.models.backbones.multimae.input_adapters import PatchedInputAdapter, SemSegInputAdapter +from terratorch.models.backbones.multimae.output_adapters import SpatialOutputAdapter, ConvNeXtAdapter + + +def _cfg(file: Path = "", **kwargs) -> dict: + return { + "file": file, + "source": "file", + "license": "mit", + **kwargs, + } + + +# TODO: Add pretrained models +# PRETRAINED_BANDS: list[HLSBands | int] = [ +# HLSBands.BLUE, +# HLSBands.GREEN, +# HLSBands.RED, +# ] + +# default_cfgs = generate_default_cfgs( +# { +# "multimae_base": _cfg( +# file="" +# ), +# "multimae_large": _cfg( +# file="" +# ) +# +# } +# ) + + +# TODO: make these user definable +DOMAIN_CONF = { + Modalities.S1: { + "channels": 2, + "stride_level": 1, + "input_adapter": partial(PatchedInputAdapter, num_channels=2), + "output_adapter": partial(SpatialOutputAdapter, num_channels=2), + "loss": MaskedMSELoss, + "image_size": 224, + "patch_size": 16, + }, + Modalities.S2L1C: { + "channels": 13, + "stride_level": 1, + "input_adapter": partial(PatchedInputAdapter, num_channels=13), + "output_adapter": partial(SpatialOutputAdapter, num_channels=13), + "loss": MaskedMSELoss, + "image_size": 224, + "patch_size": 16, + }, + Modalities.S2L2A: { + "channels": 12, + "stride_level": 1, + "input_adapter": partial(PatchedInputAdapter, num_channels=12), + "output_adapter": partial(SpatialOutputAdapter, num_channels=12), + "loss": MaskedMSELoss, + "image_size": 224, + "patch_size": 16, + }, + Modalities.S2RGB: { + "channels": 3, + "stride_level": 1, + "input_adapter": partial(PatchedInputAdapter, num_channels=3), + "output_adapter": partial(SpatialOutputAdapter, num_channels=3), + "loss": MaskedMSELoss, + "image_size": 224, + "patch_size": 16, + }, + Modalities.DEM: { + "channels": 1, + "stride_level": 1, + "input_adapter": partial(PatchedInputAdapter, num_channels=1), + "output_adapter": partial(SpatialOutputAdapter, num_channels=1), + "loss": MaskedMSELoss, + "image_size": 224, + "patch_size": 16, + }, + Modalities.LULC: { + "classes": 9, + "stride_level": 1, + "input_adapter": partial(SemSegInputAdapter, num_classes=9), + "output_adapter": partial(SpatialOutputAdapter, num_channels=9), # Used in pretraining + "loss": partial(MaskedCrossEntropyLoss, label_smoothing=0.0), + "image_size": 224, + "patch_size": 16, + }, + 'segmentation': { # TODO: Test generalized semseg head from MultiMAE! + "classes": 9, + "stride_level": 1, + "input_adapter": partial(SemSegInputAdapter, num_classes=9), + "output_adapter": partial(SpatialOutputAdapter, num_channels=9), # Used in pretraining + "loss": partial(MaskedCrossEntropyLoss, label_smoothing=0.0), + "image_size": 224, + "patch_size": 16, + }, +} + + +def _instantiate_input_adapter_from_dict(spec: dict) -> PatchedInputAdapter | SemSegInputAdapter: + return spec["input_adapter"]( + stride_level=spec["stride_level"], + patch_size_full=spec["patch_size"], + image_size=spec["image_size"], + ) + + +def _parse_input_adapters( + adapter_spec: list | dict[str, str | dict[str, int | str]], +) -> dict[str, PatchedInputAdapter | SemSegInputAdapter]: + + if isinstance(adapter_spec, list): + # list to dict + adapter_spec = {m: m for m in adapter_spec} + if isinstance(adapter_spec, dict) and len(set(adapter_spec.keys())) != len(adapter_spec.keys()): + msg = "Duplicate keys in input adapters" + raise Exception(msg) + input_adapters = {} + + for adapter_name, spec in adapter_spec.items(): + match spec: + case str(spec): + try: + spec = Modalities(spec.upper()) + except ValueError: + pass + + if spec in DOMAIN_CONF.keys(): + input_adapters[adapter_name] = _instantiate_input_adapter_from_dict(DOMAIN_CONF[spec]) + else: + msg = f"Input Domain {adapter_name} does not exist. Choose one of {list(DOMAIN_CONF.keys())}" + raise ValueError(msg) + case {"type": "PatchedInputAdapter", "num_channels": num_channels, **kwargs}: + input_adapters[adapter_name] = PatchedInputAdapter(num_channels=num_channels, **kwargs) + case {"type": "SemSegInputAdapter", "num_classes": num_classes, **kwargs}: + input_adapters[adapter_name] = SemSegInputAdapter(num_classes=num_classes, **kwargs) + case _: + msg = f"Invalid input adapter config for adapter {adapter_name}" + raise ValueError(msg) + return input_adapters + + +def _instantiate_output_adapter_from_dict(spec: dict, task: str, context_tasks: list[str], # num_channels: int + ) -> SpatialOutputAdapter | ConvNeXtAdapter: + return spec["output_adapter"]( + stride_level=spec["stride_level"], + patch_size_full=spec["patch_size"], + image_size=spec["image_size"], + task=task, + context_tasks=context_tasks, + # num_channels=spec['channels'], # TODO: Not passed in pretraining code + ) + + +def _parse_output_adapters( + adapter_spec: list | dict[str, str | dict[str, int | str]], +) -> dict[str, SpatialOutputAdapter | SpatialOutputAdapter]: + + if isinstance(adapter_spec, list): + # list to dict + adapter_spec = {m: m for m in adapter_spec} + if isinstance(adapter_spec, dict) and len(set(adapter_spec.keys())) != len(adapter_spec.keys()): + msg = "Duplicate keys in output adapters" + raise Exception(msg) + output_adapters = {} + + for adapter_name, spec in adapter_spec.items(): + match spec: + case str(spec): + try: + spec = Modalities(spec.upper()) + except ValueError: + pass + + if spec in DOMAIN_CONF.keys(): + output_adapters[adapter_name] = _instantiate_output_adapter_from_dict( + DOMAIN_CONF[spec], + task=adapter_name, + context_tasks=list(adapter_spec.keys()), + ) + else: + msg = f"output Domain {adapter_name} does not exist. Choose one of {list(DOMAIN_CONF.keys())}" + raise ValueError(msg) + case {"type": "SpatialOutputAdapter", "num_channels": num_channels, **kwargs}: # Used for pre-training + output_adapters[adapter_name] = SpatialOutputAdapter( + task=adapter_name, + context_tasks=list(adapter_spec.keys()), + num_channels=num_channels, + **kwargs + ) + case {"type": "ConvNeXtAdapter", "num_classes": num_classes, **kwargs}: + output_adapters[adapter_name] = ConvNeXtAdapter(num_classes=num_classes, **kwargs) + case _: + msg = f"Invalid output adapter config for adapter {adapter_name}" + raise ValueError(msg) + return output_adapters + + +# If you need to adapt the checkpoint file, do it here +def checkpoint_filter_fn( + modalities: list[str], + state_dict: dict[str, torch.Tensor], + model: torch.nn.Module, +): + new_state_dict = {} + + for k, v in state_dict.items(): + if "output_adapters" in k: + continue + + # drop pos emb + if "pos_emb" in k: + continue + + if k.startswith("input_adapters."): + try: + modality_name = k.split(".")[1] + modality = Modalities(modality_name) + except ValueError: + print(f"Modality {modality_name} is not in allowed modalities. Skipping {k}.") + continue + if modality.value not in modalities: + print(f"Removing input adapter for {modality_name}: {k}") + continue + if k.startswith("encoder."): + new_k = "layers." + k.removeprefix("encoder.") + new_state_dict[new_k] = v + else: + new_state_dict[k] = v + return new_state_dict + + +class PrepareMultimodalFeaturesForDecoder: + def __init__(self, modalities: list[str], merging_method: str = 'concat'): + self.modalities = modalities + self.merging_method = merging_method + + def __call__(self, x: list[torch.Tensor]) -> list[torch.Tensor]: + if len(x) == 2: + # MultiMAE decoder was used. Return predictions of first modality. + preds = list(x[0].values()) + assert len(preds) != 1, "Terratorch can only handle one output modality." + return preds[0] + + for output_index in range(len(x)): + x[output_index] = x[output_index].permute(0, 2, 1) + x[output_index] = x[output_index][:, :, :-1] # remove global token + img_shape = int(np.sqrt(x[output_index].shape[-1] / len(self.modalities))) + if self.merging_method == 'concat': + x[output_index] = x[output_index].reshape(x[output_index].shape[0], -1, img_shape, img_shape) + else: + raise ValueError(f"Unsupported merging method {self.merging_method}") + # TODO: Implement other methods, move to forward? + + return x + + +def _create_multimae( + variant: str, + modalities: list[str], + pretrained: bool = False, + features_only: bool = True, + merging_method: str = None, + **kwargs, +): + model: torch.nn.Module = build_model_with_cfg( + MultiViT if features_only else MultiMAE, # MultiViT is an encoder-only model + variant, + pretrained, + # if you need to adapt the checkpoint file + pretrained_filter_fn=partial(checkpoint_filter_fn, modalities), + pretrained_strict=False, + feature_cfg={ + "flatten_sequential": True, + }, + features_only=False, + merging_method=merging_method, + **kwargs, + ) + default_out_indices = list(range(len(model.layers))) + out_indices = kwargs.get("out_indices", default_out_indices) + model.feature_info = FeatureInfo(model.feature_info, out_indices) + + model.prepare_features_for_image_model = PrepareMultimodalFeaturesForDecoder( + modalities, + merging_method=merging_method + ) + return model + + +@register_model +def multimae_base( + input_adapters: dict[str, str | dict[str, int | str]] | None = None, + output_adapters: dict[str, str | dict[str, int | str]] | None = None, + pretrained: bool = False, + features_only: bool = True, + **kwargs, +) -> torch.nn.Module: + """MultiMAE base model.""" + + if input_adapters is None: + input_adapters = ['S1', 'S2L1C', 'S2L2A', 'DEM', 'LULC'] + logging.warning(f'') + input_adapters = _parse_input_adapters(input_adapters) + + if output_adapters is not None: + output_adapters = _parse_output_adapters(output_adapters) + + model_args = { + "input_adapters": input_adapters, + "output_adapters": output_adapters, + "dim_tokens": 768, + "depth": 12, + "num_heads": 12, + "mlp_ratio": 4, + "qkv_bias": True, + "norm_layer": partial(torch.nn.LayerNorm, eps=1e-6), + } + + kwargs.pop('features_only', None) + merging_method = None if output_adapters else kwargs.get('merging_method', 'concat') + + transformer = _create_multimae( + "multimae_base", + list(input_adapters.keys()), + pretrained=pretrained, + features_only=output_adapters is None, + merging_method=merging_method, + **dict(model_args, **kwargs), + ) + return transformer + + +@register_model +def multimae_large( + input_adapters: dict[str, str | dict[str, int | str]] | None = None, + output_adapters: dict[str, str | dict[str, int | str]] | None = None, + pretrained: bool = False, # noqa: FBT002, FBT001 + **kwargs, +) -> torch.nn.Module: + """MultiMAE large model.""" + + if input_adapters is None: + input_adapters = ['S1', 'S2L1C', 'S2L2A', 'DEM', 'LULC'] + input_adapters = _parse_input_adapters(input_adapters) + + if output_adapters is not None: + output_adapters = _parse_output_adapters(output_adapters) + + model_args = { + "input_adapters": input_adapters, + "output_adapters": output_adapters, + "dim_tokens": 1024, + "depth": 24, + "num_heads": 16, + "mlp_ratio": 4, + "qkv_bias": True, + "norm_layer": partial(torch.nn.LayerNorm, eps=1e-6), + } + + kwargs.pop('features_only', None) + merging_method = None if output_adapters else kwargs.get('merging_method', 'concat') + + transformer = _create_multimae( + "multimae_large", + list(input_adapters.keys()), + pretrained=pretrained, + features_only=output_adapters is None, + merging_method=merging_method, + **dict(model_args, **kwargs), + ) + return transformer From cfd291c093a89afff5ed25c30c3dcaea424be661 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Tue, 5 Nov 2024 09:59:08 +0100 Subject: [PATCH 20/42] Updated multimodal dataset Signed-off-by: Benedikt Blumenstiel --- .../generic_multimodal_data_module.py | 7 +++++- .../datasets/generic_multimodal_dataset.py | 9 ++------ terratorch/datasets/utils.py | 23 +++++++++++++++++++ terratorch/tasks/segmentation_tasks.py | 11 +++++---- 4 files changed, 38 insertions(+), 12 deletions(-) diff --git a/terratorch/datamodules/generic_multimodal_data_module.py b/terratorch/datamodules/generic_multimodal_data_module.py index 67aa0463..b373b2b1 100644 --- a/terratorch/datamodules/generic_multimodal_data_module.py +++ b/terratorch/datamodules/generic_multimodal_data_module.py @@ -66,8 +66,13 @@ def __call__(self, batch): # B, (T,) H, W means = torch.tensor(self.means[m], device=image.device) stds = torch.tensor(self.stds[m], device=image.device) + elif len(image.shape) == 3: # No batch dim + # C, H, W + means = torch.tensor(self.means[m], device=image.device).view(-1, 1, 1) + stds = torch.tensor(self.stds[m], device=image.device).view(-1, 1, 1) else: - msg = f"Expected batch to have 5 or 4 dimensions or a single channel, but got {len(image.shape)}" + msg = (f"Expected batch with 5 or 4 dimensions (B, C, (T,) H, W), sample with 3 dimensions (C, H, W) " + f"or a single channel, but got {len(image.shape)}") raise Exception(msg) batch['image'][m] = (image - means) / stds return batch diff --git a/terratorch/datasets/generic_multimodal_dataset.py b/terratorch/datasets/generic_multimodal_dataset.py index a408cffc..390e2ad0 100644 --- a/terratorch/datasets/generic_multimodal_dataset.py +++ b/terratorch/datasets/generic_multimodal_dataset.py @@ -405,12 +405,7 @@ def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> .. versionadded:: 0.2 """ - raise NotImplementedError('Code is based on the generic single-modality dataset and not yet adapted. ' - 'Set `export TERRATORCH_NUM_VAL_PLOTS=0` before running terratorch.') - - image = sample[self.rgb_modality] - if len(image.shape) == 5: # TODO: Fix plot code. - return + image = sample["image"] if isinstance(image, torch.Tensor): image = image.numpy() image = image.take(self.rgb_indices, axis=0) @@ -493,7 +488,7 @@ def __init__( allow_substring_split_file: bool = False, dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, - constant_scale: float = 1, + constant_scale: float = 1., transform: A.Compose | None = None, no_data_replace: float | None = None, no_label_replace: int | None = None, diff --git a/terratorch/datasets/utils.py b/terratorch/datasets/utils.py index 8fa26cd4..04b9a534 100644 --- a/terratorch/datasets/utils.py +++ b/terratorch/datasets/utils.py @@ -34,6 +34,29 @@ def try_convert_to_hls_bands_enum(cls, x: Any): except ValueError: return x + +class S1Bands(Enum): + VV = 'VV' + VH = 'VH' + + +class DEMBands(Enum): + DEM = 'DEM' + + +class LULCclasses(Enum): + LULC = 'LULC' + + +class Modalities(Enum): + S1 = "S1" + S2L1C = "S2L1C" + S2L2A = "S2L2A" + S2RGB = "S2RGB" + DEM = "DEM" + LULC = "LULC" + + def default_transform(**batch): return to_tensor(batch) diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index a0afa690..1404077e 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -18,7 +18,7 @@ from terratorch.tasks.optimizer_factory import optimizer_factory from terratorch.tasks.tiled_inference import TiledInferenceParameters, tiled_inference -BATCH_IDX_FOR_VALIDATION_PLOTTING = os.getenv('TERRATORCH_NUM_VAL_PLOTS', 10) +BATCH_IDX_FOR_VALIDATION_PLOTTING = 10 def to_segmentation_prediction(y: ModelOutput) -> Tensor: @@ -226,7 +226,7 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> model_output: ModelOutput = self(x) loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) - self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) + self.train_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0]) y_hat_hard = to_segmentation_prediction(model_output) self.train_metrics.update(y_hat_hard, y) @@ -262,7 +262,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - y = batch["mask"] model_output: ModelOutput = self(x) loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) - self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) + self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0]) y_hat_hard = to_segmentation_prediction(model_output) self.val_metrics.update(y_hat_hard, y) @@ -270,6 +270,9 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - try: datamodule = self.trainer.datamodule batch["prediction"] = y_hat_hard + if isinstance(batch["image"], dict): + # Multimodal input + batch["image"] = batch["image"][self.trainer.datamodule.rgb_modality] for key in ["image", "mask", "prediction"]: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] @@ -305,7 +308,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None model_output: ModelOutput = self(x) loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) - self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) + self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0]) y_hat_hard = to_segmentation_prediction(model_output) self.test_metrics.update(y_hat_hard, y) From d792a56c5bcfc2bd6993310fa3eaa01accf12688 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Thu, 14 Nov 2024 10:03:49 +0100 Subject: [PATCH 21/42] Added sequence data to multimodal dataset Signed-off-by: Benedikt Blumenstiel --- .../generic_multimodal_data_module.py | 37 +++++++++++++------ .../datasets/generic_multimodal_dataset.py | 34 +++++++++++------ terratorch/datasets/transforms.py | 18 +++++++-- 3 files changed, 62 insertions(+), 27 deletions(-) diff --git a/terratorch/datamodules/generic_multimodal_data_module.py b/terratorch/datamodules/generic_multimodal_data_module.py index b373b2b1..1a1d8d2a 100644 --- a/terratorch/datamodules/generic_multimodal_data_module.py +++ b/terratorch/datamodules/generic_multimodal_data_module.py @@ -35,10 +35,16 @@ def collate_chunk_dicts(batch_list): return batch -def wrap_in_compose_is_list(transform_list, additional_targets=None): +def wrap_in_compose_is_list(transform_list, image_modalities=None, sequence_modalities=None): + additional_targets = {} + if image_modalities: + for modality in image_modalities: + additional_targets[modality] = 'image' + if sequence_modalities: + # Global label values are ignored and need to be processed separately + for modality in sequence_modalities: + additional_targets[modality] = "global_label" # set check shapes to false because of the multitemporal case - if additional_targets: - additional_targets = {m: 'image' for m in additional_targets} return A.Compose(transform_list, is_check_shapes=False, additional_targets=additional_targets) \ if isinstance(transform_list, Iterable) else transform_list @@ -148,6 +154,7 @@ def __init__( output_bands: dict | None = None, predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, + image_modalities: list[str] | None = None, rgb_modality: str | None = None, rgb_indices: list[int] | None = None, allow_substring_split_file: bool = False, @@ -237,6 +244,8 @@ def __init__( super().__init__(dataset_class, batch_size, num_workers, **kwargs) self.num_classes = num_classes self.modalities = modalities + self.image_modalities = image_modalities or modalities + self.sequence_modalities = list(set(self.modalities) - set(image_modalities)) if isinstance(img_grep, dict): self.img_grep = {m: img_grep[m] if m in img_grep else '*' for m in modalities} else: @@ -279,17 +288,13 @@ def __init__( self.reduce_zero_label = reduce_zero_label self.channel_position = channel_position - # Transforms can be None (leads to to_tensor default), shared between modalities or individual per modality - if shared_transforms: - # Applying the same transforms with the same parameters to multiple images - shared_transforms = shared_transforms if isinstance(shared_transforms, list) else modalities - assert shared_transforms == modalities, "Non-image modalities not yet supported with shared_transforms" - if isinstance(train_transform, dict): self.train_transform = {m: wrap_in_compose_is_list(train_transform[m]) if m in train_transform else None for m in modalities} elif shared_transforms: - self.train_transform = wrap_in_compose_is_list(train_transform, additional_targets=shared_transforms) + self.train_transform = wrap_in_compose_is_list(train_transform, + image_modalities=self.image_modalities, + sequence_modalities=self.sequence_modalities) else: self.train_transform = {m: wrap_in_compose_is_list(train_transform) for m in modalities} @@ -298,7 +303,9 @@ def __init__( self.val_transform = {m: wrap_in_compose_is_list(val_transform[m]) if m in val_transform else None for m in modalities} elif shared_transforms: - self.val_transform = wrap_in_compose_is_list(val_transform, additional_targets=shared_transforms) + self.val_transform = wrap_in_compose_is_list(val_transform, + image_modalities=self.image_modalities, + sequence_modalities=self.sequence_modalities) else: self.val_transform = {m: wrap_in_compose_is_list(val_transform) for m in modalities} @@ -307,7 +314,9 @@ def __init__( self.test_transform = {m: wrap_in_compose_is_list(test_transform[m]) if m in test_transform else None for m in modalities} elif shared_transforms: - self.test_transform = wrap_in_compose_is_list(test_transform, additional_targets=shared_transforms) + self.test_transform = wrap_in_compose_is_list(test_transform, + image_modalities=self.image_modalities, + sequence_modalities=self.sequence_modalities) else: self.test_transform = {m: wrap_in_compose_is_list(test_transform) for m in modalities} @@ -334,6 +343,7 @@ def setup(self, stage: str) -> None: dataset_bands=self.dataset_bands, output_bands=self.output_bands, constant_scale=self.constant_scale, + image_modalities=self.image_modalities, rgb_modality=self.rgb_modality, rgb_indices=self.rgb_indices, transform=self.train_transform, @@ -356,6 +366,7 @@ def setup(self, stage: str) -> None: dataset_bands=self.dataset_bands, output_bands=self.output_bands, constant_scale=self.constant_scale, + image_modalities=self.image_modalities, rgb_modality=self.rgb_modality, rgb_indices=self.rgb_indices, transform=self.val_transform, @@ -378,6 +389,7 @@ def setup(self, stage: str) -> None: dataset_bands=self.dataset_bands, output_bands=self.output_bands, constant_scale=self.constant_scale, + image_modalities=self.image_modalities, rgb_modality=self.rgb_modality, rgb_indices=self.rgb_indices, transform=self.test_transform, @@ -396,6 +408,7 @@ def setup(self, stage: str) -> None: dataset_bands=self.predict_dataset_bands, output_bands=self.predict_output_bands, constant_scale=self.constant_scale, + image_modalities=self.image_modalities, rgb_modality=self.rgb_modality, rgb_indices=self.rgb_indices, transform=self.test_transform, diff --git a/terratorch/datasets/generic_multimodal_dataset.py b/terratorch/datasets/generic_multimodal_dataset.py index 390e2ad0..5564a45d 100644 --- a/terratorch/datasets/generic_multimodal_dataset.py +++ b/terratorch/datasets/generic_multimodal_dataset.py @@ -3,7 +3,6 @@ """Module containing generic dataset classes""" import glob -import logging import os import torch from abc import ABC @@ -59,6 +58,7 @@ def __init__( image_grep: str | None = "*", label_grep: str | None = "*", split: Path | None = None, + image_modalities: list[str] | None = None, rgb_modality: str | None = None, rgb_indices: list[int] | None = None, allow_missing_modalities: bool = False, # TODO: Not implemented on a data module level yet (collate_fn required). @@ -108,6 +108,8 @@ def __init__( self.modalities = list(data_root.keys()) assert 'mask' not in self.modalities, "Modality cannot be called 'mask'." + self.image_modalities = image_modalities or self.modalities + self.sequence_modalities = list(set(self.modalities) - set(image_modalities)) # Convert path strings to lists for m, m_dir in data_root.items(): @@ -229,7 +231,7 @@ def __init__( # If no transform is given, apply only to transform to torch tensor if isinstance(transform, A.Compose): - self.transform = MultimodalTransforms(transform) + self.transform = MultimodalTransforms(transform, sequence_modalities=self.sequence_modalities) elif transform is None: self.transform = MultimodalToTensor(self.modalities) else: @@ -238,6 +240,7 @@ def __init__( for m in self.modalities} self.transform = MultimodalTransforms(transform, shared=False) + # Ignore rasterio of not geo-referenced files import warnings import rasterio warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning) @@ -257,7 +260,10 @@ def __getitem__(self, index: int) -> dict[str, Any]: for modality, file in sample.items(): data = self._load_file( - file, nan_replace=self.no_label_replace if modality == 'mask' else self.no_data_replace).to_numpy() + file, + nan_replace=self.no_label_replace if modality == 'mask' else self.no_data_replace, + modality=modality, + ).to_numpy() # Expand temporal dim if modality in self.filter_indices and self.expand_temporal_dimension: @@ -267,7 +273,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: if modality == 'mask': data = data[0] - if len(data.shape) >= 3 and self.channel_position: + if modality in self.image_modalities and len(data.shape) >= 3 and self.channel_position: # to channels last (required by albumentations) data = np.moveaxis(data, self.channel_position, -1) @@ -293,10 +299,10 @@ def __getitem__(self, index: int) -> dict[str, Any]: return output - def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArray: + def _load_file(self, path, nan_replace: int | float | None = None, modality: str | None = None) -> xr.DataArray: if path.endswith('.zarr') or path.endswith('.zarr.zip'): data = xr.open_zarr(path, mask_and_scale=True) - data_var = list(data.data_vars)[0] # TODO: Make data var configurable if required (e.g. for time/loc) + data_var = modality if modality in data.data_vars else list(data.data_vars)[0] data = data[data_var] elif path.endswith('.npy'): data = xr.DataArray(np.load(path)) @@ -304,7 +310,11 @@ def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArr data = rioxarray.open_rasterio(path, masked=True) if nan_replace is not None: - data = data.fillna(nan_replace) + try: + data = data.fillna(nan_replace) + except np.exceptions.DTypePromotionError as e: + # No common dtype, e.g., for timestamps + pass return data @@ -319,6 +329,7 @@ def __init__( image_grep: str | None = "*", label_grep: str | None = "*", split: Path | None = None, + image_modalities: list[str] | None = None, rgb_modality: str | None = None, rgb_indices: list[str] | None = None, allow_missing_modalities: bool = False, @@ -371,6 +382,7 @@ def __init__( image_grep=image_grep, label_grep=label_grep, split=split, + image_modalities=image_modalities, rgb_modality=rgb_modality, rgb_indices=rgb_indices, allow_missing_modalities=allow_missing_modalities, @@ -482,13 +494,14 @@ def __init__( image_grep: str | None = "*", label_grep: str | None = "*", split: Path | None = None, + image_modalities: list[str] | None = None, rgb_modality: str | None = None, rgb_indices: list[int] | None = None, allow_missing_modalities : bool = False, allow_substring_split_file: bool = False, dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, - constant_scale: float = 1., + constant_scale: float = 1., # TODO: Check types of args transform: A.Compose | None = None, no_data_replace: float | None = None, no_label_replace: int | None = None, @@ -531,6 +544,7 @@ def __init__( image_grep=image_grep, label_grep=label_grep, split=split, + image_modalities=image_modalities, rgb_modality=rgb_modality, rgb_indices=rgb_indices, allow_missing_modalities=allow_missing_modalities, @@ -563,12 +577,8 @@ def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> .. versionadded:: 0.2 """ - raise NotImplementedError('Code is based on the generic single-modality dataset and not yet adapted. ' - 'Set `export TERRATORCH_NUM_VAL_PLOTS=0` before running terratorch.') image = sample["image"] - if len(image.shape) == 5: - return if isinstance(image, torch.Tensor): image = image.numpy() image = image.take(self.rgb_indices, axis=0) diff --git a/terratorch/datasets/transforms.py b/terratorch/datasets/transforms.py index 6f28b6b8..41b9f39c 100644 --- a/terratorch/datasets/transforms.py +++ b/terratorch/datasets/transforms.py @@ -1,8 +1,8 @@ # Copyright contributors to the Terratorch project +import torch from albumentations import BasicTransform, Compose, ImageOnlyTransform from einops import rearrange -from torch import Tensor import albumentations as A N_DIMS_FOR_TEMPORAL = 4 @@ -160,17 +160,29 @@ def get_transform_init_args_names(self): class MultimodalTransforms: """Applies albumentations transforms to multiple images""" - def __init__(self, transforms: dict | A.Compose, shared : bool = True): + def __init__( + self, + transforms: dict | A.Compose, + shared : bool = True, + sequence_modalities: list[str] | None = None, + sequence_transform: object | None = None, + ): self.transforms = transforms self.shared = shared + self.sequence_modalities = sequence_modalities + self.sequence_transform = sequence_transform or torch.from_numpy def __call__(self, data: dict): if self.shared: - # albumentations requires a key 'image' + # albumentations requires a key 'image' and treats all other keys as additional targets image_modality = list(data.keys())[0] data['image'] = data.pop(image_modality) data = self.transforms(**data) data[image_modality] = data.pop('image') + + # Process sequence data which is ignored by albumentations as 'global_label' + for modality in self.sequence_modalities: + data[modality] = self.sequence_transform(data[modality]) else: # Applies transformations for each modality separate for key, value in data.items(): From 707bcf5054a0cea67933f93549a46c5ea21b78d3 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Thu, 14 Nov 2024 11:10:14 +0100 Subject: [PATCH 22/42] Add modality files to multimodal dataset Signed-off-by: Benedikt Blumenstiel --- .../generic_multimodal_data_module.py | 11 +-- .../datasets/generic_multimodal_dataset.py | 88 +++++++++++++++---- terratorch/datasets/transforms.py | 11 ++- 3 files changed, 84 insertions(+), 26 deletions(-) diff --git a/terratorch/datamodules/generic_multimodal_data_module.py b/terratorch/datamodules/generic_multimodal_data_module.py index 1a1d8d2a..4f8584a5 100644 --- a/terratorch/datamodules/generic_multimodal_data_module.py +++ b/terratorch/datamodules/generic_multimodal_data_module.py @@ -28,6 +28,8 @@ def collate_chunk_dicts(batch_list): for key, value in batch_list[0].items(): # TODO: Handle missing modalities when allow_missing_modalities is set. if isinstance(value, torch.Tensor): batch[key] = torch.concat([chunk[key] for chunk in batch_list]) + if isinstance(value, np.ndarray): + batch[key] = np.concatenate([chunk[key] for chunk in batch_list]) elif isinstance(value, dict): batch[key] = collate_chunk_dicts([chunk[key] for chunk in batch_list]) else: @@ -158,7 +160,7 @@ def __init__( rgb_modality: str | None = None, rgb_indices: list[int] | None = None, allow_substring_split_file: bool = False, - constant_scale: dict | float = 1., + constant_scale: dict[float] = None, train_transform: dict | A.Compose | None | list[A.BasicTransform] = None, val_transform: dict | A.Compose | None | list[A.BasicTransform] = None, test_transform: dict | A.Compose | None | list[A.BasicTransform] = None, @@ -263,13 +265,6 @@ def __init__( self.test_split = test_split self.allow_substring_split_file = allow_substring_split_file self.constant_scale = constant_scale - if isinstance(self.constant_scale, dict): - # Fill in missing modalities - self.constant_scale = {m: self.constant_scale[m] if m in self.constant_scale else 1. - for m in modalities} - else: - # Create dict - self.constant_scale = {m: constant_scale for m in modalities} self.no_data_replace = no_data_replace self.no_label_replace = no_label_replace self.drop_last = drop_last diff --git a/terratorch/datasets/generic_multimodal_dataset.py b/terratorch/datasets/generic_multimodal_dataset.py index 5564a45d..9691677d 100644 --- a/terratorch/datasets/generic_multimodal_dataset.py +++ b/terratorch/datasets/generic_multimodal_dataset.py @@ -5,6 +5,7 @@ import glob import os import torch +import pandas as pd from abc import ABC from pathlib import Path from typing import Any @@ -24,6 +25,16 @@ from terratorch.datasets.transforms import MultimodalTransforms +def load_table_data(file_path: str): + if file_path.endswith('parquet'): + df = pd.read_parquet(file_path) + elif file_path.endswith('csv'): + df = pd.read_csv(file_path, index_col=0) + else: + raise Exception(f"Unrecognized file type: {file_path}. Only parquet and csv are supported.") + return df + + class MultimodalToTensor(): def __init__(self, modalities): self.modalities = modalities @@ -66,7 +77,7 @@ def __init__( dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, constant_scale: dict[float] = None, - transform: A.Compose | None = None, + transform: A.Compose | dict | None = None, no_data_replace: float | None = None, no_label_replace: int | None = None, expand_temporal_dimension: bool = False, @@ -118,7 +129,7 @@ def __init__( if label_data_root and not isinstance(label_data_root, list): label_data_root = [label_data_root] - self.constant_scale = {m: constant_scale[m] or 1. for m in self.modalities} + self.constant_scale = constant_scale or [] self.no_data_replace = no_data_replace self.no_label_replace = no_label_replace self.reduce_zero_label = reduce_zero_label @@ -131,9 +142,12 @@ def __init__( # Load samples based on split file if self.split_file is not None: - with open(self.split_file) as f: - split = f.readlines() - valid_files = {rf"{substring.strip()}" for substring in split} + if str(self.split_file).endswith('.txt'): + with open(self.split_file) as f: + split = f.readlines() + valid_files = {rf"{substring.strip()}" for substring in split} + else: + valid_files = list(load_table_data(str(self.split_file)).index) else: image_files = {} @@ -159,11 +173,48 @@ def __init__( self.samples = [] num_modalities = len(self.modalities) + int(label_data_root is not None) + + # Check for parquet and csv files with modality data and read the file + for m, m_dirs in data_root.items(): + m_dfs = [] + for m_dir in m_dirs: + if os.path.isfile(m_dir): + m_dfs.append(load_table_data(m_dir)) + if len(m_dfs): + # Replace paths with DataFrame + data_root[m] = pd.concat(m_dfs, axis=0) + + # Check for sample key + assert list(valid_files)[0] in data_root[m].index, \ + (f"Sample key expected in table index (first column) for {m}, " + f"key '{list(valid_files)[0]}' is not in index [{list(data_root[m].index[:3])}, ...]") + + if label_data_root: + # Check for parquet and csv files with labels and read the file + l_dfs = [] + for l_dir in label_data_root: + if os.path.isfile(l_dir): + l_dfs.append(load_table_data(l_dir)) + if len(l_dfs): + # Replace paths with DataFrame + label_data_root = pd.concat(l_dfs, axis=0) + + # Check for sample key + assert list(valid_files)[0] in label_data_root.index, \ + (f"Sample key expected in table index (first column) of the labels, " + f"key '{list(valid_files)[0]}' is not in label index [{list(label_data_root.index[:3])}, ...]") + # Iterate over all files in split for file in valid_files: sample = {} # Iterate over all modalities for m, m_dirs in data_root.items(): + # Add tabular data to sample + if isinstance(m_dirs, pd.DataFrame): + # m_dirs paths is replaced by DataFrame + sample[m] = m_dirs.loc[file].values + continue + # Iterate over all directories of the current modality for m_dir in m_dirs: if allow_substring_split_file: @@ -179,6 +230,12 @@ def __init__( sample[m] = file_path break if label_data_root: + # Add tabular data to sample + if isinstance(label_data_root, pd.DataFrame): + # m_dirs paths is replaced by DataFrame + sample['mask'] = label_data_root.loc[file].values + continue + for l_dir in label_data_root: if allow_substring_split_file: # Substring match with label_grep @@ -263,7 +320,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: file, nan_replace=self.no_label_replace if modality == 'mask' else self.no_data_replace, modality=modality, - ).to_numpy() + ) # Expand temporal dim if modality in self.filter_indices and self.expand_temporal_dimension: @@ -280,7 +337,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: if modality in self.filter_indices: data = data[..., self.filter_indices[modality]] - if modality != 'mask': + if modality in self.constant_scale: data = data.astype(np.float32) * self.constant_scale[modality] output[modality] = data @@ -300,21 +357,20 @@ def __getitem__(self, index: int) -> dict[str, Any]: return output def _load_file(self, path, nan_replace: int | float | None = None, modality: str | None = None) -> xr.DataArray: - if path.endswith('.zarr') or path.endswith('.zarr.zip'): + if isinstance(path, np.ndarray): + # data was loaded from table and is saved in memory + data = path + elif path.endswith('.zarr') or path.endswith('.zarr.zip'): data = xr.open_zarr(path, mask_and_scale=True) data_var = modality if modality in data.data_vars else list(data.data_vars)[0] - data = data[data_var] + data = data[data_var].to_numpy() elif path.endswith('.npy'): - data = xr.DataArray(np.load(path)) + data = np.load(path) else: - data = rioxarray.open_rasterio(path, masked=True) + data = rioxarray.open_rasterio(path, masked=True).to_numpy() if nan_replace is not None: - try: - data = data.fillna(nan_replace) - except np.exceptions.DTypePromotionError as e: - # No common dtype, e.g., for timestamps - pass + data = np.nan_to_num(data, nan=nan_replace) return data diff --git a/terratorch/datasets/transforms.py b/terratorch/datasets/transforms.py index 41b9f39c..f175e3a6 100644 --- a/terratorch/datasets/transforms.py +++ b/terratorch/datasets/transforms.py @@ -158,6 +158,13 @@ def get_transform_init_args_names(self): return "band_indices" +def default_sequence_transform(array): + if array.dtype == float or array.dtype == int: + return torch.from_numpy(array) + else: + return array + + class MultimodalTransforms: """Applies albumentations transforms to multiple images""" def __init__( @@ -170,12 +177,12 @@ def __init__( self.transforms = transforms self.shared = shared self.sequence_modalities = sequence_modalities - self.sequence_transform = sequence_transform or torch.from_numpy + self.sequence_transform = sequence_transform or default_sequence_transform def __call__(self, data: dict): if self.shared: # albumentations requires a key 'image' and treats all other keys as additional targets - image_modality = list(data.keys())[0] + image_modality = list(set(data.keys()) - set(self.sequence_modalities))[0] data['image'] = data.pop(image_modality) data = self.transforms(**data) data[image_modality] = data.pop('image') From 421306c95ee8b886b0790279c0e7210442e67b10 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Thu, 14 Nov 2024 11:21:17 +0100 Subject: [PATCH 23/42] rename m_dirs Signed-off-by: Benedikt Blumenstiel --- .../datasets/generic_multimodal_dataset.py | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/terratorch/datasets/generic_multimodal_dataset.py b/terratorch/datasets/generic_multimodal_dataset.py index 9691677d..634ca337 100644 --- a/terratorch/datasets/generic_multimodal_dataset.py +++ b/terratorch/datasets/generic_multimodal_dataset.py @@ -122,10 +122,10 @@ def __init__( self.image_modalities = image_modalities or self.modalities self.sequence_modalities = list(set(self.modalities) - set(image_modalities)) - # Convert path strings to lists - for m, m_dir in data_root.items(): - if not isinstance(m_dir, list): - data_root[m] = [m_dir] + # Convert path strings to lists as the code expects a list of paths per modality + for m, m_path in data_root.items(): + if not isinstance(m_path, list): + data_root[m] = [m_path] if label_data_root and not isinstance(label_data_root, list): label_data_root = [label_data_root] @@ -151,8 +151,8 @@ def __init__( else: image_files = {} - for m, m_dirs in data_root.items(): - dir_lists = [glob.glob(os.path.join(r, image_grep[m])) for r in m_dirs] + for m, m_paths in data_root.items(): + dir_lists = [glob.glob(os.path.join(r, image_grep[m])) for r in m_paths] image_files[m] = sorted([p for l in dir_lists for p in l]) # Concatenate if label_data_root: @@ -175,11 +175,11 @@ def __init__( num_modalities = len(self.modalities) + int(label_data_root is not None) # Check for parquet and csv files with modality data and read the file - for m, m_dirs in data_root.items(): + for m, m_paths in data_root.items(): m_dfs = [] - for m_dir in m_dirs: - if os.path.isfile(m_dir): - m_dfs.append(load_table_data(m_dir)) + for m_path in m_paths: + if os.path.isfile(m_path): + m_dfs.append(load_table_data(m_path)) if len(m_dfs): # Replace paths with DataFrame data_root[m] = pd.concat(m_dfs, axis=0) @@ -192,9 +192,9 @@ def __init__( if label_data_root: # Check for parquet and csv files with labels and read the file l_dfs = [] - for l_dir in label_data_root: - if os.path.isfile(l_dir): - l_dfs.append(load_table_data(l_dir)) + for l_path in label_data_root: + if os.path.isfile(l_path): + l_dfs.append(load_table_data(l_path)) if len(l_dfs): # Replace paths with DataFrame label_data_root = pd.concat(l_dfs, axis=0) @@ -208,31 +208,31 @@ def __init__( for file in valid_files: sample = {} # Iterate over all modalities - for m, m_dirs in data_root.items(): + for m, m_paths in data_root.items(): # Add tabular data to sample - if isinstance(m_dirs, pd.DataFrame): - # m_dirs paths is replaced by DataFrame - sample[m] = m_dirs.loc[file].values + if isinstance(m_paths, pd.DataFrame): + # m_paths was replaced by DataFrame + sample[m] = m_paths.loc[file].values continue # Iterate over all directories of the current modality - for m_dir in m_dirs: + for m_path in m_paths: if allow_substring_split_file: # Substring match with image_grep - m_files = glob.glob(os.path.join(m_dir, file + image_grep[m])) + m_files = glob.glob(os.path.join(m_path, file + image_grep[m])) if m_files: sample[m] = m_files[0] break else: # Exact match - file_path = os.path.join(m_dir, file) + file_path = os.path.join(m_path, file) if os.path.isfile(file_path): sample[m] = file_path break if label_data_root: # Add tabular data to sample if isinstance(label_data_root, pd.DataFrame): - # m_dirs paths is replaced by DataFrame + # label_data_root was replaced by DataFrame sample['mask'] = label_data_root.loc[file].values continue From 3ebec8e2fed4768a18e87a46c66673c4f67b4209 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Thu, 14 Nov 2024 17:19:51 +0100 Subject: [PATCH 24/42] Added multimodal scalar dataset and concat_bands Signed-off-by: Benedikt Blumenstiel --- examples/confs/multimae_sen1floods11.yaml | 16 +- .../multimodal_prithvi_sen1floods11.yaml | 170 ++++++++++++ examples/confs/multimodal_sen1floods11.yaml | 183 ------------- .../generic_multimodal_data_module.py | 74 +++-- terratorch/datasets/__init__.py | 4 +- .../datasets/generic_multimodal_dataset.py | 254 +++++++++++++++--- 6 files changed, 457 insertions(+), 244 deletions(-) create mode 100644 examples/confs/multimodal_prithvi_sen1floods11.yaml delete mode 100644 examples/confs/multimodal_sen1floods11.yaml diff --git a/examples/confs/multimae_sen1floods11.yaml b/examples/confs/multimae_sen1floods11.yaml index 67ee4174..73d5bf6b 100644 --- a/examples/confs/multimae_sen1floods11.yaml +++ b/examples/confs/multimae_sen1floods11.yaml @@ -39,9 +39,9 @@ data: - LULC rgb_modality: S2L2A # If not provided, uses first modality rgb_indices: + - 3 - 2 - 1 - - 0 train_data_root: S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand @@ -89,8 +89,7 @@ data: S1: - -12.577 - -20.265 - LULC: - - 0 + stds: S2L2A: - 1160.144 @@ -108,8 +107,6 @@ data: S1: - 5.179 - 5.872 - LULC: - - 1 num_classes: 2 @@ -118,9 +115,7 @@ data: init_args: height: 224 width: 224 - - class_path: albumentations.HorizontalFlip - init_args: - p: 0.5 + - class_path: albumentations.D4 - class_path: ToTensorV2 @@ -129,17 +124,18 @@ model: init_args: model_factory: EncoderDecoderFactory model_args: - decoder: FCNDecoder backbone_pretrained: false backbone: multimae_base backbone_input_adapters: - S1 - S2L2A - LULC + decoder: UperNetDecoder # FCNDecoder + # decoder_num_convs: 4 # only for FCNDecoder + decoder_scale_modules: True # only for UperNetDecoder decoder_channels: 256 num_classes: 2 head_dropout: 0.1 - decoder_num_convs: 4 head_channel_list: - 256 loss: ce diff --git a/examples/confs/multimodal_prithvi_sen1floods11.yaml b/examples/confs/multimodal_prithvi_sen1floods11.yaml new file mode 100644 index 00000000..93f10167 --- /dev/null +++ b/examples/confs/multimodal_prithvi_sen1floods11.yaml @@ -0,0 +1,170 @@ +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: output + name: multimodal_prithvi_sen1floods11 + version: test_best + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 40 + + max_epochs: 5 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + enable_checkpointing: True + default_root_dir: output/multimodal_prithvi_sen1floods11/ + +data: + class_path: GenericMultiModalDataModule + init_args: + task: 'segmentation' + batch_size: 2 + num_workers: 0 + modalities: + - S2L2A + - S1 + rgb_modality: S2L2A # If not provided, uses first modality + rgb_indices: + - 3 + - 2 + - 1 + + train_data_root: + S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand + S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand + train_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand + val_data_root: + S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand + S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand + val_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand + test_data_root: + S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand + S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand + test_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand + + train_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_train.txt + val_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_valid.txt + test_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_test.txt + + allow_substring_split_file: True + img_grep: + S2L2A: "*_S2L2AHand.tif" + S1: "*_S1Hand.tif" + label_grep: "*_LabelHand.tif" + no_label_replace: -1 + no_data_replace: 0 + concat_bands: true # Concatenate S2 and S2 bands + + means: + S2L2A: + - 1793.243 + - 1924.863 + - 2184.553 + - 2340.936 + - 2671.402 + - 3240.082 + - 3468.412 + - 3563.244 + - 3627.704 + - 3711.071 + - 3416.714 + - 2849.625 + S1: + - -12.577 + - -20.265 + + stds: + S2L2A: + - 1160.144 + - 1201.092 + - 1219.943 + - 1397.225 + - 1400.035 + - 1373.136 + - 1429.17 + - 1485.025 + - 1447.836 + - 1652.703 + - 1471.002 + - 1365.30 + S1: + - 5.179 + - 5.872 + + num_classes: 2 + + train_transform: + - class_path: albumentations.RandomCrop + init_args: + height: 224 + width: 224 + - class_path: albumentations.D4 + - class_path: ToTensorV2 + + +model: + class_path: terratorch.tasks.SemanticSegmentationTask + init_args: + model_factory: EncoderDecoderFactory + model_args: + backbone: prithvi_vit_100 + backbone_pretrained: false + backbone_bands: + - COASTAL_AEROSOL + - BLUE + - GREEN + - RED + - RED_EDGE_1 + - RED_EDGE_2 + - RED_EDGE_3 + - NIR_BROAD + - NIR_NARROW + - CIRRUS + - SWIR_1 + - SWIR_2 + - VV + - VH + decoder: FCNDecoder # FCNDecoder + decoder_num_convs: 4 # only for FCNDecoder + # decoder_scale_modules: True # only for UperNetDecoder + decoder_channels: 256 + num_classes: 2 + head_dropout: 0.1 + head_channel_list: + - 256 + + loss: dice + ignore_index: -1 + class_weights: + - 0.3 + - 0.7 + class_names: + - Others + - Flood + freeze_backbone: false + freeze_decoder: false + +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 6.e-5 + weight_decay: 0.05 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss + diff --git a/examples/confs/multimodal_sen1floods11.yaml b/examples/confs/multimodal_sen1floods11.yaml deleted file mode 100644 index f591b229..00000000 --- a/examples/confs/multimodal_sen1floods11.yaml +++ /dev/null @@ -1,183 +0,0 @@ -# lightning.pytorch==2.1.1 -seed_everything: 0 -trainer: - accelerator: auto - strategy: auto - devices: auto - num_nodes: 1 - precision: 16-mixed - logger: - class_path: TensorBoardLogger - init_args: - save_dir: output - name: sen1floods11_MM - callbacks: - - class_path: RichProgressBar - - class_path: LearningRateMonitor - init_args: - logging_interval: epoch - - class_path: EarlyStopping - init_args: - monitor: val/loss - patience: 40 - - max_epochs: 2 - check_val_every_n_epoch: 1 - log_every_n_steps: 50 - enable_checkpointing: true - default_root_dir: output/sen1floods11_MM/ - -data: - class_path: GenericMultiModalDataModule - init_args: - task: 'segmentation' - batch_size: 4 - num_workers: 0 - modalities: - - S2L2A - - S1 - - LULC - S2L2A_dataset_bands: - - COASTAL_AEROSOL - - BLUE - - GREEN - - RED - - RED_EDGE_1 - - RED_EDGE_2 - - RED_EDGE_3 - - NIR_BROAD - - NIR_NARROW - - CIRRUS - - SWIR_1 - - SWIR_2 - S2L2A_output_bands: - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - S1_dataset_bands: - - vv - - vh - S1_output_bands: - - vv - - vh - LULC_dataset_bands: - - lulc - LULC_output_bands: - - lulc - rgb_modality: S2L2A # If not provided, uses first modality - rgb_indices: - - 0 - - 1 - - 2 - - train_S2L2A_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand - train_S1_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand - train_LULC_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand - train_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand - val_S2L2A_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand - val_S1_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand - val_LULC_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand - val_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand - test_S2L2A_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand - test_S1_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand - test_LULC_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand - test_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand - - train_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_train.txt - val_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_valid.txt - test_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_test.txt - - S2L2A_grep: "*_S2L2AHand.tif" - S1_grep: "*_S1Hand.tif" - LULC_grep: "*_LULCHand.npy" - label_grep: "*_LabelHand.tif" - no_data_replace: 0 - no_label_replace: -1 - - S2L2A_constant_scale: 1. - S1_constant_scale: 1. - LULC_constant_scale: 1. - - S2L2A_means: - - 0.1412956 - - 0.13795798 - - 0.12353792 - - 0.30902815 - - 0.2044958 - - 0.11912015 - S2L2A_stds: - - 0.07406382 - - 0.07370365 - - 0.08692279 - - 0.11798815 - - 0.09772074 - - 0.07659938 - S1_means: - - -20 - - -20 - S1_stds: - - 10 - - 10 - LULC_means: - - 0 - LULC_stds: - - 1 - - num_classes: 2 - -# train_transform: -# - class_path: albumentations.CenterCrop # TODO: How to handle transforms with multiple modalities? -# init_args: -# height: 224 -# width: 224 -# - class_path: albumentations.HorizontalFlip -# init_args: -# p: 0.5 -# - class_path: ToTensorV2 - - -model: - class_path: terratorch.tasks.SemanticSegmentationTask - init_args: - model_factory: PrithviModelFactory - model_args: - decoder: FCNDecoder - pretrained: true - backbone: prithvi_vit_100 - decoder_channels: 256 - in_channels: 6 - bands: - - BLUE - - GREEN - - RED - - NIR_NARROW - - SWIR_1 - - SWIR_2 - num_frames: 1 - num_classes: 2 - head_dropout: 0.1 - decoder_num_convs: 4 - head_channel_list: - - 256 - loss: ce - ignore_index: -1 - class_weights: - - 0.3 - - 0.7 - freeze_backbone: false - freeze_decoder: false - - -optimizer: - class_path: torch.optim.AdamW - init_args: - lr: 6.e-5 - weight_decay: 0.05 -lr_scheduler: - class_path: ReduceLROnPlateau - init_args: - monitor: val/loss - diff --git a/terratorch/datamodules/generic_multimodal_data_module.py b/terratorch/datamodules/generic_multimodal_data_module.py index 4f8584a5..303793b4 100644 --- a/terratorch/datamodules/generic_multimodal_data_module.py +++ b/terratorch/datamodules/generic_multimodal_data_module.py @@ -14,12 +14,13 @@ import numpy as np import torch from torch import Tensor -from torch.utils.data import DataLoader, RandomSampler, BatchSampler, SequentialSampler +from torch.utils.data import DataLoader, RandomSampler, BatchSampler, SequentialSampler, default_collate from torchgeo.datamodules import NonGeoDataModule from torchgeo.transforms import AugmentationSequential from terratorch.datasets import (GenericMultimodalDataset, GenericMultimodalSegmentationDataset, - GenericMultimodalPixelwiseRegressionDataset, HLSBands) + GenericMultimodalPixelwiseRegressionDataset, GenericMultimodalScalarDataset, HLSBands) +from terratorch.datamodules.generic_pixel_wise_data_module import Normalize from terratorch.io.file import load_from_file_or_attribute @@ -28,7 +29,7 @@ def collate_chunk_dicts(batch_list): for key, value in batch_list[0].items(): # TODO: Handle missing modalities when allow_missing_modalities is set. if isinstance(value, torch.Tensor): batch[key] = torch.concat([chunk[key] for chunk in batch_list]) - if isinstance(value, np.ndarray): + elif isinstance(value, np.ndarray): batch[key] = np.concatenate([chunk[key] for chunk in batch_list]) elif isinstance(value, dict): batch[key] = collate_chunk_dicts([chunk[key] for chunk in batch_list]) @@ -37,14 +38,34 @@ def collate_chunk_dicts(batch_list): return batch -def wrap_in_compose_is_list(transform_list, image_modalities=None, sequence_modalities=None): +def collate_samples(batch_list): + """ + Wrapper for default_collate as it cannot handle some datatypes such as np.datetime64. + """ + batch = {} + for key, value in batch_list[0].items(): # TODO: Handle missing modalities when allow_missing_modalities is set. + if isinstance(value, dict): + batch[key] = collate_samples([chunk[key] for chunk in batch_list]) + else: + try: + batch[key] = default_collate([chunk[key] for chunk in batch_list]) + except TypeError: + # Fallback to numpy or simple list + if isinstance(value, np.ndarray): + batch[key] = np.stack([chunk[key] for chunk in batch_list]) + else: + batch[key] = [chunk[key] for chunk in batch_list] + return batch + + +def wrap_in_compose_is_list(transform_list, image_modalities=None, non_image_modalities=None): additional_targets = {} if image_modalities: for modality in image_modalities: additional_targets[modality] = 'image' - if sequence_modalities: + if non_image_modalities: # Global label values are ignored and need to be processed separately - for modality in sequence_modalities: + for modality in non_image_modalities: additional_targets[modality] = "global_label" # set check shapes to false because of the multitemporal case return A.Compose(transform_list, is_check_shapes=False, additional_targets=additional_targets) \ @@ -175,6 +196,7 @@ def __init__( sample_num_modalities: int | None = None, sample_replace: bool = False, channel_position: int = -3, + concat_bands: bool = False, **kwargs: Any, ) -> None: """Constructor @@ -238,6 +260,9 @@ def __init__( dataset_class = GenericMultimodalSegmentationDataset elif task == 'regression': dataset_class = GenericMultimodalPixelwiseRegressionDataset + elif task in ['classification', 'multilabel_classification', 'scalar_regression', 'scalar']: + dataset_class = GenericMultimodalScalarDataset + task = 'scalar' elif task is None: dataset_class = GenericMultimodalDataset else: @@ -247,7 +272,9 @@ def __init__( self.num_classes = num_classes self.modalities = modalities self.image_modalities = image_modalities or modalities - self.sequence_modalities = list(set(self.modalities) - set(image_modalities)) + self.non_image_modalities = list(set(self.modalities) - set(self.image_modalities)) + if task == 'scalar': + self.non_image_modalities += ['label'] if isinstance(img_grep, dict): self.img_grep = {m: img_grep[m] if m in img_grep else '*' for m in modalities} else: @@ -270,7 +297,7 @@ def __init__( self.drop_last = drop_last self.pin_memory = pin_memory self.sample_num_modalities = sample_num_modalities - self.sample_replace = sample_replace + self.sample_replace = sample_replace self.dataset_bands = dataset_bands self.output_bands = output_bands @@ -282,6 +309,7 @@ def __init__( self.expand_temporal_dimension = expand_temporal_dimension self.reduce_zero_label = reduce_zero_label self.channel_position = channel_position + self.concat_bands = concat_bands if isinstance(train_transform, dict): self.train_transform = {m: wrap_in_compose_is_list(train_transform[m]) if m in train_transform else None @@ -289,7 +317,7 @@ def __init__( elif shared_transforms: self.train_transform = wrap_in_compose_is_list(train_transform, image_modalities=self.image_modalities, - sequence_modalities=self.sequence_modalities) + non_image_modalities=self.non_image_modalities) else: self.train_transform = {m: wrap_in_compose_is_list(train_transform) for m in modalities} @@ -300,7 +328,7 @@ def __init__( elif shared_transforms: self.val_transform = wrap_in_compose_is_list(val_transform, image_modalities=self.image_modalities, - sequence_modalities=self.sequence_modalities) + non_image_modalities=self.non_image_modalities) else: self.val_transform = {m: wrap_in_compose_is_list(val_transform) for m in modalities} @@ -311,19 +339,29 @@ def __init__( elif shared_transforms: self.test_transform = wrap_in_compose_is_list(test_transform, image_modalities=self.image_modalities, - sequence_modalities=self.sequence_modalities) + non_image_modalities=self.non_image_modalities, + ) else: self.test_transform = {m: wrap_in_compose_is_list(test_transform) for m in modalities} - means = {m: load_from_file_or_attribute(means[m]) for m in means.keys()} - stds = {m: load_from_file_or_attribute(stds[m]) for m in stds.keys()} + if self.concat_bands: + # Concatenate mean and std values + means = load_from_file_or_attribute(np.concatenate([means[m] for m in self.image_modalities]).tolist()) + stds = load_from_file_or_attribute(np.concatenate([stds[m] for m in self.image_modalities]).tolist()) + + self.aug = Normalize(means, stds) + else: + # Apply standardization per modality + means = {m: load_from_file_or_attribute(means[m]) for m in means.keys()} + stds = {m: load_from_file_or_attribute(stds[m]) for m in stds.keys()} - self.aug = MultimodalNormalize(means, stds) + self.aug = MultimodalNormalize(means, stds) self.chunk_data = chunk_data - if chunk_data: - self.collate_fn = collate_chunk_dicts + + self.collate_fn = collate_chunk_dicts if chunk_data else collate_samples + def setup(self, stage: str) -> None: if stage in ["fit"]: @@ -347,6 +385,7 @@ def setup(self, stage: str) -> None: expand_temporal_dimension=self.expand_temporal_dimension, reduce_zero_label=self.reduce_zero_label, channel_position=self.channel_position, + concat_bands=self.concat_bands , ) logging.info(f'Train dataset: {len(self.train_dataset)}') if stage in ["fit", "validate"]: @@ -370,6 +409,7 @@ def setup(self, stage: str) -> None: expand_temporal_dimension=self.expand_temporal_dimension, reduce_zero_label=self.reduce_zero_label, channel_position=self.channel_position, + concat_bands=self.concat_bands, ) logging.info(f'Val dataset: {len(self.val_dataset)}') if stage in ["test"]: @@ -393,6 +433,7 @@ def setup(self, stage: str) -> None: expand_temporal_dimension=self.expand_temporal_dimension, reduce_zero_label=self.reduce_zero_label, channel_position=self.channel_position, + concat_bands=self.concat_bands, ) logging.info(f'Test dataset: {len(self.test_dataset)}') if stage in ["predict"] and self.predict_root: @@ -412,6 +453,7 @@ def setup(self, stage: str) -> None: expand_temporal_dimension=self.expand_temporal_dimension, reduce_zero_label=self.reduce_zero_label, channel_position=self.channel_position, + concat_bands=self.concat_bands, ) logging.info(f'Predict dataset: {len(self.predict_dataset)}') diff --git a/terratorch/datasets/__init__.py b/terratorch/datasets/__init__.py index f30f6c20..7bf449a9 100644 --- a/terratorch/datasets/__init__.py +++ b/terratorch/datasets/__init__.py @@ -11,7 +11,8 @@ from terratorch.datasets.generic_multimodal_dataset import ( GenericMultimodalDataset, GenericMultimodalSegmentationDataset, - GenericMultimodalPixelwiseRegressionDataset + GenericMultimodalPixelwiseRegressionDataset, + GenericMultimodalScalarDataset, ) from terratorch.datasets.hls import HLSL30, HLSS30 from terratorch.datasets.m_bigearthnet import MBigEarthNonGeo @@ -52,6 +53,7 @@ "GenericMultimodalDataset", "GenericMultimodalSegmentationDataset", "GenericMultimodalPixelwiseRegressionDataset", + "GenericMultimodalScalarDataset", "FireScarsNonGeo", "FireScarsHLS", "FireScarsSegmentationMask", diff --git a/terratorch/datasets/generic_multimodal_dataset.py b/terratorch/datasets/generic_multimodal_dataset.py index 634ca337..72658d6a 100644 --- a/terratorch/datasets/generic_multimodal_dataset.py +++ b/terratorch/datasets/generic_multimodal_dataset.py @@ -3,6 +3,7 @@ """Module containing generic dataset classes""" import glob +import logging import os import torch import pandas as pd @@ -25,7 +26,8 @@ from terratorch.datasets.transforms import MultimodalTransforms -def load_table_data(file_path: str): +def load_table_data(file_path: str | Path) -> pd.DataFrame: + file_path = str(file_path) if file_path.endswith('parquet'): df = pd.read_parquet(file_path) elif file_path.endswith('csv'): @@ -72,7 +74,7 @@ def __init__( image_modalities: list[str] | None = None, rgb_modality: str | None = None, rgb_indices: list[int] | None = None, - allow_missing_modalities: bool = False, # TODO: Not implemented on a data module level yet (collate_fn required). + allow_missing_modalities: bool = False, # TODO: Not implemented on a data module level yet. allow_substring_split_file: bool = False, dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, @@ -83,6 +85,8 @@ def __init__( expand_temporal_dimension: bool = False, reduce_zero_label: bool = False, channel_position: int = -1, + scalar_label: bool = False, + concat_bands: bool = False, *args, **kwargs, ) -> None: """Constructor @@ -120,13 +124,15 @@ def __init__( self.modalities = list(data_root.keys()) assert 'mask' not in self.modalities, "Modality cannot be called 'mask'." self.image_modalities = image_modalities or self.modalities - self.sequence_modalities = list(set(self.modalities) - set(image_modalities)) + self.non_image_modalities = list(set(self.modalities) - set(image_modalities)) + if scalar_label: + self.non_image_modalities += ['label'] # Convert path strings to lists as the code expects a list of paths per modality for m, m_path in data_root.items(): if not isinstance(m_path, list): data_root[m] = [m_path] - if label_data_root and not isinstance(label_data_root, list): + if label_data_root is not None and not isinstance(label_data_root, list): label_data_root = [label_data_root] self.constant_scale = constant_scale or [] @@ -135,6 +141,13 @@ def __init__( self.reduce_zero_label = reduce_zero_label self.expand_temporal_dimension = expand_temporal_dimension self.channel_position = channel_position + self.scalar_label = scalar_label + self.concat_bands = concat_bands + assert not self.concat_bands or len(self.non_image_modalities) == 0, \ + (f"concat_bands can only be used with image modalities, " + f"but non-image modalities are given: {self.non_image_modalities}") + assert not self.concat_bands or not allow_missing_modalities, \ + "concat_bands cannot be used with allow_missing_modalities." if self.expand_temporal_dimension and len(dataset_bands) != self.modalities: msg = "Please provide dataset_bands for each modality when expand_temporal_dimension is True" @@ -147,7 +160,7 @@ def __init__( split = f.readlines() valid_files = {rf"{substring.strip()}" for substring in split} else: - valid_files = list(load_table_data(str(self.split_file)).index) + valid_files = list(load_table_data(self.split_file).index) else: image_files = {} @@ -155,7 +168,7 @@ def __init__( dir_lists = [glob.glob(os.path.join(r, image_grep[m])) for r in m_paths] image_files[m] = sorted([p for l in dir_lists for p in l]) # Concatenate - if label_data_root: + if label_data_root is not None: dir_lists = [glob.glob(os.path.join(r, label_grep)) for r in label_data_root] image_files['mask'] = sorted([p for l in dir_lists for p in l]) # Concatenate @@ -189,7 +202,7 @@ def __init__( (f"Sample key expected in table index (first column) for {m}, " f"key '{list(valid_files)[0]}' is not in index [{list(data_root[m].index[:3])}, ...]") - if label_data_root: + if label_data_root is not None: # Check for parquet and csv files with labels and read the file l_dfs = [] for l_path in label_data_root: @@ -229,29 +242,28 @@ def __init__( if os.path.isfile(file_path): sample[m] = file_path break - if label_data_root: + if label_data_root is not None: # Add tabular data to sample if isinstance(label_data_root, pd.DataFrame): # label_data_root was replaced by DataFrame sample['mask'] = label_data_root.loc[file].values - continue - - for l_dir in label_data_root: - if allow_substring_split_file: - # Substring match with label_grep - l_files = glob.glob(os.path.join(l_dir, file + label_grep)) - if l_files: - sample['mask'] = l_files[0] - break - else: - # Exact match - file_path = os.path.join(l_dir, file) - if os.path.isfile(file_path): - sample['mask'] = file_path - break - if 'mask' not in sample: - # Only add sample if mask is present - break + else: + for l_dir in label_data_root: + if allow_substring_split_file: + # Substring match with label_grep + l_files = glob.glob(os.path.join(l_dir, file + label_grep)) + if l_files: + sample['mask'] = l_files[0] + break + else: + # Exact match + file_path = os.path.join(l_dir, file) + if os.path.isfile(file_path): + sample['mask'] = file_path + break + if 'mask' not in sample: + # Only add sample if mask is present + break if len(sample) == num_modalities or allow_missing_modalities: self.samples.append(sample) @@ -286,9 +298,13 @@ def __init__( self.filter_indices[m] = [self.dataset_bands[m].index(band) for band in self.output_bands[m]] + if not self.channel_position: + logging.warning('output_bands is defined but no channel_position is provided. ' + 'Channels must be in the last dimension, otherwise provide channel_position.') + # If no transform is given, apply only to transform to torch tensor if isinstance(transform, A.Compose): - self.transform = MultimodalTransforms(transform, sequence_modalities=self.sequence_modalities) + self.transform = MultimodalTransforms(transform, non_image_modalities=self.non_image_modalities) elif transform is None: self.transform = MultimodalToTensor(self.modalities) else: @@ -327,7 +343,8 @@ def __getitem__(self, index: int) -> dict[str, Any]: data = rearrange(data, "(channels time) h w -> channels time h w", channels=len(self.dataset_bands[modality])) - if modality == 'mask': + if modality == 'mask' and len(data) == 1: + # tasks expect image masks without channel dim data = data[0] if modality in self.image_modalities and len(data.shape) >= 3 and self.channel_position: @@ -344,15 +361,21 @@ def __getitem__(self, index: int) -> dict[str, Any]: if self.reduce_zero_label: output["mask"] -= 1 + + if self.scalar_label: + output["label"] = output.pop("mask") + if self.transform: output = self.transform(output) - # Tasks expect data to be stored in 'image', moving modalities to image dict - output = { - 'image': {m: output[m] for m in self.modalities if m in output}, - 'mask': output['mask'] if 'mask' in output else None, - 'filename': self.samples[index] - } + if self.concat_bands: + # Concatenate bands of all image modalities + output['image'] = torch.cat([output.pop(m) for m in self.image_modalities if m in output]) + else: + # Tasks expect data to be stored in 'image', moving modalities to image dict + output['image'] = {m: output.pop(m) for m in self.modalities if m in output} + + output['filename'] = self.samples[index] return output @@ -400,6 +423,7 @@ def __init__( expand_temporal_dimension: bool = False, reduce_zero_label: bool = False, channel_position: int = -3, + *args, **kwargs, ) -> None: """Constructor @@ -432,6 +456,8 @@ def __init__( reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False. """ + assert label_data_root is not None, "label_data_root must be specified for segmentation tasks." + super().__init__( data_root, label_data_root=label_data_root, @@ -452,6 +478,7 @@ def __init__( expand_temporal_dimension=expand_temporal_dimension, reduce_zero_label=reduce_zero_label, channel_position=channel_position, + *args, **kwargs, ) self.num_classes = num_classes self.class_names = class_names @@ -564,6 +591,7 @@ def __init__( expand_temporal_dimension: bool = False, reduce_zero_label: bool = False, channel_position: int = -3, + *args, **kwargs, ) -> None: """Constructor @@ -594,6 +622,8 @@ def __init__( reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False. """ + assert label_data_root is not None, "label_data_root must be specified for regression tasks." + super().__init__( data_root, label_data_root=label_data_root, @@ -614,6 +644,7 @@ def __init__( expand_temporal_dimension=expand_temporal_dimension, reduce_zero_label=reduce_zero_label, channel_position=channel_position, + *args, **kwargs, ) def __getitem__(self, index: int) -> dict[str, Any]: @@ -686,3 +717,158 @@ def _plot_sample(image, label, prediction=None, suptitle=None): if suptitle is not None: plt.suptitle(suptitle) return fig + + +class GenericMultimodalScalarDataset(GenericMultimodalDataset): + """GenericMultimodalClassificationDataset""" + + def __init__( + self, + data_root: Path, + label_data_root: Path, + image_grep: str | None = "*", + label_grep: str | None = "*", + split: Path | None = None, + image_modalities: list[str] | None = None, + rgb_modality: str | None = None, + rgb_indices: list[int] | None = None, + allow_missing_modalities : bool = False, + allow_substring_split_file: bool = False, + dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, + output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, + constant_scale: float = 1., # TODO: Check types of args + transform: A.Compose | None = None, + no_data_replace: float | None = None, + no_label_replace: int | None = None, + expand_temporal_dimension: bool = False, + reduce_zero_label: bool = False, + channel_position: int = -3, + *args, **kwargs, + ) -> None: + """Constructor + + Args: + TODO: Update docs + data_root (Path): Path to data root directory + label_data_root (Path, optional): Path to data root directory with labels. + If not specified, will use the same as for images. + image_grep (str, optional): Regular expression appended to data_root to find input images. + Defaults to "*". + label_grep (str, optional): Regular expression appended to data_root to find ground truth masks. + Defaults to "*". + split (Path, optional): Path to file containing files to be used for this split. + The file should be a new-line separated prefixes contained in the desired files. + Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep]) + rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. + dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. + output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. + constant_scale (float): Factor to multiply image values by. Defaults to 1. + transform (Albumentations.Compose | None): Albumentations transform to be applied. + Should end with ToTensorV2(). If used through the generic_data_module, + should not include normalization. Not supported for multi-temporal data. + Defaults to None, which simply applies ToTensorV2(). + no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None. + no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None. + expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w). + Defaults to False. + reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the + expected 0. Defaults to False. + """ + assert label_data_root is not None, "label_data_root must be specified for scalar tasks." + + super().__init__( + data_root, + label_data_root=label_data_root, + image_grep=image_grep, + label_grep=label_grep, + split=split, + image_modalities=image_modalities, + rgb_modality=rgb_modality, + rgb_indices=rgb_indices, + allow_missing_modalities=allow_missing_modalities, + allow_substring_split_file=allow_substring_split_file, + dataset_bands=dataset_bands, + output_bands=output_bands, + constant_scale=constant_scale, + transform=transform, + no_data_replace=no_data_replace, + no_label_replace=no_label_replace, + expand_temporal_dimension=expand_temporal_dimension, + reduce_zero_label=reduce_zero_label, + channel_position=channel_position, + scalar_label=True, + *args, **kwargs, + ) + + def __getitem__(self, index: int) -> dict[str, Any]: + item = super().__getitem__(index) + return item + + def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> Figure: + """Plot a sample from the dataset. + + Args: + sample (dict[str, Tensor]): a sample returned by :meth:`__getitem__` + suptitle (str|None): optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.2 + """ + + # TODO: Check plotting code for classification tasks and add it to generic classification dataset as well + raise NotImplementedError + + image = sample["image"] + if isinstance(image, torch.Tensor): + image = image.numpy() + image = image.take(self.rgb_indices, axis=0) + image = np.transpose(image, (1, 2, 0)) + image = (image - image.min(axis=(0, 1))) * (1 / image.max(axis=(0, 1))) + image = np.clip(image, 0, 1) + + label_mask = sample["mask"] + if isinstance(label_mask, torch.Tensor): + label_mask = label_mask.numpy() + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction_mask = sample["prediction"] + if isinstance(prediction_mask, torch.Tensor): + prediction_mask = prediction_mask.numpy() + + return self._plot_sample( + image, + label_mask, + prediction=prediction_mask if showing_predictions else None, + suptitle=suptitle, + ) + + @staticmethod + def _plot_sample(image, label, prediction=None, suptitle=None): + num_images = 4 if prediction is not None else 3 + fig, ax = plt.subplots(1, num_images, figsize=(12, 10), layout="compressed") + + norm = mpl.colors.Normalize(vmin=label.min(), vmax=label.max()) + ax[0].axis("off") + ax[0].title.set_text("Image") + ax[0].imshow(image) + + ax[1].axis("off") + ax[1].title.set_text("Ground Truth Mask") + ax[1].imshow(label, cmap="Greens", norm=norm) + + ax[2].axis("off") + ax[2].title.set_text("GT Mask on Image") + ax[2].imshow(image) + ax[2].imshow(label, cmap="Greens", alpha=0.3, norm=norm) + # ax[2].legend() + + if prediction is not None: + ax[3].title.set_text("Predicted Mask") + ax[3].imshow(prediction, cmap="Greens", norm=norm) + + if suptitle is not None: + plt.suptitle(suptitle) + return fig From 0d4c7ea1ffe6ae635b1e1cc6d60a597eaf500c5c Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Thu, 14 Nov 2024 17:52:46 +0100 Subject: [PATCH 25/42] Renaming Signed-off-by: Benedikt Blumenstiel --- .../generic_multimodal_data_module.py | 40 +++++++++--------- .../generic_pixel_wise_data_module.py | 4 +- .../datasets/generic_multimodal_dataset.py | 42 +++++++++---------- 3 files changed, 43 insertions(+), 43 deletions(-) diff --git a/terratorch/datamodules/generic_multimodal_data_module.py b/terratorch/datamodules/generic_multimodal_data_module.py index 303793b4..d7cfeffe 100644 --- a/terratorch/datamodules/generic_multimodal_data_module.py +++ b/terratorch/datamodules/generic_multimodal_data_module.py @@ -62,7 +62,7 @@ def wrap_in_compose_is_list(transform_list, image_modalities=None, non_image_mod additional_targets = {} if image_modalities: for modality in image_modalities: - additional_targets[modality] = 'image' + additional_targets[modality] = "image" if non_image_modalities: # Global label values are ignored and need to be processed separately for modality in non_image_modalities: @@ -80,9 +80,9 @@ def __init__(self, means, stds): def __call__(self, batch): for m in self.means.keys(): - if m not in batch['image']: + if m not in batch["image"]: continue - image = batch['image'][m] + image = batch["image"][m] if len(image.shape) == 5: # B, C, T, H, W means = torch.tensor(self.means[m], device=image.device).view(1, -1, 1, 1, 1) @@ -103,7 +103,7 @@ def __call__(self, batch): msg = (f"Expected batch with 5 or 4 dimensions (B, C, (T,) H, W), sample with 3 dimensions (C, H, W) " f"or a single channel, but got {len(image.shape)}") raise Exception(msg) - batch['image'][m] = (image - means) / stds + batch["image"][m] = (image - means) / stds return batch @@ -119,7 +119,7 @@ def __init__(self, modalities, sample_num_modalities, sample_replace, *args, **k def __iter__(self) -> Iterator[list[int]]: """ - Code similar to BatchSampler but samples tuples in the format (idx, ['m1', 'm2', ...]) + Code similar to BatchSampler but samples tuples in the format (idx, ["m1", "m2", ...]) """ # Select sampled modalities per batch sampled_modalities = np.random.choice(self.modalities, self.sample_num_modalities, replace=self.sample_replace) @@ -256,30 +256,30 @@ def __init__( into device/CUDA pinned memory before returning them. Defaults to False. """ - if task == 'segmentation': + if task == "segmentation": dataset_class = GenericMultimodalSegmentationDataset - elif task == 'regression': + elif task == "regression": dataset_class = GenericMultimodalPixelwiseRegressionDataset - elif task in ['classification', 'multilabel_classification', 'scalar_regression', 'scalar']: + elif task in ["classification", "multilabel_classification", "scalar_regression", "scalar"]: dataset_class = GenericMultimodalScalarDataset - task = 'scalar' + task = "scalar" elif task is None: dataset_class = GenericMultimodalDataset else: - raise ValueError(f'Unknown task {task}, only segmentation and regression are supported.') + raise ValueError(f"Unknown task {task}, only segmentation and regression are supported.") super().__init__(dataset_class, batch_size, num_workers, **kwargs) self.num_classes = num_classes self.modalities = modalities self.image_modalities = image_modalities or modalities self.non_image_modalities = list(set(self.modalities) - set(self.image_modalities)) - if task == 'scalar': - self.non_image_modalities += ['label'] + if task == "scalar": + self.non_image_modalities += ["label"] if isinstance(img_grep, dict): - self.img_grep = {m: img_grep[m] if m in img_grep else '*' for m in modalities} + self.img_grep = {m: img_grep[m] if m in img_grep else "*" for m in modalities} else: - self.img_grep = {m: img_grep or '*' for m in modalities} - self.label_grep = label_grep or '*' + self.img_grep = {m: img_grep or "*" for m in modalities} + self.label_grep = label_grep or "*" self.train_root = train_data_root self.val_root = val_data_root self.test_root = test_data_root @@ -387,7 +387,7 @@ def setup(self, stage: str) -> None: channel_position=self.channel_position, concat_bands=self.concat_bands , ) - logging.info(f'Train dataset: {len(self.train_dataset)}') + logging.info(f"Train dataset: {len(self.train_dataset)}") if stage in ["fit", "validate"]: self.val_dataset = self.dataset_class( data_root=self.val_root, @@ -411,7 +411,7 @@ def setup(self, stage: str) -> None: channel_position=self.channel_position, concat_bands=self.concat_bands, ) - logging.info(f'Val dataset: {len(self.val_dataset)}') + logging.info(f"Val dataset: {len(self.val_dataset)}") if stage in ["test"]: self.test_dataset = self.dataset_class( data_root=self.test_root, @@ -435,7 +435,7 @@ def setup(self, stage: str) -> None: channel_position=self.channel_position, concat_bands=self.concat_bands, ) - logging.info(f'Test dataset: {len(self.test_dataset)}') + logging.info(f"Test dataset: {len(self.test_dataset)}") if stage in ["predict"] and self.predict_root: self.predict_dataset = self.dataset_class( data_root=self.predict_root, @@ -455,13 +455,13 @@ def setup(self, stage: str) -> None: channel_position=self.channel_position, concat_bands=self.concat_bands, ) - logging.info(f'Predict dataset: {len(self.predict_dataset)}') + logging.info(f"Predict dataset: {len(self.predict_dataset)}") def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders. Args: - split: Either 'train', 'val', 'test', or 'predict'. + split: Either "train", "val", "test", or "predict". Returns: A collection of data loaders specifying samples. diff --git a/terratorch/datamodules/generic_pixel_wise_data_module.py b/terratorch/datamodules/generic_pixel_wise_data_module.py index c21a7651..47c08aa0 100644 --- a/terratorch/datamodules/generic_pixel_wise_data_module.py +++ b/terratorch/datamodules/generic_pixel_wise_data_module.py @@ -108,7 +108,7 @@ def __init__( no_data_replace: float | None = None, no_label_replace: int | None = None, drop_last: bool = True, - pin_memory: bool = True, + pin_memory: bool = False, **kwargs: Any, ) -> None: """Constructor @@ -365,7 +365,7 @@ def __init__( no_data_replace: float | None = None, no_label_replace: int | None = None, drop_last: bool = True, - pin_memory: bool = True, + pin_memory: bool = False, **kwargs: Any, ) -> None: """Constructor diff --git a/terratorch/datasets/generic_multimodal_dataset.py b/terratorch/datasets/generic_multimodal_dataset.py index 72658d6a..414cef3f 100644 --- a/terratorch/datasets/generic_multimodal_dataset.py +++ b/terratorch/datasets/generic_multimodal_dataset.py @@ -28,9 +28,9 @@ def load_table_data(file_path: str | Path) -> pd.DataFrame: file_path = str(file_path) - if file_path.endswith('parquet'): + if file_path.endswith("parquet"): df = pd.read_parquet(file_path) - elif file_path.endswith('csv'): + elif file_path.endswith("csv"): df = pd.read_csv(file_path, index_col=0) else: raise Exception(f"Unrecognized file type: {file_path}. Only parquet and csv are supported.") @@ -53,7 +53,7 @@ def __call__(self, d): elif len(v.shape) == 5: v = np.moveaxis(v, -1, 1) # B, C, T, H, W else: - raise ValueError(f'Unexpected shape for {k}: {v.shape}') + raise ValueError(f"Unexpected shape for {k}: {v.shape}") new_dict[k] = torch.from_numpy(v) return new_dict @@ -122,11 +122,11 @@ def __init__( self.split_file = split self.modalities = list(data_root.keys()) - assert 'mask' not in self.modalities, "Modality cannot be called 'mask'." + assert "mask" not in self.modalities, "Modality cannot be called 'mask'." self.image_modalities = image_modalities or self.modalities self.non_image_modalities = list(set(self.modalities) - set(image_modalities)) if scalar_label: - self.non_image_modalities += ['label'] + self.non_image_modalities += ["label"] # Convert path strings to lists as the code expects a list of paths per modality for m, m_path in data_root.items(): @@ -155,7 +155,7 @@ def __init__( # Load samples based on split file if self.split_file is not None: - if str(self.split_file).endswith('.txt'): + if str(self.split_file).endswith(".txt"): with open(self.split_file) as f: split = f.readlines() valid_files = {rf"{substring.strip()}" for substring in split} @@ -170,7 +170,7 @@ def __init__( if label_data_root is not None: dir_lists = [glob.glob(os.path.join(r, label_grep)) for r in label_data_root] - image_files['mask'] = sorted([p for l in dir_lists for p in l]) # Concatenate + image_files["mask"] = sorted([p for l in dir_lists for p in l]) # Concatenate if allow_substring_split_file: # Get exact match of filenames @@ -246,22 +246,22 @@ def __init__( # Add tabular data to sample if isinstance(label_data_root, pd.DataFrame): # label_data_root was replaced by DataFrame - sample['mask'] = label_data_root.loc[file].values + sample["mask"] = label_data_root.loc[file].values else: for l_dir in label_data_root: if allow_substring_split_file: # Substring match with label_grep l_files = glob.glob(os.path.join(l_dir, file + label_grep)) if l_files: - sample['mask'] = l_files[0] + sample["mask"] = l_files[0] break else: # Exact match file_path = os.path.join(l_dir, file) if os.path.isfile(file_path): - sample['mask'] = file_path + sample["mask"] = file_path break - if 'mask' not in sample: + if "mask" not in sample: # Only add sample if mask is present break @@ -299,8 +299,8 @@ def __init__( self.filter_indices[m] = [self.dataset_bands[m].index(band) for band in self.output_bands[m]] if not self.channel_position: - logging.warning('output_bands is defined but no channel_position is provided. ' - 'Channels must be in the last dimension, otherwise provide channel_position.') + logging.warning("output_bands is defined but no channel_position is provided. " + "Channels must be in the last dimension, otherwise provide channel_position.") # If no transform is given, apply only to transform to torch tensor if isinstance(transform, A.Compose): @@ -334,7 +334,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: for modality, file in sample.items(): data = self._load_file( file, - nan_replace=self.no_label_replace if modality == 'mask' else self.no_data_replace, + nan_replace=self.no_label_replace if modality == "mask" else self.no_data_replace, modality=modality, ) @@ -343,7 +343,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: data = rearrange(data, "(channels time) h w -> channels time h w", channels=len(self.dataset_bands[modality])) - if modality == 'mask' and len(data) == 1: + if modality == "mask" and len(data) == 1: # tasks expect image masks without channel dim data = data[0] @@ -370,12 +370,12 @@ def __getitem__(self, index: int) -> dict[str, Any]: if self.concat_bands: # Concatenate bands of all image modalities - output['image'] = torch.cat([output.pop(m) for m in self.image_modalities if m in output]) + output["image"] = torch.cat([output.pop(m) for m in self.image_modalities if m in output]) else: - # Tasks expect data to be stored in 'image', moving modalities to image dict - output['image'] = {m: output.pop(m) for m in self.modalities if m in output} + # Tasks expect data to be stored in "image", moving modalities to image dict + output["image"] = {m: output.pop(m) for m in self.modalities if m in output} - output['filename'] = self.samples[index] + output["filename"] = self.samples[index] return output @@ -383,11 +383,11 @@ def _load_file(self, path, nan_replace: int | float | None = None, modality: str if isinstance(path, np.ndarray): # data was loaded from table and is saved in memory data = path - elif path.endswith('.zarr') or path.endswith('.zarr.zip'): + elif path.endswith(".zarr") or path.endswith(".zarr.zip"): data = xr.open_zarr(path, mask_and_scale=True) data_var = modality if modality in data.data_vars else list(data.data_vars)[0] data = data[data_var].to_numpy() - elif path.endswith('.npy'): + elif path.endswith(".npy"): data = np.load(path) else: data = rioxarray.open_rasterio(path, masked=True).to_numpy() From 87d22f9d6fa8c4bdabd97da1510833990bfcb97d Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Thu, 14 Nov 2024 17:54:45 +0100 Subject: [PATCH 26/42] Renaming Signed-off-by: Benedikt Blumenstiel --- terratorch/datasets/transforms.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/terratorch/datasets/transforms.py b/terratorch/datasets/transforms.py index f175e3a6..478c348e 100644 --- a/terratorch/datasets/transforms.py +++ b/terratorch/datasets/transforms.py @@ -158,7 +158,7 @@ def get_transform_init_args_names(self): return "band_indices" -def default_sequence_transform(array): +def default_non_image_transform(array): if array.dtype == float or array.dtype == int: return torch.from_numpy(array) else: @@ -171,25 +171,25 @@ def __init__( self, transforms: dict | A.Compose, shared : bool = True, - sequence_modalities: list[str] | None = None, - sequence_transform: object | None = None, + non_image_modalities: list[str] | None = None, + non_image_transform: object | None = None, ): self.transforms = transforms self.shared = shared - self.sequence_modalities = sequence_modalities - self.sequence_transform = sequence_transform or default_sequence_transform + self.non_image_modalities = non_image_modalities + self.non_image_transform = non_image_transform or default_non_image_transform def __call__(self, data: dict): if self.shared: # albumentations requires a key 'image' and treats all other keys as additional targets - image_modality = list(set(data.keys()) - set(self.sequence_modalities))[0] + image_modality = list(set(data.keys()) - set(self.non_image_modalities))[0] data['image'] = data.pop(image_modality) data = self.transforms(**data) data[image_modality] = data.pop('image') # Process sequence data which is ignored by albumentations as 'global_label' - for modality in self.sequence_modalities: - data[modality] = self.sequence_transform(data[modality]) + for modality in self.non_image_modalities: + data[modality] = self.non_image_transform(data[modality]) else: # Applies transformations for each modality separate for key, value in data.items(): From fe1ccb7cf8cae0fa99d12dcfe927b702f79d406f Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Mon, 18 Nov 2024 14:41:05 +0100 Subject: [PATCH 27/42] Updated multimodal docs Signed-off-by: Benedikt Blumenstiel --- examples/confs/multimae_sen1floods11.yaml | 2 +- .../multimodal_prithvi_sen1floods11.yaml | 2 +- .../generic_multimodal_data_module.py | 195 +++++++---- .../datasets/generic_multimodal_dataset.py | 329 ++++++++++++------ 4 files changed, 346 insertions(+), 182 deletions(-) diff --git a/examples/confs/multimae_sen1floods11.yaml b/examples/confs/multimae_sen1floods11.yaml index 73d5bf6b..62a7acb2 100644 --- a/examples/confs/multimae_sen1floods11.yaml +++ b/examples/confs/multimae_sen1floods11.yaml @@ -63,7 +63,7 @@ data: val_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_valid.txt test_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_test.txt - allow_substring_split_file: True + allow_substring_file_names: True img_grep: S2L2A: "*_S2L2AHand.tif" S1: "*_S1Hand.tif" diff --git a/examples/confs/multimodal_prithvi_sen1floods11.yaml b/examples/confs/multimodal_prithvi_sen1floods11.yaml index 93f10167..d25778a2 100644 --- a/examples/confs/multimodal_prithvi_sen1floods11.yaml +++ b/examples/confs/multimodal_prithvi_sen1floods11.yaml @@ -60,7 +60,7 @@ data: val_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_valid.txt test_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_test.txt - allow_substring_split_file: True + allow_substring_file_names: True img_grep: S2L2A: "*_S2L2AHand.tif" S1: "*_S1Hand.tif" diff --git a/terratorch/datamodules/generic_multimodal_data_module.py b/terratorch/datamodules/generic_multimodal_data_module.py index d7cfeffe..2101f9bb 100644 --- a/terratorch/datamodules/generic_multimodal_data_module.py +++ b/terratorch/datamodules/generic_multimodal_data_module.py @@ -8,15 +8,12 @@ from collections.abc import Callable, Iterable from pathlib import Path from typing import Any, Iterator - import albumentations as A -import kornia.augmentation as K import numpy as np import torch from torch import Tensor from torch.utils.data import DataLoader, RandomSampler, BatchSampler, SequentialSampler, default_collate from torchgeo.datamodules import NonGeoDataModule -from torchgeo.transforms import AugmentationSequential from terratorch.datasets import (GenericMultimodalDataset, GenericMultimodalSegmentationDataset, GenericMultimodalPixelwiseRegressionDataset, GenericMultimodalScalarDataset, HLSBands) @@ -155,44 +152,45 @@ class GenericMultiModalDataModule(NonGeoDataModule): def __init__( self, batch_size: int, - num_workers: int, modalities: list[str], - train_data_root: dict, - val_data_root: dict, - test_data_root: dict, - means: dict, - stds: dict, + train_data_root: dict[Path], + val_data_root: dict[Path], + test_data_root: dict[Path], + means: dict[list], + stds: dict[list], task: str | None = None, num_classes: int | None = None, - img_grep: str | dict | None = None, + image_grep: str | dict | None = None, label_grep: str | None = None, train_label_data_root: Path | None = None, val_label_data_root: Path | None = None, test_label_data_root: Path | None = None, - predict_data_root: Path | None = None, + predict_data_root: dict[Path] | None = None, train_split: Path | None = None, val_split: Path | None = None, test_split: Path | None = None, - dataset_bands: dict | None = None, - output_bands: dict | None = None, - predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, - predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, + dataset_bands: dict[list] | None = None, + output_bands: dict[list] | None = None, + predict_dataset_bands: dict[list] | None = None, + predict_output_bands: dict[list] | None = None, image_modalities: list[str] | None = None, rgb_modality: str | None = None, rgb_indices: list[int] | None = None, - allow_substring_split_file: bool = False, + allow_substring_file_names: bool = False, + class_names: list[str] | None = None, constant_scale: dict[float] = None, train_transform: dict | A.Compose | None | list[A.BasicTransform] = None, val_transform: dict | A.Compose | None | list[A.BasicTransform] = None, test_transform: dict | A.Compose | None | list[A.BasicTransform] = None, shared_transforms: list | bool = True, expand_temporal_dimension: bool = False, - reduce_zero_label: bool = False, no_data_replace: float | None = None, - no_label_replace: int | None = None, + no_label_replace: float | None = -1, + reduce_zero_label: bool = False, drop_last: bool = True, + num_workers: int = 0, pin_memory: bool = False, - chunk_data: bool = False, + data_with_sample_dim: bool = False, sample_num_modalities: int | None = None, sample_replace: bool = False, channel_position: int = -3, @@ -202,59 +200,105 @@ def __init__( """Constructor Args: - # TODO: Update docs - batch_size (int): _description_ - num_workers (int): _description_ - train_data_root (Path): _description_ - val_data_root (Path): _description_ - test_data_root (Path): _description_ - predict_data_root (Path): _description_ - img_grep (str): _description_ - label_grep (str): _description_ - means (list[float]): _description_ - stds (list[float]): _description_ - num_classes (int): _description_ - train_label_data_root (Path | None, optional): _description_. Defaults to None. - val_label_data_root (Path | None, optional): _description_. Defaults to None. - test_label_data_root (Path | None, optional): _description_. Defaults to None. - train_split (Path | None, optional): _description_. Defaults to None. - val_split (Path | None, optional): _description_. Defaults to None. - test_split (Path | None, optional): _description_. Defaults to None. - dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None. - output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. - Naming must match that of dataset_bands. Defaults to None. - predict_dataset_bands (list[HLSBands | int] | None): Overwrites dataset_bands - with this value at predict time. + batch_size (int): Number of samples in per batch. + modalities (list[str]): List of modalities. + train_data_root (dict[Path]): Dictionary of paths to training data root directory or csv/parquet files with + image-level data, with modalities as keys. + val_data_root (dict[Path]): Dictionary of paths to validation data root directory or csv/parquet files with + image-level data, with modalities as keys. + test_data_root (dict[Path]): Dictionary of paths to test data root directory or csv/parquet files with + image-level data, with modalities as keys. + means (dict[list]): Dictionary of mean values as lists with modalities as keys. + stds (dict[list]): Dictionary of std values as lists with modalities as keys. + task (str, optional): Selected task form segmentation, regression (pixel-wise), classification, + multilabel_classification, scalar_regression, scalar (custom image-level task), or None (no targets). + Defaults to None. + num_classes (int, optional): Number of classes in classification or segmentation tasks. + predict_data_root (dict[Path], optional): Dictionary of paths to data root directory or csv/parquet files + with image-level data, with modalities as keys. + image_grep (dict[str], optional): Dictionary with regular expression appended to data_root to find input + images, with modalities as keys. Defaults to "*". Ignored when allow_substring_file_names is False. + label_grep (str, optional): Regular expression appended to label_data_root to find labels or mask files. + Defaults to "*". Ignored when allow_substring_file_names is False. + train_label_data_root (Path | None, optional): Path to data root directory with training labels or + csv/parquet files with labels. Required for supervised tasks. + val_label_data_root (Path | None, optional): Path to data root directory with validation labels or + csv/parquet files with labels. Required for supervised tasks. + test_label_data_root (Path | None, optional): Path to data root directory with test labels or + csv/parquet files with labels. Required for supervised tasks. + train_split (Path, optional): Path to file containing training samples prefixes to be used for this split. + The file can be a csv/parquet file with the prefixes in the index or a txt file with new-line separated + sample prefixes. File names must be exact matches if allow_substring_file_names is False. Otherwise, + files are searched using glob with the form Path(data_root).glob(prefix + [image or label grep]). + If not specified, search samples based on files in data_root. Defaults to None. + val_split (Path, optional): Path to file containing validation samples prefixes to be used for this split. + The file can be a csv/parquet file with the prefixes in the index or a txt file with new-line separated + sample prefixes. File names must be exact matches if allow_substring_file_names is False. Otherwise, + files are searched using glob with the form Path(data_root).glob(prefix + [image or label grep]). + If not specified, search samples based on files in data_root. Defaults to None. + test_split (Path, optional): Path to file containing test samples prefixes to be used for this split. + The file can be a csv/parquet file with the prefixes in the index or a txt file with new-line separated + sample prefixes. File names must be exact matches if allow_substring_file_names is False. Otherwise, + files are searched using glob with the form Path(data_root).glob(prefix + [image or label grep]). + If not specified, search samples based on files in data_root. Defaults to None. + dataset_bands (dict[list], optional): Bands present in the dataset, provided in a dictionary with modalities + as keys. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so + that they can then be referred to by output_bands. Needs to be superset of output_bands. Can be a subset + of all modalities. Defaults to None. + output_bands (dict[list], optional): Bands that should be output by the dataset as named by dataset_bands, + provided as a dictionary with modality keys. Can be subset of all modalities. Defaults to None. + predict_dataset_bands (list[dict], optional): Overwrites dataset_bands with this value at predict time. Defaults to None, which does not overwrite. - predict_output_bands (list[HLSBands | int] | None): Overwrites output_bands - with this value at predict time. Defaults to None, which does not overwrite. - constant_scale (float, optional): _description_. Defaults to 1. + predict_output_bands (list[dict], optional): Overwrites output_bands with this value at predict time. + Defaults to None, which does not overwrite. + image_modalities(list[str], optional): List of pixel-level raster modalities. Defaults to data_root.keys(). + The difference between all modalities and image_modalities are non-image modalities which are treated + differently during the transforms and are not modified but only converted into a tensor if possible. + rgb_modality (str, optional): Modality used for RGB plots. Defaults to first modality in data_root.keys(). rgb_indices (list[int] | None, optional): _description_. Defaults to None. - train_transform (Albumentations.Compose | None): Albumentations transform - to be applied to the train dataset. - Should end with ToTensorV2(). If used through the generic_data_module, - should not include normalization. Not supported for multi-temporal data. + allow_substring_file_names + class_names (list[str], optional): Names of the classes. Defaults to None. + constant_scale (dict[float]): Factor to multiply data values by, provided as a dictionary with modalities as + keys. Can be subset of all modalities. Defaults to None. + train_transform (Albumentations.Compose | dict | None): Albumentations transform to be applied to all image + modalities. Should end with ToTensorV2() and not include normalization. The transform is not applied to + non-image data, which is only converted to tensors if possible. If dict, can include separate transforms + per modality (no shared parameters between modalities). Defaults to None, which simply applies ToTensorV2(). - val_transform (Albumentations.Compose | None): Albumentations transform - to be applied to the train dataset. - Should end with ToTensorV2(). If used through the generic_data_module, - should not include normalization. Not supported for multi-temporal data. + val_transform (Albumentations.Compose | dict | None): Albumentations transform to be applied to all image + modalities. Should end with ToTensorV2() and not include normalization. The transform is not applied to + non-image data, which is only converted to tensors if possible. If dict, can include separate transforms + per modality (no shared parameters between modalities). Defaults to None, which simply applies ToTensorV2(). - test_transform (Albumentations.Compose | None): Albumentations transform - to be applied to the train dataset. - Should end with ToTensorV2(). If used through the generic_data_module, - should not include normalization. Not supported for multi-temporal data. + test_transform (Albumentations.Compose | dict | None): Albumentations transform to be applied to all image + modalities. Should end with ToTensorV2() and not include normalization. The transform is not applied to + non-image data, which is only converted to tensors if possible. If dict, can include separate transforms + per modality (no shared parameters between modalities). Defaults to None, which simply applies ToTensorV2(). - no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None. - no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None. + shared_transforms (bool): transforms are shared between all image modalities (e.g., similar crop). + This setting is ignored if transforms are defined per modality. Defaults to True. expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w). + Only works with image modalities. Is only applied to modalities with defined dataset_bands. Defaults to False. + no_data_replace (float | None): Replace nan values in input images with this value. If none, does no + replacement. Defaults to None. + no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. + Defaults to None. reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False. drop_last (bool): Drop the last batch if it is not complete. Defaults to True. - pin_memory (bool): If ``True``, the data loader will copy Tensors - into device/CUDA pinned memory before returning them. Defaults to False. - + num_workers (int): Number of parallel workers. Defaults to 0 for single threaded process. + pin_memory (bool): If ``True``, the data loader will copy Tensors into device/CUDA pinned memory before + returning them. Defaults to False. + data_with_sample_dim (bool): Use a specific collate function to concatenate samples along a existing sample + dimension instead of stacking the samples. Defaults to False. + sample_num_modalities (int, optional): Load only a subset of modalities per batch. Defaults to None. + sample_replace (bool): If sample_num_modalities is set, sample modalities with replacement. + Defaults to False. + channel_position (int): Position of the channel dimension in the image modalities. Defaults to -3. + concat_bands (bool): Concatenate all image modalities along the band dimension into a single "image", so + that it can be processed by single-modal models. Concatenate in the order of provided modalities. + Works with image modalities only. Does not work with allow_missing_modalities. Defaults to False. """ if task == "segmentation": dataset_class = GenericMultimodalSegmentationDataset @@ -270,15 +314,16 @@ def __init__( super().__init__(dataset_class, batch_size, num_workers, **kwargs) self.num_classes = num_classes + self.class_names = class_names self.modalities = modalities self.image_modalities = image_modalities or modalities self.non_image_modalities = list(set(self.modalities) - set(self.image_modalities)) if task == "scalar": self.non_image_modalities += ["label"] - if isinstance(img_grep, dict): - self.img_grep = {m: img_grep[m] if m in img_grep else "*" for m in modalities} + if isinstance(image_grep, dict): + self.image_grep = {m: image_grep[m] if m in image_grep else "*" for m in modalities} else: - self.img_grep = {m: img_grep or "*" for m in modalities} + self.image_grep = {m: image_grep or "*" for m in modalities} self.label_grep = label_grep or "*" self.train_root = train_data_root self.val_root = val_data_root @@ -290,7 +335,7 @@ def __init__( self.train_split = train_split self.val_split = val_split self.test_split = test_split - self.allow_substring_split_file = allow_substring_split_file + self.allow_substring_file_names = allow_substring_file_names self.constant_scale = constant_scale self.no_data_replace = no_data_replace self.no_label_replace = no_label_replace @@ -358,9 +403,9 @@ def __init__( self.aug = MultimodalNormalize(means, stds) - self.chunk_data = chunk_data + self.data_with_sample_dim = data_with_sample_dim - self.collate_fn = collate_chunk_dicts if chunk_data else collate_samples + self.collate_fn = collate_chunk_dicts if data_with_sample_dim else collate_samples def setup(self, stage: str) -> None: @@ -368,11 +413,11 @@ def setup(self, stage: str) -> None: self.train_dataset = self.dataset_class( data_root=self.train_root, num_classes=self.num_classes, - image_grep=self.img_grep, + image_grep=self.image_grep, label_grep=self.label_grep, label_data_root=self.train_label_data_root, split=self.train_split, - allow_substring_split_file=self.allow_substring_split_file, + allow_substring_file_names=self.allow_substring_file_names, dataset_bands=self.dataset_bands, output_bands=self.output_bands, constant_scale=self.constant_scale, @@ -392,11 +437,11 @@ def setup(self, stage: str) -> None: self.val_dataset = self.dataset_class( data_root=self.val_root, num_classes=self.num_classes, - image_grep=self.img_grep, + image_grep=self.image_grep, label_grep=self.label_grep, label_data_root=self.val_label_data_root, split=self.val_split, - allow_substring_split_file=self.allow_substring_split_file, + allow_substring_file_names=self.allow_substring_file_names, dataset_bands=self.dataset_bands, output_bands=self.output_bands, constant_scale=self.constant_scale, @@ -416,11 +461,11 @@ def setup(self, stage: str) -> None: self.test_dataset = self.dataset_class( data_root=self.test_root, num_classes=self.num_classes, - image_grep=self.img_grep, + image_grep=self.image_grep, label_grep=self.label_grep, label_data_root=self.test_label_data_root, split=self.test_split, - allow_substring_split_file=self.allow_substring_split_file, + allow_substring_file_names=self.allow_substring_file_names, dataset_bands=self.dataset_bands, output_bands=self.output_bands, constant_scale=self.constant_scale, @@ -440,7 +485,7 @@ def setup(self, stage: str) -> None: self.predict_dataset = self.dataset_class( data_root=self.predict_root, num_classes=self.num_classes, - allow_substring_split_file=self.allow_substring_split_file, + allow_substring_file_names=self.allow_substring_file_names, dataset_bands=self.predict_dataset_bands, output_bands=self.predict_output_bands, constant_scale=self.constant_scale, diff --git a/terratorch/datasets/generic_multimodal_dataset.py b/terratorch/datasets/generic_multimodal_dataset.py index 414cef3f..df34b376 100644 --- a/terratorch/datasets/generic_multimodal_dataset.py +++ b/terratorch/datasets/generic_multimodal_dataset.py @@ -68,23 +68,23 @@ def __init__( self, data_root: dict[Path], label_data_root: Path | list[Path] | None = None, - image_grep: str | None = "*", + image_grep: dict[str] | None = "*", label_grep: str | None = "*", split: Path | None = None, image_modalities: list[str] | None = None, rgb_modality: str | None = None, rgb_indices: list[int] | None = None, - allow_missing_modalities: bool = False, # TODO: Not implemented on a data module level yet. - allow_substring_split_file: bool = False, - dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, - output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, + allow_missing_modalities: bool = False, + allow_substring_file_names: bool = False, + dataset_bands: dict[list] | None = None, + output_bands: dict[list] | None = None, constant_scale: dict[float] = None, transform: A.Compose | dict | None = None, no_data_replace: float | None = None, - no_label_replace: int | None = None, + no_label_replace: float | None = -1, expand_temporal_dimension: bool = False, reduce_zero_label: bool = False, - channel_position: int = -1, + channel_position: int = -3, # TODO Check pissiont zarr data scalar_label: bool = False, concat_bands: bool = False, *args, **kwargs, @@ -92,30 +92,59 @@ def __init__( """Constructor Args: - data_root (Path): Path to data root directory - label_data_root (Path, optional): Path to data root directory with labels. - If not specified, will use the same as for images. - image_grep (str, optional): Regular expression appended to data_root to find input images. - Defaults to "*". - label_grep (str, optional): Regular expression appended to data_root to find ground truth masks. - Defaults to "*". - split (Path, optional): Path to file containing files to be used for this split. - The file should be a new-line separated prefixes contained in the desired files. - Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep]) + data_root (dict[Path]): Dictionary of paths to data root directory or csv/parquet files with image-level + data, with modalities as keys. + label_data_root (Path, optional): Path to data root directory with labels or csv/parquet files with + image-level labels. Needs to be specified for supervised tasks. + image_grep (dict[str], optional): Dictionary with regular expression appended to data_root to find input + images, with modalities as keys. Defaults to "*". Ignored when allow_substring_file_names is False. + label_grep (str, optional): Regular expression appended to label_data_root to find labels or mask files. + Defaults to "*". Ignored when allow_substring_file_names is False. + split (Path, optional): Path to file containing samples prefixes to be used for this split. + The file can be a csv/parquet file with the prefixes in the index or a txt file with new-line separated + sample prefixes. File names must be exact matches if allow_substring_file_names is False. Otherwise, + files are searched using glob with the form Path(data_root).glob(prefix + [image or label grep]). + If not specified, search samples based on files in data_root. Defaults to None. + image_modalities(list[str], optional): List of pixel-level raster modalities. Defaults to data_root.keys(). + The difference between all modalities and image_modalities are non-image modalities which are treated + differently during the transforms and are not modified but only converted into a tensor if possible. + rgb_modality (str, optional): Modality used for RGB plots. Defaults to first modality in data_root.keys(). rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. - dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None. - output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands. - constant_scale (float): Factor to multiply image values by. Defaults to 1. - transform (Albumentations.Compose | None): Albumentations transform to be applied. - Should end with ToTensorV2(). If used through the generic_data_module, - should not include normalization. Not supported for multi-temporal data. + allow_missing_modalities (bool, optional): Allow missing modalities during data loading. Defaults to False. + TODO: Currently not implemented on a data module level! + allow_substring_file_names (bool, optional): Allow substrings during sample identification by adding + image or label grep to the sample prefixes. If False, treats sample prefixes as full file names. + If True and no split file is provided, considers the file stem as prefix, otherwise the full file name. + Defaults to True. + dataset_bands (dict[list], optional): Bands present in the dataset, provided in a dictionary with modalities + as keys. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so + that they can then be referred to by output_bands. Needs to be superset of output_bands. Can be a subset + of all modalities. Defaults to None. + output_bands (dict[list], optional): Bands that should be output by the dataset as named by dataset_bands, + provided as a dictionary with modality keys. Can be subset of all modalities. Defaults to None. + constant_scale (dict[float]): Factor to multiply data values by, provided as a dictionary with modalities as + keys. Can be subset of all modalities. Defaults to None. + transform (Albumentations.Compose | dict | None): Albumentations transform to be applied to all image + modalities (transformation are shared between image modalities, e.g., similar crop or rotation). + Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. + Not supported for multi-temporal data. The transform is not applied to non-image data, which is only + converted to tensors if possible. If dict, can include multiple transforms per modality which are + applied separately (no shared parameters between modalities). Defaults to None, which simply applies ToTensorV2(). - no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None. - no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to -1. + no_data_replace (float | None): Replace nan values in input data with this value. + If None, does no replacement. Defaults to None. + no_label_replace (float | None): Replace nan values in label with this value. + If none, does no replacement. Defaults to -1. expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w). + Only works with image modalities. Is only applied to modalities with defined dataset_bands. Defaults to False. reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False. + channel_position (int): Position of the channel dimension in the image modalities. Defaults to -3. + scalar_label (bool): Returns a image mask if False or otherwise the raw labels. Defaults to False. + concat_bands (bool): Concatenate all image modalities along the band dimension into a single "image", so + that it can be processed by single-modal models. Concatenate in the order of provided modalities. + Works with image modalities only. Does not work with allow_missing_modalities. Defaults to False. """ super().__init__() @@ -149,8 +178,8 @@ def __init__( assert not self.concat_bands or not allow_missing_modalities, \ "concat_bands cannot be used with allow_missing_modalities." - if self.expand_temporal_dimension and len(dataset_bands) != self.modalities: - msg = "Please provide dataset_bands for each modality when expand_temporal_dimension is True" + if self.expand_temporal_dimension and dataset_bands is None: + msg = "Please provide dataset_bands fwhen expand_temporal_dimension is True" raise Exception(msg) # Load samples based on split file @@ -172,7 +201,7 @@ def __init__( dir_lists = [glob.glob(os.path.join(r, label_grep)) for r in label_data_root] image_files["mask"] = sorted([p for l in dir_lists for p in l]) # Concatenate - if allow_substring_split_file: + if allow_substring_file_names: # Get exact match of filenames get_file_id = lambda s: os.path.basename(s) else: @@ -230,7 +259,7 @@ def __init__( # Iterate over all directories of the current modality for m_path in m_paths: - if allow_substring_split_file: + if allow_substring_file_names: # Substring match with image_grep m_files = glob.glob(os.path.join(m_path, file + image_grep[m])) if m_files: @@ -249,7 +278,7 @@ def __init__( sample["mask"] = label_data_root.loc[file].values else: for l_dir in label_data_root: - if allow_substring_split_file: + if allow_substring_file_names: # Substring match with label_grep l_files = glob.glob(os.path.join(l_dir, file + label_grep)) if l_files: @@ -412,49 +441,76 @@ def __init__( rgb_modality: str | None = None, rgb_indices: list[str] | None = None, allow_missing_modalities: bool = False, - allow_substring_split_file: bool = False, - dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, - output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, + allow_substring_file_names: bool = False, + dataset_bands: dict[list] | None = None, + output_bands: dict[list] | None = None, class_names: list[str] | None = None, - constant_scale: float = 1, + constant_scale: dict[float] = 1., transform: A.Compose | None = None, no_data_replace: float | None = None, - no_label_replace: int | None = None, + no_label_replace: int | None = -1, expand_temporal_dimension: bool = False, reduce_zero_label: bool = False, channel_position: int = -3, + concat_bands: bool = False, *args, **kwargs, ) -> None: """Constructor Args: - TODO: Update docs - data_root (Path): Path to data root directory - num_classes (int): Number of classes in the dataset - label_data_root (Path, optional): Path to data root directory with labels. - If not specified, will use the same as for images. - image_grep (str, optional): Regular expression appended to data_root to find input images. - Defaults to "*". - label_grep (str, optional): Regular expression appended to data_root to find ground truth masks. - Defaults to "*". - split (Path, optional): Path to file containing files to be used for this split. - The file should be a new-line separated prefixes contained in the desired files. - Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep]) + data_root (dict[Path]): Dictionary of paths to data root directory or csv/parquet files with image-level + data, with modalities as keys. + num_classes (int): Number of classes. + label_data_root (Path): Path to data root directory with mask files. + image_grep (dict[str], optional): Dictionary with regular expression appended to data_root to find input + images, with modalities as keys. Defaults to "*". Ignored when allow_substring_file_names is False. + label_grep (str, optional): Regular expression appended to label_data_root to find mask files. + Defaults to "*". Ignored when allow_substring_file_names is False. + split (Path, optional): Path to file containing samples prefixes to be used for this split. + The file can be a csv/parquet file with the prefixes in the index or a txt file with new-line separated + sample prefixes. File names must be exact matches if allow_substring_file_names is False. Otherwise, + files are searched using glob with the form Path(data_root).glob(prefix + [image or label grep]). + If not specified, search samples based on files in data_root. Defaults to None. + image_modalities(list[str], optional): List of pixel-level raster modalities. Defaults to data_root.keys(). + The difference between all modalities and image_modalities are non-image modalities which are treated + differently during the transforms and are not modified but only converted into a tensor if possible. + rgb_modality (str, optional): Modality used for RGB plots. Defaults to first modality in data_root.keys(). rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. - dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. - output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. - class_names (list[str], optional): Class names. Defaults to None. - constant_scale (float): Factor to multiply image values by. Defaults to 1. - transform (Albumentations.Compose | None): Albumentations transform to be applied. - Should end with ToTensorV2(). If used through the generic_data_module, - should not include normalization. Not supported for multi-temporal data. + allow_missing_modalities (bool, optional): Allow missing modalities during data loading. Defaults to False. + TODO: Currently not implemented on a data module level! + allow_substring_file_names (bool, optional): Allow substrings during sample identification by adding + image or label grep to the sample prefixes. If False, treats sample prefixes as full file names. + If True and no split file is provided, considers the file stem as prefix, otherwise the full file name. + Defaults to True. + dataset_bands (dict[list], optional): Bands present in the dataset, provided in a dictionary with modalities + as keys. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so + that they can then be referred to by output_bands. Needs to be superset of output_bands. Can be a subset + of all modalities. Defaults to None. + output_bands (dict[list], optional): Bands that should be output by the dataset as named by dataset_bands, + provided as a dictionary with modality keys. Can be subset of all modalities. Defaults to None. + class_names (list[str], optional): Names of the classes. Defaults to None. + constant_scale (dict[float]): Factor to multiply data values by, provided as a dictionary with modalities as + keys. Can be subset of all modalities. Defaults to None. + transform (Albumentations.Compose | dict | None): Albumentations transform to be applied to all image + modalities (transformation are shared between image modalities, e.g., similar crop or rotation). + Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. + Not supported for multi-temporal data. The transform is not applied to non-image data, which is only + converted to tensors if possible. If dict, can include multiple transforms per modality which are + applied separately (no shared parameters between modalities). Defaults to None, which simply applies ToTensorV2(). - no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None. - no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None. + no_data_replace (float | None): Replace nan values in input data with this value. + If None, does no replacement. Defaults to None. + no_label_replace (float | None): Replace nan values in label with this value. + If none, does no replacement. Defaults to -1. expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w). + Only works with image modalities. Is only applied to modalities with defined dataset_bands. Defaults to False. reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False. + channel_position (int): Position of the channel dimension in the image modalities. Defaults to -3. + concat_bands (bool): Concatenate all image modalities along the band dimension into a single "image", so + that it can be processed by single-modal models. Concatenate in the order of provided modalities. + Works with image modalities only. Does not work with allow_missing_modalities. Defaults to False. """ assert label_data_root is not None, "label_data_root must be specified for segmentation tasks." @@ -468,7 +524,7 @@ def __init__( rgb_modality=rgb_modality, rgb_indices=rgb_indices, allow_missing_modalities=allow_missing_modalities, - allow_substring_split_file=allow_substring_split_file, + allow_substring_file_names=allow_substring_file_names, dataset_bands=dataset_bands, output_bands=output_bands, constant_scale=constant_scale, @@ -478,6 +534,7 @@ def __init__( expand_temporal_dimension=expand_temporal_dimension, reduce_zero_label=reduce_zero_label, channel_position=channel_position, + concat_bands=concat_bands, *args, **kwargs, ) self.num_classes = num_classes @@ -581,46 +638,71 @@ def __init__( rgb_modality: str | None = None, rgb_indices: list[int] | None = None, allow_missing_modalities : bool = False, - allow_substring_split_file: bool = False, - dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, - output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, - constant_scale: float = 1., # TODO: Check types of args - transform: A.Compose | None = None, + allow_substring_file_names: bool = False, + dataset_bands: dict[list] | None = None, + output_bands: dict[list] | None = None, + constant_scale: dict[float] = 1., + transform: A.Compose | dict | None = None, no_data_replace: float | None = None, - no_label_replace: int | None = None, + no_label_replace: float | None = None, expand_temporal_dimension: bool = False, reduce_zero_label: bool = False, channel_position: int = -3, + concat_bands: bool = False, *args, **kwargs, ) -> None: """Constructor Args: - TODO: Update docs - data_root (Path): Path to data root directory - label_data_root (Path, optional): Path to data root directory with labels. - If not specified, will use the same as for images. - image_grep (str, optional): Regular expression appended to data_root to find input images. - Defaults to "*". - label_grep (str, optional): Regular expression appended to data_root to find ground truth masks. - Defaults to "*". - split (Path, optional): Path to file containing files to be used for this split. - The file should be a new-line separated prefixes contained in the desired files. - Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep]) + data_root (dict[Path]): Dictionary of paths to data root directory or csv/parquet files with image-level + data, with modalities as keys. + label_data_root (Path): Path to data root directory with ground truth files. + image_grep (dict[str], optional): Dictionary with regular expression appended to data_root to find input + images, with modalities as keys. Defaults to "*". Ignored when allow_substring_file_names is False. + label_grep (str, optional): Regular expression appended to label_data_root to find ground truth files. + Defaults to "*". Ignored when allow_substring_file_names is False. + split (Path, optional): Path to file containing samples prefixes to be used for this split. + The file can be a csv/parquet file with the prefixes in the index or a txt file with new-line separated + sample prefixes. File names must be exact matches if allow_substring_file_names is False. Otherwise, + files are searched using glob with the form Path(data_root).glob(prefix + [image or label grep]). + If not specified, search samples based on files in data_root. Defaults to None. + image_modalities(list[str], optional): List of pixel-level raster modalities. Defaults to data_root.keys(). + The difference between all modalities and image_modalities are non-image modalities which are treated + differently during the transforms and are not modified but only converted into a tensor if possible. + rgb_modality (str, optional): Modality used for RGB plots. Defaults to first modality in data_root.keys(). rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. - dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. - output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. - constant_scale (float): Factor to multiply image values by. Defaults to 1. - transform (Albumentations.Compose | None): Albumentations transform to be applied. - Should end with ToTensorV2(). If used through the generic_data_module, - should not include normalization. Not supported for multi-temporal data. + allow_missing_modalities (bool, optional): Allow missing modalities during data loading. Defaults to False. + TODO: Currently not implemented on a data module level! + allow_substring_file_names (bool, optional): Allow substrings during sample identification by adding + image or label grep to the sample prefixes. If False, treats sample prefixes as full file names. + If True and no split file is provided, considers the file stem as prefix, otherwise the full file name. + Defaults to True. + dataset_bands (dict[list], optional): Bands present in the dataset, provided in a dictionary with modalities + as keys. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so + that they can then be referred to by output_bands. Needs to be superset of output_bands. Can be a subset + of all modalities. Defaults to None. + output_bands (dict[list], optional): Bands that should be output by the dataset as named by dataset_bands, + provided as a dictionary with modality keys. Can be subset of all modalities. Defaults to None. + constant_scale (dict[float]): Factor to multiply data values by, provided as a dictionary with modalities as + keys. Can be subset of all modalities. Defaults to None. + transform (Albumentations.Compose | dict | None): Albumentations transform to be applied to all image + modalities. Should end with ToTensorV2() and not include normalization. The transform is not applied to + non-image data, which is only converted to tensors if possible. If dict, can include separate transforms + per modality (no shared parameters between modalities). Defaults to None, which simply applies ToTensorV2(). - no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None. - no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None. + no_data_replace (float | None): Replace nan values in input data with this value. + If None, does no replacement. Defaults to None. + no_label_replace (float | None): Replace nan values in label with this value. + If none, does no replacement. Defaults to None. expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w). + Only works with image modalities. Is only applied to modalities with defined dataset_bands. Defaults to False. reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False. + channel_position (int): Position of the channel dimension in the image modalities. Defaults to -3. + concat_bands (bool): Concatenate all image modalities along the band dimension into a single "image", so + that it can be processed by single-modal models. Concatenate in the order of provided modalities. + Works with image modalities only. Does not work with allow_missing_modalities. Defaults to False. """ assert label_data_root is not None, "label_data_root must be specified for regression tasks." @@ -634,7 +716,7 @@ def __init__( rgb_modality=rgb_modality, rgb_indices=rgb_indices, allow_missing_modalities=allow_missing_modalities, - allow_substring_split_file=allow_substring_split_file, + allow_substring_file_names=allow_substring_file_names, dataset_bands=dataset_bands, output_bands=output_bands, constant_scale=constant_scale, @@ -644,6 +726,7 @@ def __init__( expand_temporal_dimension=expand_temporal_dimension, reduce_zero_label=reduce_zero_label, channel_position=channel_position, + concat_bands=concat_bands, *args, **kwargs, ) @@ -725,6 +808,7 @@ class GenericMultimodalScalarDataset(GenericMultimodalDataset): def __init__( self, data_root: Path, + num_classes: int, label_data_root: Path, image_grep: str | None = "*", label_grep: str | None = "*", @@ -733,46 +817,76 @@ def __init__( rgb_modality: str | None = None, rgb_indices: list[int] | None = None, allow_missing_modalities : bool = False, - allow_substring_split_file: bool = False, + allow_substring_file_names: bool = False, dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, - constant_scale: float = 1., # TODO: Check types of args + class_names: list[str] | None = None, + constant_scale: dict[float] = 1., transform: A.Compose | None = None, no_data_replace: float | None = None, no_label_replace: int | None = None, expand_temporal_dimension: bool = False, reduce_zero_label: bool = False, channel_position: int = -3, + concat_bands: bool = False, *args, **kwargs, ) -> None: """Constructor Args: - TODO: Update docs - data_root (Path): Path to data root directory - label_data_root (Path, optional): Path to data root directory with labels. - If not specified, will use the same as for images. - image_grep (str, optional): Regular expression appended to data_root to find input images. - Defaults to "*". - label_grep (str, optional): Regular expression appended to data_root to find ground truth masks. - Defaults to "*". - split (Path, optional): Path to file containing files to be used for this split. - The file should be a new-line separated prefixes contained in the desired files. - Files will be seached using glob with the form Path(data_root).glob(prefix + [image or label grep]) + data_root (dict[Path]): Dictionary of paths to data root directory or csv/parquet files with image-level + data, with modalities as keys. + num_classes (int): Number of classes. + label_data_root (Path): Path to data root directory with labels or csv/parquet files with labels. + image_grep (dict[str], optional): Dictionary with regular expression appended to data_root to find input + images, with modalities as keys. Defaults to "*". Ignored when allow_substring_file_names is False. + label_grep (str, optional): Regular expression appended to label_data_root to find labels files. + Defaults to "*". Ignored when allow_substring_file_names is False. + split (Path, optional): Path to file containing samples prefixes to be used for this split. + The file can be a csv/parquet file with the prefixes in the index or a txt file with new-line separated + sample prefixes. File names must be exact matches if allow_substring_file_names is False. Otherwise, + files are searched using glob with the form Path(data_root).glob(prefix + [image or label grep]). + If not specified, search samples based on files in data_root. Defaults to None. + image_modalities(list[str], optional): List of pixel-level raster modalities. Defaults to data_root.keys(). + The difference between all modalities and image_modalities are non-image modalities which are treated + differently during the transforms and are not modified but only converted into a tensor if possible. + rgb_modality (str, optional): Modality used for RGB plots. Defaults to first modality in data_root.keys(). rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. - dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. - output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. - constant_scale (float): Factor to multiply image values by. Defaults to 1. - transform (Albumentations.Compose | None): Albumentations transform to be applied. - Should end with ToTensorV2(). If used through the generic_data_module, - should not include normalization. Not supported for multi-temporal data. + allow_missing_modalities (bool, optional): Allow missing modalities during data loading. Defaults to False. + TODO: Currently not implemented on a data module level! + allow_substring_file_names (bool, optional): Allow substrings during sample identification by adding + image or label grep to the sample prefixes. If False, treats sample prefixes as full file names. + If True and no split file is provided, considers the file stem as prefix, otherwise the full file name. + Defaults to True. + dataset_bands (dict[list], optional): Bands present in the dataset, provided in a dictionary with modalities + as keys. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so + that they can then be referred to by output_bands. Needs to be superset of output_bands. Can be a subset + of all modalities. Defaults to None. + output_bands (dict[list], optional): Bands that should be output by the dataset as named by dataset_bands, + provided as a dictionary with modality keys. Can be subset of all modalities. Defaults to None. + class_names (list[str], optional): Names of the classes. Defaults to None. + constant_scale (dict[float]): Factor to multiply data values by, provided as a dictionary with modalities as + keys. Can be subset of all modalities. Defaults to None. + transform (Albumentations.Compose | dict | None): Albumentations transform to be applied to all image + modalities (transformation are shared between image modalities, e.g., similar crop or rotation). + Should end with ToTensorV2(). If used through the generic_data_module, should not include normalization. + Not supported for multi-temporal data. The transform is not applied to non-image data, which is only + converted to tensors if possible. If dict, can include multiple transforms per modality which are + applied separately (no shared parameters between modalities). Defaults to None, which simply applies ToTensorV2(). - no_data_replace (float | None): Replace nan values in input images with this value. If none, does no replacement. Defaults to None. - no_label_replace (int | None): Replace nan values in label with this value. If none, does no replacement. Defaults to None. + no_data_replace (float | None): Replace nan values in input data with this value. + If None, does no replacement. Defaults to None. + no_label_replace (float | None): Replace nan values in label with this value. + If none, does no replacement. Defaults to -1. expand_temporal_dimension (bool): Go from shape (time*channels, h, w) to (channels, time, h, w). + Only works with image modalities. Is only applied to modalities with defined dataset_bands. Defaults to False. reduce_zero_label (bool): Subtract 1 from all labels. Useful when labels start from 1 instead of the expected 0. Defaults to False. + channel_position (int): Position of the channel dimension in the image modalities. Defaults to -3. + concat_bands (bool): Concatenate all image modalities along the band dimension into a single "image", so + that it can be processed by single-modal models. Concatenate in the order of provided modalities. + Works with image modalities only. Does not work with allow_missing_modalities. Defaults to False. """ assert label_data_root is not None, "label_data_root must be specified for scalar tasks." @@ -786,7 +900,7 @@ def __init__( rgb_modality=rgb_modality, rgb_indices=rgb_indices, allow_missing_modalities=allow_missing_modalities, - allow_substring_split_file=allow_substring_split_file, + allow_substring_file_names=allow_substring_file_names, dataset_bands=dataset_bands, output_bands=output_bands, constant_scale=constant_scale, @@ -797,9 +911,14 @@ def __init__( reduce_zero_label=reduce_zero_label, channel_position=channel_position, scalar_label=True, + concat_bands=concat_bands, *args, **kwargs, ) + self.num_classes = num_classes + self.class_names = class_names + + def __getitem__(self, index: int) -> dict[str, Any]: item = super().__getitem__(index) return item From d6a9ab99f311b14d6b436d9fa62aa0afe41f0392 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Thu, 28 Nov 2024 17:04:37 +0100 Subject: [PATCH 28/42] Fix type error Signed-off-by: Benedikt Blumenstiel --- examples/confs/multimae_sen1floods11.yaml | 8 ++--- .../multimodal_prithvi_sen1floods11.yaml | 2 +- .../generic_multimodal_data_module.py | 34 +++++++++---------- .../datasets/generic_multimodal_dataset.py | 12 +++---- 4 files changed, 28 insertions(+), 28 deletions(-) diff --git a/examples/confs/multimae_sen1floods11.yaml b/examples/confs/multimae_sen1floods11.yaml index 62a7acb2..9d779ab3 100644 --- a/examples/confs/multimae_sen1floods11.yaml +++ b/examples/confs/multimae_sen1floods11.yaml @@ -64,7 +64,7 @@ data: test_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_test.txt allow_substring_file_names: True - img_grep: + image_grep: S2L2A: "*_S2L2AHand.tif" S1: "*_S1Hand.tif" LULC: "*_LULCHand.npy" @@ -130,9 +130,9 @@ model: - S1 - S2L2A - LULC - decoder: UperNetDecoder # FCNDecoder - # decoder_num_convs: 4 # only for FCNDecoder - decoder_scale_modules: True # only for UperNetDecoder + decoder: FCNDecoder # UperNetDecoder + decoder_num_convs: 4 # only for FCNDecoder + # decoder_scale_modules: True # only for UperNetDecoder decoder_channels: 256 num_classes: 2 head_dropout: 0.1 diff --git a/examples/confs/multimodal_prithvi_sen1floods11.yaml b/examples/confs/multimodal_prithvi_sen1floods11.yaml index d25778a2..768e85be 100644 --- a/examples/confs/multimodal_prithvi_sen1floods11.yaml +++ b/examples/confs/multimodal_prithvi_sen1floods11.yaml @@ -61,7 +61,7 @@ data: test_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_test.txt allow_substring_file_names: True - img_grep: + image_grep: S2L2A: "*_S2L2AHand.tif" S1: "*_S1Hand.tif" label_grep: "*_LabelHand.tif" diff --git a/terratorch/datamodules/generic_multimodal_data_module.py b/terratorch/datamodules/generic_multimodal_data_module.py index 2101f9bb..8423ee86 100644 --- a/terratorch/datamodules/generic_multimodal_data_module.py +++ b/terratorch/datamodules/generic_multimodal_data_module.py @@ -153,26 +153,26 @@ def __init__( self, batch_size: int, modalities: list[str], - train_data_root: dict[Path], - val_data_root: dict[Path], - test_data_root: dict[Path], - means: dict[list], - stds: dict[list], + train_data_root: dict[str, Path], + val_data_root: dict[str, Path], + test_data_root: dict[str, Path], + means: dict[str, list], + stds: dict[str, list], task: str | None = None, num_classes: int | None = None, - image_grep: str | dict | None = None, + image_grep: str | dict[str, str] | None = None, label_grep: str | None = None, - train_label_data_root: Path | None = None, - val_label_data_root: Path | None = None, - test_label_data_root: Path | None = None, - predict_data_root: dict[Path] | None = None, - train_split: Path | None = None, - val_split: Path | None = None, - test_split: Path | None = None, - dataset_bands: dict[list] | None = None, - output_bands: dict[list] | None = None, - predict_dataset_bands: dict[list] | None = None, - predict_output_bands: dict[list] | None = None, + train_label_data_root: Path | str | None = None, + val_label_data_root: Path | str | None = None, + test_label_data_root: Path | str | None = None, + predict_data_root: dict[str, Path] | str | None = None, + train_split: Path | str | None = None, + val_split: Path | str | None = None, + test_split: Path| str | None = None, + dataset_bands: dict[str, list] | None = None, + output_bands: dict[str, list] | None = None, + predict_dataset_bands: dict[str, list] | None = None, + predict_output_bands: dict[str, list] | None = None, image_modalities: list[str] | None = None, rgb_modality: str | None = None, rgb_indices: list[int] | None = None, diff --git a/terratorch/datasets/generic_multimodal_dataset.py b/terratorch/datasets/generic_multimodal_dataset.py index df34b376..1c0af790 100644 --- a/terratorch/datasets/generic_multimodal_dataset.py +++ b/terratorch/datasets/generic_multimodal_dataset.py @@ -66,9 +66,9 @@ class GenericMultimodalDataset(NonGeoDataset, ABC): def __init__( self, - data_root: dict[Path], - label_data_root: Path | list[Path] | None = None, - image_grep: dict[str] | None = "*", + data_root: dict[str, Path | str], + label_data_root: Path | str | list[Path | str] | None = None, + image_grep: dict[str, str] | None = "*", label_grep: str | None = "*", split: Path | None = None, image_modalities: list[str] | None = None, @@ -76,9 +76,9 @@ def __init__( rgb_indices: list[int] | None = None, allow_missing_modalities: bool = False, allow_substring_file_names: bool = False, - dataset_bands: dict[list] | None = None, - output_bands: dict[list] | None = None, - constant_scale: dict[float] = None, + dataset_bands: dict[str, list] | None = None, + output_bands: dict[str, list] | None = None, + constant_scale: dict[str, float] = None, transform: A.Compose | dict | None = None, no_data_replace: float | None = None, no_label_replace: float | None = -1, From 9306b9374da4fe9446cacbf8a89ba29875113967 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Mon, 2 Dec 2024 09:39:34 +0100 Subject: [PATCH 29/42] Ensure image modality to be loaded first Signed-off-by: Benedikt Blumenstiel --- terratorch/datasets/generic_multimodal_dataset.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/terratorch/datasets/generic_multimodal_dataset.py b/terratorch/datasets/generic_multimodal_dataset.py index 1c0af790..ad92c161 100644 --- a/terratorch/datasets/generic_multimodal_dataset.py +++ b/terratorch/datasets/generic_multimodal_dataset.py @@ -154,13 +154,14 @@ def __init__( assert "mask" not in self.modalities, "Modality cannot be called 'mask'." self.image_modalities = image_modalities or self.modalities self.non_image_modalities = list(set(self.modalities) - set(image_modalities)) + self.modalities = self.image_modalities + self.non_image_modalities # Ensure image modalities to be first if scalar_label: self.non_image_modalities += ["label"] - # Convert path strings to lists as the code expects a list of paths per modality - for m, m_path in data_root.items(): - if not isinstance(m_path, list): - data_root[m] = [m_path] + # Order by modalities and convert path strings to lists as the code expects a list of paths per modality + data_root = {m: data_root[m] if isinstance(data_root[m], list) else [data_root[m]] + for m in self.modalities} + if label_data_root is not None and not isinstance(label_data_root, list): label_data_root = [label_data_root] From 6f49b6775132deb8590c5c72516331c13ef8f3bf Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Mon, 2 Dec 2024 09:40:08 +0100 Subject: [PATCH 30/42] Add possibility to pass input_size in kwargs Signed-off-by: Benedikt Blumenstiel --- terratorch/models/pixel_wise_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/terratorch/models/pixel_wise_model.py b/terratorch/models/pixel_wise_model.py index dc960926..6b9145c8 100644 --- a/terratorch/models/pixel_wise_model.py +++ b/terratorch/models/pixel_wise_model.py @@ -92,9 +92,13 @@ def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput: self.check_input_shape(x) if isinstance(x, torch.Tensor): input_size = x.shape[-2:] + elif hasattr(kwargs, 'image_size'): + input_size = kwargs['image_size'] elif isinstance(x, dict): # Multimodal input in passed as dict input_size = list(x.values())[0].shape[-2:] + else: + ValueError('Could not infer input shape.') features = self.encoder(x, **kwargs) ## only for backwards compatibility with pre-neck times. From d23dd629caa5f38bdbc8d9ba524d2cbdc70a560a Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Mon, 2 Dec 2024 09:41:09 +0100 Subject: [PATCH 31/42] Added check for rgb_modality Signed-off-by: Benedikt Blumenstiel --- terratorch/tasks/segmentation_tasks.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index ae1b98f1..da13c1b3 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -288,9 +288,15 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - try: datamodule = self.trainer.datamodule batch["prediction"] = y_hat_hard + if isinstance(batch["image"], dict): - # Multimodal input - batch["image"] = batch["image"][self.trainer.datamodule.rgb_modality] + if hasattr(datamodule, 'rgb_modality'): + # Generic multimodal dataset + batch["image"] = batch["image"][datamodule.rgb_modality] + else: + # Multimodal dataset. Assuming first item to be the modality to visualize. + batch["image"] = batch["image"][list(batch["image"].keys())[0]] + for key in ["image", "mask", "prediction"]: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] From 98416abd07aa85323ce4213efb70aa9a45f574ce Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Mon, 2 Dec 2024 11:05:09 +0100 Subject: [PATCH 32/42] Updated multimodal prithvi example Signed-off-by: Benedikt Blumenstiel --- .../multimodal_prithvi_sen1floods11.yaml | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/confs/multimodal_prithvi_sen1floods11.yaml b/examples/confs/multimodal_prithvi_sen1floods11.yaml index 768e85be..0a7633c8 100644 --- a/examples/confs/multimodal_prithvi_sen1floods11.yaml +++ b/examples/confs/multimodal_prithvi_sen1floods11.yaml @@ -22,7 +22,7 @@ trainer: monitor: val/loss patience: 40 - max_epochs: 5 + max_epochs: 100 check_val_every_n_epoch: 1 log_every_n_steps: 50 enable_checkpointing: True @@ -32,9 +32,9 @@ data: class_path: GenericMultiModalDataModule init_args: task: 'segmentation' - batch_size: 2 - num_workers: 0 - modalities: + batch_size: 16 + num_workers: 4 + modalities: # Define names of modalities - S2L2A - S1 rgb_modality: S2L2A # If not provided, uses first modality @@ -43,6 +43,7 @@ data: - 2 - 1 + # Data roots are defined as dicts with modalities as keys train_data_root: S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand @@ -56,9 +57,9 @@ data: S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand test_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand - train_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_train.txt - val_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_valid.txt - test_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_test.txt + train_split: data/sen1floods11/splits/splits/flood_handlabeled/flood_train_data.txt + val_split: data/sen1floods11/splits/splits/flood_handlabeled/flood_valid_data.txt + test_split: data/sen1floods11/splits/splits/flood_handlabeled/flood_test_data.txt allow_substring_file_names: True image_grep: @@ -67,8 +68,9 @@ data: label_grep: "*_LabelHand.tif" no_label_replace: -1 no_data_replace: 0 - concat_bands: true # Concatenate S2 and S2 bands + concat_bands: true # Concatenate modalities along band dim for single-modal models like Prithvi + # Define standardization values as dicts (no scaling if modality is not included) means: S2L2A: - 1793.243 @@ -107,6 +109,7 @@ data: num_classes: 2 + # Transforms are shared between all image modalities (e.g. same crop area) train_transform: - class_path: albumentations.RandomCrop init_args: From 43fbfea6f565b63dc84521053955f022372a9087 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 2 Dec 2024 17:27:34 -0300 Subject: [PATCH 33/42] updating documentation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- docs/quick_start.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/quick_start.md b/docs/quick_start.md index a7c1925c..446ea1b4 100644 --- a/docs/quick_start.md +++ b/docs/quick_start.md @@ -107,9 +107,8 @@ model_args = dict( HLSBands.SWIR_1, HLSBands.SWIR_2, ], - necks=[{"name": "SelectIndices", "indices": -1}, + necks=[{"name": "SelectIndices", "indices": [-1]}, {"name": "ReshapeTokensToImage"}], - num_classes=4, backbone_pretrained=True, backbone_num_frames=1, decoder_channels=128, From 17bd8bede370bca8c3c1f116e905d02cb23c603a Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Tue, 3 Dec 2024 09:35:47 +0100 Subject: [PATCH 34/42] Fix prithvi model Signed-off-by: Benedikt Blumenstiel --- terratorch/models/backbones/prithvi_mae.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index 2eb78236..7fee6027 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -283,7 +283,7 @@ def __init__(self, for i in range(depth): self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)) self.feature_info.append( - {"num_chs": embed_dim * self.patch_embed.patch_size[0], "reduction": 1, "module": f"blocks.{i}"} + {"num_chs": embed_dim * self.patch_embed.grid_size[0], "reduction": 1, "module": f"blocks.{i}"} ) self.blocks = nn.ModuleList(self.blocks) @@ -418,7 +418,7 @@ def forward_features( x = x + pos_embed[:, 1:, :] if self.temporal_encoding: - num_tokens_per_frame = x.shape[1] // self.patch_embed.num_frames + num_tokens_per_frame = x.shape[1] // self.num_frames temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame) x = x + temporal_encoding if self.location_encoding: From 543858726824f6f4f6749c8c0825ef034f25e0bc Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Tue, 3 Dec 2024 09:35:58 +0100 Subject: [PATCH 35/42] Add pretrained models Signed-off-by: Benedikt Blumenstiel --- terratorch/models/backbones/prithvi_vit.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index b56994c2..8d6ed2b4 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -30,10 +30,22 @@ "hf_hub_id": "ibm-nasa-geospatial/Prithvi-100M", "hf_hub_filename": "Prithvi_100M.pt", }, - "prithvi_eo_v2_300": {}, - "prithvi_eo_v2_300_tl": {}, - "prithvi_eo_v2_600": {}, - "prithvi_eo_v2_600_tl": {}, + "prithvi_eo_v2_300": { + "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-300M", + "hf_hub_filename": "Prithvi_EO_V2_300M.pt", + }, + "prithvi_eo_v2_300_tl": { + "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL", + "hf_hub_filename": "Prithvi_EO_V2_300M_TL.pt", + }, + "prithvi_eo_v2_600": { + "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-600M", + "hf_hub_filename": "Prithvi_EO_V2_600M.pt", + }, + "prithvi_eo_v2_600_tl": { + "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-600M-TL", + "hf_hub_filename": "Prithvi_EO_V2_600M_TL.pt", + }, "prithvi_vit_tiny": {} } ) From fd5ea2537e2ad275c64265219b46c4b98eadeef6 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Tue, 3 Dec 2024 09:51:55 +0100 Subject: [PATCH 36/42] Renamed Prithvi V1 weights Signed-off-by: Benedikt Blumenstiel --- terratorch/models/backbones/prithvi_vit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index 8d6ed2b4..6bee25fa 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -27,8 +27,8 @@ default_cfgs = generate_default_cfgs( { "prithvi_vit_100": { - "hf_hub_id": "ibm-nasa-geospatial/Prithvi-100M", - "hf_hub_filename": "Prithvi_100M.pt", + "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-1.0-100M", + "hf_hub_filename": "Prithvi_EO_V1_100M.pt", }, "prithvi_eo_v2_300": { "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-300M", From 5ef8036fe1981dce80ac3190dddc010a410a28ac Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Tue, 3 Dec 2024 11:46:37 +0100 Subject: [PATCH 37/42] Renamed Prithvi V1 weights Signed-off-by: Benedikt Blumenstiel --- terratorch/models/backbones/prithvi_vit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index 6bee25fa..9f9c8292 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -27,7 +27,7 @@ default_cfgs = generate_default_cfgs( { "prithvi_vit_100": { - "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-1.0-100M", + "hf_hub_id": "ibm-nasa-geospatial/Prithvi-100M", "hf_hub_filename": "Prithvi_EO_V1_100M.pt", }, "prithvi_eo_v2_300": { From 8ec8991a9be962f6af4bd3a69fe5db138caa482e Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Tue, 3 Dec 2024 14:25:24 +0100 Subject: [PATCH 38/42] Init prithvi from hf config Signed-off-by: Benedikt Blumenstiel --- terratorch/models/backbones/prithvi_mae.py | 4 +- terratorch/models/backbones/prithvi_vit.py | 115 ++++----------------- 2 files changed, 20 insertions(+), 99 deletions(-) diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index 7fee6027..8a7c5276 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -240,7 +240,7 @@ def __init__(self, depth: int = 24, num_heads: int = 16, mlp_ratio: float = 4., - norm_layer: nn.Module = nn.LayerNorm, + norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), coords_encoding: List[str] | None = None, coords_scale_learn: bool = False, encoder_only: bool = True, # needed for timm @@ -598,7 +598,7 @@ def __init__(self, decoder_depth: int = 8, decoder_num_heads: int = 16, mlp_ratio: float = 4., - norm_layer: nn.Module = nn.LayerNorm, + norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), norm_pix_loss: bool = False, coords_encoding: List[str] | None = None, coords_scale_learn: bool = False, diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index 9f9c8292..2cfdf223 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -4,11 +4,9 @@ import torch import logging from functools import partial - -from timm.models import FeatureInfo -from timm.models._builder import build_model_with_cfg -from timm.models._registry import generate_default_cfgs, register_model from torch import nn, Tensor +from timm.models import (FeatureInfo, load_model_config_from_hf, build_model_with_cfg, generate_default_cfgs, + register_model) from terratorch.datasets import HLSBands from terratorch.models.backbones.select_patch_embed_weights import select_patch_embed_weights @@ -50,6 +48,7 @@ } ) + def checkpoint_filter_fn_vit( state_dict, model: PrithviViT, pretrained_bands: list[HLSBands | int], model_bands: list[HLSBands | int] ) -> dict: @@ -124,6 +123,7 @@ def pad_images(imgs: Tensor,patch_size: int, padding:str) -> Tensor: ]) return imgs + def _create_prithvi( variant: str, pretrained: bool = False, # noqa: FBT001, FBT002 @@ -219,13 +219,13 @@ def forward_features_pad_images(*args, **kwargs): return model -def create_prithvi_vit_100( + +def create_prithvi_from_config( model_name: str, pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - """Prithvi ViT 100M""" pretrained_bands = PRETRAINED_BANDS if bands is None: bands = pretrained_bands @@ -234,102 +234,21 @@ def create_prithvi_vit_100( Pretrained patch_embed layer may be misaligned with current bands" ) - model_args = { - "patch_size": 16, - "embed_dim": 768, - "depth": 12, - "num_heads": 12, - "decoder_embed_dim": 512, - "decoder_depth": 8, - "decoder_num_heads": 16, - "mlp_ratio": 4, - "norm_layer": partial(nn.LayerNorm, eps=1e-6), - "num_frames": 1, - } + config, _ = load_model_config_from_hf(default_cfgs[model_name].default.hf_hub_id) + config.update(num_frames=1) # Assume one timestamp by default + config.update(kwargs) # Overwrite with keyword args model = _create_prithvi( model_name, pretrained=pretrained, model_bands=bands, pretrained_bands=pretrained_bands, - **dict(model_args,**kwargs), + **config, ) return model -def create_prithvi_vit_300( - model_name: str, - pretrained: bool = False, # noqa: FBT001, FBT002 - bands: list[HLSBands | int] | None = None, - **kwargs, -) -> PrithviViT: - """Prithvi ViT 300M""" - pretrained_bands = PRETRAINED_BANDS - if bands is None: - bands = pretrained_bands - logging.info( - f"Model bands not passed. Assuming bands are ordered in the same way as {PRETRAINED_BANDS}.\ - Pretrained patch_embed layer may be misaligned with current bands" - ) - model_args = { - "patch_size": 16, - "embed_dim": 1024, - "depth": 24, - "num_heads": 16, - "decoder_embed_dim": 512, - "decoder_depth": 8, - "decoder_num_heads": 16, - "mlp_ratio": 4, - "norm_layer": partial(nn.LayerNorm, eps=1e-6), - "num_frames": 1, - } - model = _create_prithvi( - model_name, - pretrained=pretrained, - pretrained_bands=pretrained_bands, - model_bands=bands, - **dict(model_args, **kwargs), - ) - return model - - -def create_prithvi_vit_600( - model_name: str, - pretrained: bool = False, # noqa: FBT001, FBT002 - bands: list[HLSBands] | None = None, - **kwargs, -) -> PrithviViT: - """Prithvi ViT 600M""" - pretrained_bands = PRETRAINED_BANDS - if bands is None: - bands = pretrained_bands - logging.info( - f"Model bands not passed. Assuming bands are ordered in the same way as {PRETRAINED_BANDS}.\ - Pretrained patch_embed layer may be misaligned with current bands" - ) - model_args = { - "patch_size": 14, - "embed_dim": 1280, - "depth": 32, - "num_heads": 16, - "decoder_embed_dim": 512, - "decoder_depth": 8, - "decoder_num_heads": 16, - "mlp_ratio": 4, - "norm_layer": partial(nn.LayerNorm, eps=1e-6), - "num_frames": 1, - } - model = _create_prithvi( - model_name, - pretrained=pretrained, - pretrained_bands=pretrained_bands, - model_bands=bands, - **dict(model_args, **kwargs), - ) - return model - - @register_model def prithvi_vit_tiny( bands: list[HLSBands | int] | None = None, @@ -352,16 +271,18 @@ def prithvi_vit_tiny( "num_frames": 1, "model_bands": bands, } - model = _create_prithvi("prithvi_vit_tiny", **dict(model_args, **kwargs)) + model_args.update(kwargs) + model = _create_prithvi("prithvi_vit_tiny", **model_args) return model + @register_model def prithvi_vit_100( pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - return create_prithvi_vit_100("prithvi_vit_100", pretrained, bands, **kwargs) + return create_prithvi_from_config("prithvi_vit_100", pretrained, bands, **kwargs) @register_model @@ -370,7 +291,7 @@ def prithvi_eo_v2_300( bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - return create_prithvi_vit_300("prithvi_eo_v2_300", pretrained, bands, **kwargs) + return create_prithvi_from_config("prithvi_eo_v2_300", pretrained, bands, **kwargs) @register_model @@ -379,7 +300,7 @@ def prithvi_eo_v2_600( bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - return create_prithvi_vit_600("prithvi_eo_v2_600", pretrained, bands, **kwargs) + return create_prithvi_from_config("prithvi_eo_v2_600", pretrained, bands, **kwargs) @register_model @@ -388,7 +309,7 @@ def prithvi_eo_v2_300_tl( bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - return create_prithvi_vit_300("prithvi_eo_v2_300_tl", pretrained, bands, **kwargs) + return create_prithvi_from_config("prithvi_eo_v2_300_tl", pretrained, bands, **kwargs) @register_model @@ -397,4 +318,4 @@ def prithvi_eo_v2_600_tl( bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - return create_prithvi_vit_600("prithvi_eo_v2_600_tl", pretrained, bands, **kwargs) \ No newline at end of file + return create_prithvi_from_config("prithvi_eo_v2_600_tl", pretrained, bands, **kwargs) From 93df86e1e4ba3e48ebb0a7215e067f0c6754cc26 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Tue, 3 Dec 2024 15:07:05 +0100 Subject: [PATCH 39/42] Fix patch_size lists in configs Signed-off-by: Benedikt Blumenstiel --- terratorch/models/backbones/prithvi_vit.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index 2cfdf223..69eaa6bc 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -146,6 +146,8 @@ def _create_prithvi( padding = kwargs.get("padding", "none") patch_size = kwargs.get("patch_size", 16) + if isinstance(patch_size, list): + patch_size = patch_size[-1] # Little hack because VIT does not support timm's features_only encoder_only = kwargs.pop("features_only", False) From 27a6a681d9cafb3360c9e2e52ef01bdea79fbac2 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Tue, 3 Dec 2024 15:30:39 +0100 Subject: [PATCH 40/42] Added default configs for prithvi Signed-off-by: Benedikt Blumenstiel --- terratorch/models/backbones/prithvi_mae.py | 10 +-- terratorch/models/backbones/prithvi_vit.py | 100 +++++++++++++++++++-- 2 files changed, 99 insertions(+), 11 deletions(-) diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index 8a7c5276..2c037487 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -589,11 +589,11 @@ class PrithviMAE(nn.Module): def __init__(self, img_size: int | Tuple[int, int] = 224, patch_size: int | Tuple[int, int, int] = (1, 16, 16), - num_frames: int = 3, - in_chans: int = 3, - embed_dim: int = 1024, - depth: int = 24, - num_heads: int = 16, + num_frames: int = 4, + in_chans: int = 6, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, decoder_embed_dim: int = 512, decoder_depth: int = 8, decoder_num_heads: int = 16, diff --git a/terratorch/models/backbones/prithvi_vit.py b/terratorch/models/backbones/prithvi_vit.py index 69eaa6bc..07a76a65 100644 --- a/terratorch/models/backbones/prithvi_vit.py +++ b/terratorch/models/backbones/prithvi_vit.py @@ -226,6 +226,7 @@ def create_prithvi_from_config( model_name: str, pretrained: bool = False, # noqa: FBT001, FBT002 bands: list[HLSBands] | None = None, + default_cfg: dict = None, **kwargs, ) -> PrithviViT: pretrained_bands = PRETRAINED_BANDS @@ -236,7 +237,11 @@ def create_prithvi_from_config( Pretrained patch_embed layer may be misaligned with current bands" ) - config, _ = load_model_config_from_hf(default_cfgs[model_name].default.hf_hub_id) + try: + config, _ = load_model_config_from_hf(default_cfgs[model_name].default.hf_hub_id) + except: + # No connection to hf + config = default_cfg config.update(num_frames=1) # Assume one timestamp by default config.update(kwargs) # Overwrite with keyword args @@ -284,7 +289,22 @@ def prithvi_vit_100( bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - return create_prithvi_from_config("prithvi_vit_100", pretrained, bands, **kwargs) + + default_config = { + "img_size": 224, + "patch_size": [1, 16, 16], + "num_frames": 3, + "in_chans": 6, + "embed_dim": 768, + "depth": 12, + "num_heads": 12, + "decoder_embed_dim": 512, + "decoder_depth": 8, + "decoder_num_heads": 16, + "mlp_ratio": 4, + } + + return create_prithvi_from_config("prithvi_vit_100", pretrained, bands, default_config, **kwargs) @register_model @@ -293,7 +313,24 @@ def prithvi_eo_v2_300( bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - return create_prithvi_from_config("prithvi_eo_v2_300", pretrained, bands, **kwargs) + + default_config = { + "img_size": 224, + "num_frames": 4, + "patch_size": [1, 16, 16], + "in_chans": 6, + "embed_dim": 1024, + "depth": 24, + "num_heads": 16, + "decoder_embed_dim": 512, + "decoder_depth": 8, + "decoder_num_heads": 16, + "mlp_ratio": 4, + "coords_encoding": [], + "coords_scale_learn": True, + } + + return create_prithvi_from_config("prithvi_eo_v2_300", pretrained, bands, default_config, **kwargs) @register_model @@ -302,7 +339,24 @@ def prithvi_eo_v2_600( bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - return create_prithvi_from_config("prithvi_eo_v2_600", pretrained, bands, **kwargs) + + default_config = { + "img_size": 224, + "num_frames": 4, + "patch_size": [1, 14, 14], + "in_chans": 6, + "embed_dim": 1280, + "depth": 32, + "num_heads": 16, + "decoder_embed_dim": 512, + "decoder_depth": 8, + "decoder_num_heads": 16, + "mlp_ratio": 4, + "coords_encoding": [], + "coords_scale_learn": True, + } + + return create_prithvi_from_config("prithvi_eo_v2_600", pretrained, bands, default_config, **kwargs) @register_model @@ -311,7 +365,24 @@ def prithvi_eo_v2_300_tl( bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - return create_prithvi_from_config("prithvi_eo_v2_300_tl", pretrained, bands, **kwargs) + + default_config = { + "img_size": 224, + "num_frames": 4, + "patch_size": [1, 16, 16], + "in_chans": 6, + "embed_dim": 1024, + "depth": 24, + "num_heads": 16, + "decoder_embed_dim": 512, + "decoder_depth": 8, + "decoder_num_heads": 16, + "mlp_ratio": 4, + "coords_encoding": ["time", "location"], + "coords_scale_learn": True, + } + + return create_prithvi_from_config("prithvi_eo_v2_300_tl", pretrained, bands, default_config, **kwargs) @register_model @@ -320,4 +391,21 @@ def prithvi_eo_v2_600_tl( bands: list[HLSBands] | None = None, **kwargs, ) -> PrithviViT: - return create_prithvi_from_config("prithvi_eo_v2_600_tl", pretrained, bands, **kwargs) + + default_config = { + "img_size": 224, + "num_frames": 4, + "patch_size": [1, 14, 14], + "in_chans": 6, + "embed_dim": 1280, + "depth": 32, + "num_heads": 16, + "decoder_embed_dim": 512, + "decoder_depth": 8, + "decoder_num_heads": 16, + "mlp_ratio": 4, + "coords_encoding": ["time", "location"], + "coords_scale_learn": True, + } + + return create_prithvi_from_config("prithvi_eo_v2_600_tl", pretrained, bands, default_config, **kwargs) From a3384b967354cd172010e5a352b1d32fc86b953e Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Tue, 3 Dec 2024 17:49:07 +0100 Subject: [PATCH 41/42] Added check for temp loc coords Signed-off-by: Benedikt Blumenstiel --- terratorch/models/backbones/prithvi_mae.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index 2c037487..c80737fe 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -373,11 +373,11 @@ def forward( # add pos embed w/o cls token x = x + pos_embed[:, 1:, :] - if self.temporal_encoding: + if self.temporal_encoding and temporal_coords is not None: num_tokens_per_frame = x.shape[1] // self.num_frames temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame) x = x + temporal_encoding - if self.location_encoding: + if self.location_encoding and location_coords is not None: location_encoding = self.location_embed_enc(location_coords) x = x + location_encoding @@ -417,11 +417,11 @@ def forward_features( # add pos embed w/o cls token x = x + pos_embed[:, 1:, :] - if self.temporal_encoding: + if self.temporal_encoding and temporal_coords is not None: num_tokens_per_frame = x.shape[1] // self.num_frames temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame) x = x + temporal_encoding - if self.location_encoding: + if self.location_encoding and location_coords is not None: location_encoding = self.location_embed_enc(location_coords) x = x + location_encoding @@ -556,12 +556,12 @@ def forward( # remove cls token x_ = x[:, 1:, :] - if self.temporal_encoding: + if self.temporal_encoding and temporal_coords is not None: num_tokens_per_frame = x_.shape[1] // self.num_frames temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame) # Add temporal encoding w/o cls token x_ = x_ + temporal_encoding - if self.location_encoding: + if self.location_encoding and location_coords is not None: location_encoding = self.location_embed_dec(location_coords) # Add location encoding w/o cls token x_ = x_ + location_encoding From 1daca1159bf9fd998e23f196b48a727582a1bdb8 Mon Sep 17 00:00:00 2001 From: Benedikt Blumenstiel Date: Tue, 3 Dec 2024 18:03:45 +0100 Subject: [PATCH 42/42] Changed back layer norm default Signed-off-by: Benedikt Blumenstiel --- terratorch/models/backbones/prithvi_mae.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index c80737fe..e8483207 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -240,7 +240,7 @@ def __init__(self, depth: int = 24, num_heads: int = 16, mlp_ratio: float = 4., - norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), + norm_layer: nn.Module = nn.LayerNorm, coords_encoding: List[str] | None = None, coords_scale_learn: bool = False, encoder_only: bool = True, # needed for timm @@ -598,7 +598,7 @@ def __init__(self, decoder_depth: int = 8, decoder_num_heads: int = 16, mlp_ratio: float = 4., - norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), + norm_layer: nn.Module = nn.LayerNorm, norm_pix_loss: bool = False, coords_encoding: List[str] | None = None, coords_scale_learn: bool = False,