From 63d4733fb3e110b4f389fe42da436eb1ed824be7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Thu, 18 Jul 2024 17:43:40 -0300 Subject: [PATCH 01/22] Bands could be define by intervals MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .../datasets/generic_pixel_wise_dataset.py | 39 +++++++++++++------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index 083a2241..ff0d6e6a 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -7,7 +7,7 @@ from abc import ABC from functools import partial from pathlib import Path -from typing import Any +from typing import Any, List, Union import albumentations as A import matplotlib as mpl @@ -122,18 +122,26 @@ def __init__( self.dataset_bands = dataset_bands self.output_bands = output_bands - if self.output_bands and not self.dataset_bands: - msg = "If output bands provided, dataset_bands must also be provided" - return Exception(msg) # noqa: PLE0101 - - if self.output_bands: - if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands): - msg = "Output bands must be a subset of dataset bands" - raise Exception(msg) - self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands] + + bands_by_interval = (self._bands_defined_by_interval(bands_list=dataset_bands) and + self._bands_defined_by_interval(bands_list=output_bands)) + + if not bands_by_interval: + if self.output_bands and not self.dataset_bands: + msg = "If output bands provided, dataset_bands must also be provided" + return Exception(msg) # noqa: PLE0101 + + if self.output_bands: + if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands): + msg = "Output bands must be a subset of dataset bands" + raise Exception(msg) + self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands] + else: + self.filter_indices = None else: - self.filter_indices = None - # If no transform is given, apply only to transform to torch tensor + pass + + # If no transform is given, apply only to transform to torch tensor self.transform = transform if transform else lambda **batch: to_tensor(batch) # self.transform = transform if transform else ToTensorV2() @@ -166,6 +174,13 @@ def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArr data = data.fillna(nan_replace) return data + def _bands_defined_by_interval(self, bands_list: List[int] | List[List[int]] = None) -> bool: + if all([type(band)==int for band in bands_list]): + return False + elif all([isinstance(band, list) for band in bands_list]): + return True + else: + raise Exception(f"Excpected List[int] or List[List[int]], but received {type(bands_list)}.") class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset): """GenericNonGeoSegmentationDataset""" From 87310198209715139330a253ea826f7a76a69416 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 19 Jul 2024 10:05:20 -0300 Subject: [PATCH 02/22] Constructing the bands using the definition by interval MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .../datasets/generic_pixel_wise_dataset.py | 44 ++++++++++++------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index ff0d6e6a..f97faf74 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -120,28 +120,30 @@ def __init__( ) self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices - self.dataset_bands = dataset_bands - self.output_bands = output_bands - bands_by_interval = (self._bands_defined_by_interval(bands_list=dataset_bands) and self._bands_defined_by_interval(bands_list=output_bands)) - if not bands_by_interval: - if self.output_bands and not self.dataset_bands: - msg = "If output bands provided, dataset_bands must also be provided" - return Exception(msg) # noqa: PLE0101 - - if self.output_bands: - if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands): - msg = "Output bands must be a subset of dataset bands" - raise Exception(msg) - self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands] - else: - self.filter_indices = None + # If the bands are defined by sub-intervals or not. + if bands_by_interval: + self.dataset_bands = self._generate_bands_intervals(dataset_bands) + self.output_bands = self._generate_bands_intervals(output_bands) + else: + self.dataset_bands = dataset_bands + self.output_bands = output_bands + + if self.output_bands and not self.dataset_bands: + msg = "If output bands provided, dataset_bands must also be provided" + return Exception(msg) # noqa: PLE0101 + + if self.output_bands: + if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands): + msg = "Output bands must be a subset of dataset bands" + raise Exception(msg) + self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands] else: - pass + self.filter_indices = None - # If no transform is given, apply only to transform to torch tensor + # If no transform is given, apply only to transform to torch tensor self.transform = transform if transform else lambda **batch: to_tensor(batch) # self.transform = transform if transform else ToTensorV2() @@ -174,6 +176,14 @@ def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArr data = data.fillna(nan_replace) return data + def _generate_bands_intervals(self, bands_intervals:List[List[int]] = None): + bands = list() + for b_interval in bands_intervals: + b_interval[-1] += 1 + bands_sublist = np.arange(*b_interval).astype(int) + bands.append(bands_sublist) + return sorted(sum(bands, [])) + def _bands_defined_by_interval(self, bands_list: List[int] | List[List[int]] = None) -> bool: if all([type(band)==int for band in bands_list]): return False From 1dd66502de99fc00593fd985d3b34b7623ccddb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 19 Jul 2024 10:29:09 -0300 Subject: [PATCH 03/22] Extending the supported formats for bands to include list[int] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/datasets/generic_pixel_wise_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index f97faf74..fbb9a646 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -185,7 +185,7 @@ def _generate_bands_intervals(self, bands_intervals:List[List[int]] = None): return sorted(sum(bands, [])) def _bands_defined_by_interval(self, bands_list: List[int] | List[List[int]] = None) -> bool: - if all([type(band)==int for band in bands_list]): + if all([type(band)==int or isinstance(band, HLSBands) for band in bands_list]): return False elif all([isinstance(band, list) for band in bands_list]): return True @@ -373,8 +373,8 @@ def __init__( ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, rgb_indices: list[int] | None = None, - dataset_bands: list[HLSBands | int] | None = None, - output_bands: list[HLSBands | int] | None = None, + dataset_bands: list[HLSBands | int | list[int]] | None = None, + output_bands: list[HLSBands | int | list[int]] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, no_data_replace: float | None = None, From a0ce8aa8dc1b98c843027788949411ad1e6b1aaa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 19 Jul 2024 11:12:33 -0300 Subject: [PATCH 04/22] Extending the supported formats for bands to include list[int] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .../generic_pixel_wise_data_module.py | 12 +++++------ .../datasets/generic_pixel_wise_dataset.py | 20 ++++++++++--------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/terratorch/datamodules/generic_pixel_wise_data_module.py b/terratorch/datamodules/generic_pixel_wise_data_module.py index 16c4a31c..050a1583 100644 --- a/terratorch/datamodules/generic_pixel_wise_data_module.py +++ b/terratorch/datamodules/generic_pixel_wise_data_module.py @@ -91,9 +91,9 @@ def __init__( test_split: Path | None = None, ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, - dataset_bands: list[HLSBands | int] | None = None, - predict_dataset_bands: list[HLSBands | int] | None = None, - output_bands: list[HLSBands | int] | None = None, + dataset_bands: list[HLSBands | int | list[int]] | None = None, + predict_dataset_bands: list[HLSBands | int | list[int]] | None = None, + output_bands: list[HLSBands | int | list[int]] | None = None, constant_scale: float = 1, rgb_indices: list[int] | None = None, train_transform: A.Compose | None | list[A.BasicTransform] = None, @@ -330,9 +330,9 @@ def __init__( test_split: Path | None = None, ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, - dataset_bands: list[HLSBands | int] | None = None, - predict_dataset_bands: list[HLSBands | int] | None = None, - output_bands: list[HLSBands | int] | None = None, + dataset_bands: list[HLSBands | int | list[int]] | None = None, + predict_dataset_bands: list[HLSBands | int | list[int]] | None = None, + output_bands: list[HLSBands | int | list[int]] | None = None, constant_scale: float = 1, rgb_indices: list[int] | None = None, train_transform: A.Compose | None | list[A.BasicTransform] = None, diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index fbb9a646..90bf0e7d 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -43,8 +43,8 @@ def __init__( ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, rgb_indices: list[int] | None = None, - dataset_bands: list[HLSBands | int] | None = None, - output_bands: list[HLSBands | int] | None = None, + dataset_bands: list[HLSBands | int | list[int]] | None = None, + output_bands: list[HLSBands | int | list[int]] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, no_data_replace: float | None = None, @@ -179,16 +179,18 @@ def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArr def _generate_bands_intervals(self, bands_intervals:List[List[int]] = None): bands = list() for b_interval in bands_intervals: - b_interval[-1] += 1 - bands_sublist = np.arange(*b_interval).astype(int) + bands_sublist = np.arange(b_interval[0], b_interval[1] + 1).astype(int).tolist() bands.append(bands_sublist) return sorted(sum(bands, [])) - def _bands_defined_by_interval(self, bands_list: List[int] | List[List[int]] = None) -> bool: + def _bands_defined_by_interval(self, bands_list: list[int] | list[list[int]] = None) -> bool: if all([type(band)==int or isinstance(band, HLSBands) for band in bands_list]): return False - elif all([isinstance(band, list) for band in bands_list]): - return True + elif all([isinstance(subinterval, list) for subinterval in bands_list]): + if all([type(band)==int for band in sum(bands_list, [])]): + return True + else: + raise Exception(f"Whe using subintervals, the limits must be int.") else: raise Exception(f"Excpected List[int] or List[List[int]], but received {type(bands_list)}.") @@ -206,8 +208,8 @@ def __init__( ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, rgb_indices: list[str] | None = None, - dataset_bands: list[HLSBands | int] | None = None, - output_bands: list[HLSBands | int] | None = None, + dataset_bands: list[HLSBands | int | list[int]] | None = None, + output_bands: list[HLSBands | int | list[int]] | None = None, class_names: list[str] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, From 295128fe3f30c7bf8dcecb40eca5fa0c8e2a0e74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 19 Jul 2024 11:13:17 -0300 Subject: [PATCH 05/22] Testing the definition by interval using a dedicated yaml file MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- ...finetune_prithvi_swin_B_band_interval.yaml | 136 ++++++++++++++++++ tests/test_finetune.py | 13 ++ 2 files changed, 149 insertions(+) create mode 100644 tests/manufactured-finetune_prithvi_swin_B_band_interval.yaml diff --git a/tests/manufactured-finetune_prithvi_swin_B_band_interval.yaml b/tests/manufactured-finetune_prithvi_swin_B_band_interval.yaml new file mode 100644 index 00000000..8697cd63 --- /dev/null +++ b/tests/manufactured-finetune_prithvi_swin_B_band_interval.yaml @@ -0,0 +1,136 @@ +# lightning.pytorch==2.1.1 +seed_everything: 42 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + # precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: tests/ + name: all_ecos_random + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 100 + max_epochs: 5 + check_val_every_n_epoch: 1 + log_every_n_steps: 20 + enable_checkpointing: true + default_root_dir: tests/ +data: + class_path: GenericNonGeoPixelwiseRegressionDataModule + init_args: + batch_size: 2 + num_workers: 4 + train_transform: + - class_path: albumentations.HorizontalFlip + init_args: + p: 0.5 + - class_path: albumentations.Rotate + init_args: + limit: 30 + border_mode: 0 # cv2.BORDER_CONSTANT + value: 0 + # mask_value: 1 + p: 0.5 + - class_path: ToTensorV2 + dataset_bands: + - [0, 11] + output_bands: + - [1, 3] + - [4, 6] + rgb_indices: + - 2 + - 1 + - 0 + train_data_root: tests/ + train_label_data_root: tests/ + val_data_root: tests/ + val_label_data_root: tests/ + test_data_root: tests/ + test_label_data_root: tests/ + img_grep: "regression*input*.tif" + label_grep: "regression*label*.tif" + means: + - 547.36707 + - 898.5121 + - 1020.9082 + - 2665.5352 + - 2340.584 + - 1610.1407 + stds: + - 411.4701 + - 558.54065 + - 815.94025 + - 812.4403 + - 1113.7145 + - 1067.641 + no_label_replace: -1 + no_data_replace: 0 + +model: + class_path: terratorch.tasks.PixelwiseRegressionTask + init_args: + model_args: + decoder: UperNetDecoder + pretrained: true + backbone: prithvi_swin_B + backbone_pretrained_cfg_overlay: + file: tests/prithvi_swin_B.pt + backbone_drop_path_rate: 0.3 + # backbone_window_size: 8 + decoder_channels: 256 + in_channels: 6 + bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + num_frames: 1 + head_dropout: 0.5708022831486758 + head_final_act: torch.nn.ReLU + head_learned_upscale_layers: 2 + loss: rmse + #aux_heads: + # - name: aux_head + # decoder: IdentityDecoder + # decoder_args: + # decoder_out_index: 2 + # head_dropout: 0,5 + # head_channel_list: + # - 64 + # head_final_act: torch.nn.ReLU + #aux_loss: + # aux_head: 0.4 + ignore_index: -1 + freeze_backbone: true + freeze_decoder: false + model_factory: PrithviModelFactory + + # uncomment this block for tiled inference + # tiled_inference_parameters: + # h_crop: 224 + # h_stride: 192 + # w_crop: 224 + # w_stride: 192 + # average_patches: true +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.00013524680528283027 + weight_decay: 0.047782217873995426 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss + diff --git a/tests/test_finetune.py b/tests/test_finetune.py index bb48b94e..3b18c812 100644 --- a/tests/test_finetune.py +++ b/tests/test_finetune.py @@ -23,6 +23,19 @@ def test_finetune_multiple_backbones(model_name): command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}.yaml"] _ = build_lightning_cli(command_list) +@pytest.mark.parametrize("model_name", ["prithvi_swin_B"]) +def test_finetune_bands_intervals(model_name): + + model_instance = timm.create_model(model_name) + + state_dict = model_instance.state_dict() + + torch.save(state_dict, os.path.join("tests/", model_name + ".pt")) + + # Running the terratorch CLI + command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_band_interval.yaml"] + _ = build_lightning_cli(command_list) + """ @pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"]) def test_finetune_multiple_backbones(model_name): From 64dcf5d387e54f4710edf97cc625848ab7305824 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Fri, 19 Jul 2024 11:34:05 -0300 Subject: [PATCH 06/22] Special case for bands_list=:None MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/datasets/generic_pixel_wise_dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index 90bf0e7d..e68a0a3d 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -184,7 +184,9 @@ def _generate_bands_intervals(self, bands_intervals:List[List[int]] = None): return sorted(sum(bands, [])) def _bands_defined_by_interval(self, bands_list: list[int] | list[list[int]] = None) -> bool: - if all([type(band)==int or isinstance(band, HLSBands) for band in bands_list]): + if not bands_list: + return False + elif all([type(band)==int or isinstance(band, HLSBands) for band in bands_list]): return False elif all([isinstance(subinterval, list) for subinterval in bands_list]): if all([type(band)==int for band in sum(bands_list, [])]): From e48965f712f5f48d5b5d31483dd82988b4c9f4fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 22 Jul 2024 11:09:40 -0300 Subject: [PATCH 07/22] Basic support to use simple strings to name the bands MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .../datasets/generic_pixel_wise_dataset.py | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index e68a0a3d..96de29f2 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -130,11 +130,18 @@ def __init__( else: self.dataset_bands = dataset_bands self.output_bands = output_bands + + bands_type = self._bands_as_int_or_str(dataset_bands, output_bands) + + if bands_type == str: + raise UserWarning("When the bands are defined as str, guarantee your input files"+ + "are organized by band and all have its specific name.") if self.output_bands and not self.dataset_bands: msg = "If output bands provided, dataset_bands must also be provided" return Exception(msg) # noqa: PLE0101 + # There is a special condition if the bands are defined as simple strings. if self.output_bands: if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands): msg = "Output bands must be a subset of dataset bands" @@ -183,10 +190,25 @@ def _generate_bands_intervals(self, bands_intervals:List[List[int]] = None): bands.append(bands_sublist) return sorted(sum(bands, [])) + def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type: + + band_type = [None, None] + for b, bands_list in enumerate([dataset_bands, output_bands]): + if all([type(band)==int for band in bands_list]): + band_type[b] = int + elif all([type(band)==str for band in bands_list]): + band_type[b] = str + else: + pass + if band_type.cound(band_type[0]) == len(band_type) + return band_type[0] + else: + raise Exception("The bands must be or all str or all int.") + def _bands_defined_by_interval(self, bands_list: list[int] | list[list[int]] = None) -> bool: if not bands_list: return False - elif all([type(band)==int or isinstance(band, HLSBands) for band in bands_list]): + elif all([type(band)==int or type(band)==str or isinstance(band, HLSBands) for band in bands_list]): return False elif all([isinstance(subinterval, list) for subinterval in bands_list]): if all([type(band)==int for band in sum(bands_list, [])]): @@ -194,7 +216,7 @@ def _bands_defined_by_interval(self, bands_list: list[int] | list[list[int]] = N else: raise Exception(f"Whe using subintervals, the limits must be int.") else: - raise Exception(f"Excpected List[int] or List[List[int]], but received {type(bands_list)}.") + raise Exception(f"Excpected List[int] or List[str] or List[List[int]], but received {type(bands_list)}.") class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset): """GenericNonGeoSegmentationDataset""" From 5dba48101bfd6a5a7e80d647b1e2056174440b27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 22 Jul 2024 13:25:17 -0300 Subject: [PATCH 08/22] Strings are allowed to define bands MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .../generic_pixel_wise_data_module.py | 12 ++++++------ .../datasets/generic_pixel_wise_dataset.py | 17 ++++++++++------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/terratorch/datamodules/generic_pixel_wise_data_module.py b/terratorch/datamodules/generic_pixel_wise_data_module.py index 050a1583..a79f82fb 100644 --- a/terratorch/datamodules/generic_pixel_wise_data_module.py +++ b/terratorch/datamodules/generic_pixel_wise_data_module.py @@ -91,9 +91,9 @@ def __init__( test_split: Path | None = None, ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, - dataset_bands: list[HLSBands | int | list[int]] | None = None, - predict_dataset_bands: list[HLSBands | int | list[int]] | None = None, - output_bands: list[HLSBands | int | list[int]] | None = None, + dataset_bands: list[HLSBands | int | list[int] | str] | None = None, + predict_dataset_bands: list[HLSBands | int | list[int] | str ] | None = None, + output_bands: list[HLSBands | int | list[int] | str] | None = None, constant_scale: float = 1, rgb_indices: list[int] | None = None, train_transform: A.Compose | None | list[A.BasicTransform] = None, @@ -330,9 +330,9 @@ def __init__( test_split: Path | None = None, ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, - dataset_bands: list[HLSBands | int | list[int]] | None = None, - predict_dataset_bands: list[HLSBands | int | list[int]] | None = None, - output_bands: list[HLSBands | int | list[int]] | None = None, + dataset_bands: list[HLSBands | int | list[int] | str ] | None = None, + predict_dataset_bands: list[HLSBands | int | list[int] | str ] | None = None, + output_bands: list[HLSBands | int | list[int] | str ] | None = None, constant_scale: float = 1, rgb_indices: list[int] | None = None, train_transform: A.Compose | None | list[A.BasicTransform] = None, diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index 96de29f2..ce9991aa 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -88,6 +88,7 @@ def __init__( expected 0. Defaults to False. """ super().__init__() + self.split_file = split label_data_root = label_data_root if label_data_root is not None else data_root @@ -136,7 +137,7 @@ def __init__( if bands_type == str: raise UserWarning("When the bands are defined as str, guarantee your input files"+ "are organized by band and all have its specific name.") - + if self.output_bands and not self.dataset_bands: msg = "If output bands provided, dataset_bands must also be provided" return Exception(msg) # noqa: PLE0101 @@ -146,7 +147,9 @@ def __init__( if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands): msg = "Output bands must be a subset of dataset bands" raise Exception(msg) + self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands] + else: self.filter_indices = None @@ -176,7 +179,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: if self.transform: output = self.transform(**output) return output - + def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArray: data = rioxarray.open_rasterio(path, masked=True) if nan_replace is not None: @@ -200,7 +203,7 @@ def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type: band_type[b] = str else: pass - if band_type.cound(band_type[0]) == len(band_type) + if band_type.cound(band_type[0]) == len(band_type): return band_type[0] else: raise Exception("The bands must be or all str or all int.") @@ -232,8 +235,8 @@ def __init__( ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, rgb_indices: list[str] | None = None, - dataset_bands: list[HLSBands | int | list[int]] | None = None, - output_bands: list[HLSBands | int | list[int]] | None = None, + dataset_bands: list[HLSBands | int | list[int] | str ] | None = None, + output_bands: list[HLSBands | int | list[int] | str ] | None = None, class_names: list[str] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, @@ -399,8 +402,8 @@ def __init__( ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, rgb_indices: list[int] | None = None, - dataset_bands: list[HLSBands | int | list[int]] | None = None, - output_bands: list[HLSBands | int | list[int]] | None = None, + dataset_bands: list[HLSBands | int | list[int] | str ] | None = None, + output_bands: list[HLSBands | int | list[int] | str ] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, no_data_replace: float | None = None, From 174f2f16609c967a939c0aae30bbd9b2abf16785 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 22 Jul 2024 14:25:18 -0300 Subject: [PATCH 09/22] Testing to use strings to define a model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/datasets/generic_pixel_wise_dataset.py | 4 ++-- tests/test_finetune.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index ce9991aa..d486348b 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -135,7 +135,7 @@ def __init__( bands_type = self._bands_as_int_or_str(dataset_bands, output_bands) if bands_type == str: - raise UserWarning("When the bands are defined as str, guarantee your input files"+ + UserWarning("When the bands are defined as str, guarantee your input files"+ "are organized by band and all have its specific name.") if self.output_bands and not self.dataset_bands: @@ -203,7 +203,7 @@ def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type: band_type[b] = str else: pass - if band_type.cound(band_type[0]) == len(band_type): + if band_type.count(band_type[0]) == len(band_type): return band_type[0] else: raise Exception("The bands must be or all str or all int.") diff --git a/tests/test_finetune.py b/tests/test_finetune.py index 3b18c812..535f96ed 100644 --- a/tests/test_finetune.py +++ b/tests/test_finetune.py @@ -36,6 +36,20 @@ def test_finetune_bands_intervals(model_name): command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_band_interval.yaml"] _ = build_lightning_cli(command_list) +@pytest.mark.parametrize("model_name", ["prithvi_swin_B"]) +def test_finetune_bands_intervals(model_name): + + model_instance = timm.create_model(model_name) + + state_dict = model_instance.state_dict() + + torch.save(state_dict, os.path.join("tests/", model_name + ".pt")) + + # Running the terratorch CLI + command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_string.yaml"] + _ = build_lightning_cli(command_list) + + """ @pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"]) def test_finetune_multiple_backbones(model_name): From 989bf8003d4d722273570ade2ccdd4734bb7bcef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 22 Jul 2024 15:12:16 -0300 Subject: [PATCH 10/22] Exception for None inputs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .../datasets/generic_pixel_wise_dataset.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index d486348b..42564ccd 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -174,6 +174,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: "mask": self._load_file(self.segmentation_mask_files[index], nan_replace = self.no_label_replace).to_numpy()[0], "filename": self.image_files[index], } + if self.reduce_zero_label: output["mask"] -= 1 if self.transform: @@ -196,17 +197,20 @@ def _generate_bands_intervals(self, bands_intervals:List[List[int]] = None): def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type: band_type = [None, None] - for b, bands_list in enumerate([dataset_bands, output_bands]): - if all([type(band)==int for band in bands_list]): - band_type[b] = int - elif all([type(band)==str for band in bands_list]): - band_type[b] = str - else: - pass - if band_type.count(band_type[0]) == len(band_type): - return band_type[0] + if not dataset_bands and not output_bands: + return None else: - raise Exception("The bands must be or all str or all int.") + for b, bands_list in enumerate([dataset_bands, output_bands]): + if all([type(band)==int for band in bands_list]): + band_type[b] = int + elif all([type(band)==str for band in bands_list]): + band_type[b] = str + else: + pass + if band_type.count(band_type[0]) == len(band_type): + return band_type[0] + else: + raise Exception("The bands must be or all str or all int.") def _bands_defined_by_interval(self, bands_list: list[int] | list[list[int]] = None) -> bool: if not bands_list: From 831c662f44f6e4cb9cf191af78168581b37109e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 22 Jul 2024 15:21:22 -0300 Subject: [PATCH 11/22] Support for str MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/datasets/generic_pixel_wise_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index 42564ccd..2262ae1c 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -43,8 +43,8 @@ def __init__( ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, rgb_indices: list[int] | None = None, - dataset_bands: list[HLSBands | int | list[int]] | None = None, - output_bands: list[HLSBands | int | list[int]] | None = None, + dataset_bands: list[HLSBands | int | list[int] | str ] | None = None, + output_bands: list[HLSBands | int | list[int] | str ] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, no_data_replace: float | None = None, From de533dd86462cfa8320d325bd58e6e646fc72359 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 22 Jul 2024 15:25:02 -0300 Subject: [PATCH 12/22] YAML file for testing string as bands MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- ...ctured-finetune_prithvi_swin_B_string.yaml | 149 ++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 tests/manufactured-finetune_prithvi_swin_B_string.yaml diff --git a/tests/manufactured-finetune_prithvi_swin_B_string.yaml b/tests/manufactured-finetune_prithvi_swin_B_string.yaml new file mode 100644 index 00000000..a7aa84c2 --- /dev/null +++ b/tests/manufactured-finetune_prithvi_swin_B_string.yaml @@ -0,0 +1,149 @@ +# lightning.pytorch==2.1.1 +seed_everything: 42 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + # precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: tests/ + name: all_ecos_random + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 100 + max_epochs: 5 + check_val_every_n_epoch: 1 + log_every_n_steps: 20 + enable_checkpointing: true + default_root_dir: tests/ +data: + class_path: GenericNonGeoPixelwiseRegressionDataModule + init_args: + batch_size: 2 + num_workers: 4 + train_transform: + - class_path: albumentations.HorizontalFlip + init_args: + p: 0.5 + - class_path: albumentations.Rotate + init_args: + limit: 30 + border_mode: 0 # cv2.BORDER_CONSTANT + value: 0 + # mask_value: 1 + p: 0.5 + - class_path: ToTensorV2 + dataset_bands: + - "band_1" + - "band_2" + - "band_3" + - "band_4" + - "band_5" + - "band_6" + - "band_7" + - "band_8" + - "band_9" + - "band_10" + output_bands: + - "band_2" + - "band_3" + - "band_4" + - "band_5" + - "band_6" + - "band_7" + rgb_indices: + - 2 + - 1 + - 0 + train_data_root: tests/ + train_label_data_root: tests/ + val_data_root: tests/ + val_label_data_root: tests/ + test_data_root: tests/ + test_label_data_root: tests/ + img_grep: "regression*input*.tif" + label_grep: "regression*label*.tif" + means: + - 547.36707 + - 898.5121 + - 1020.9082 + - 2665.5352 + - 2340.584 + - 1610.1407 + stds: + - 411.4701 + - 558.54065 + - 815.94025 + - 812.4403 + - 1113.7145 + - 1067.641 + no_label_replace: -1 + no_data_replace: 0 + +model: + class_path: terratorch.tasks.PixelwiseRegressionTask + init_args: + model_args: + decoder: UperNetDecoder + pretrained: true + backbone: prithvi_swin_B + backbone_pretrained_cfg_overlay: + file: tests/prithvi_swin_B.pt + backbone_drop_path_rate: 0.3 + # backbone_window_size: 8 + decoder_channels: 256 + in_channels: 6 + bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + num_frames: 1 + head_dropout: 0.5708022831486758 + head_final_act: torch.nn.ReLU + head_learned_upscale_layers: 2 + loss: rmse + #aux_heads: + # - name: aux_head + # decoder: IdentityDecoder + # decoder_args: + # decoder_out_index: 2 + # head_dropout: 0,5 + # head_channel_list: + # - 64 + # head_final_act: torch.nn.ReLU + #aux_loss: + # aux_head: 0.4 + ignore_index: -1 + freeze_backbone: true + freeze_decoder: false + model_factory: PrithviModelFactory + + # uncomment this block for tiled inference + # tiled_inference_parameters: + # h_crop: 224 + # h_stride: 192 + # w_crop: 224 + # w_stride: 192 + # average_patches: true +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 0.00013524680528283027 + weight_decay: 0.047782217873995426 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss + From 7a94a829b12b66e3ea9ccba7b917225e2dad8284 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 22 Jul 2024 15:46:18 -0300 Subject: [PATCH 13/22] This test is no longer required MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- tests/test_finetune.py | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/tests/test_finetune.py b/tests/test_finetune.py index 535f96ed..cd4f5356 100644 --- a/tests/test_finetune.py +++ b/tests/test_finetune.py @@ -37,7 +37,7 @@ def test_finetune_bands_intervals(model_name): _ = build_lightning_cli(command_list) @pytest.mark.parametrize("model_name", ["prithvi_swin_B"]) -def test_finetune_bands_intervals(model_name): +def test_finetune_bands_str(model_name): model_instance = timm.create_model(model_name) @@ -49,22 +49,3 @@ def test_finetune_bands_intervals(model_name): command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_string.yaml"] _ = build_lightning_cli(command_list) - -""" -@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"]) -def test_finetune_multiple_backbones(model_name): - - model_instance = timm.create_model(model_name) - pretrained_bands = [0, 1, 2, 3, 4, 5] - model_bands = [0, 1, 2, 3, 4, 5] - - state_dict = model_instance.state_dict() - - torch.save(state_dict, os.path.join("tests/", model_name + ".pt")) - - # Running the terratorch CLI - command_str = f"python terratorch/__main__.py fit -c tests/manufactured-finetune_{model_name}.yaml" - command_out = subprocess.run(command_str, shell=True) - - assert not command_out.returncode - """ From 0a84c42e0975fcc0a57740f4dc6b3340d3e956a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Tue, 23 Jul 2024 10:12:19 -0300 Subject: [PATCH 14/22] Band intervals should be tuples with two entries MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .../generic_pixel_wise_data_module.py | 12 +++++----- .../datasets/generic_pixel_wise_dataset.py | 22 ++++++++++--------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/terratorch/datamodules/generic_pixel_wise_data_module.py b/terratorch/datamodules/generic_pixel_wise_data_module.py index a79f82fb..8dba1c20 100644 --- a/terratorch/datamodules/generic_pixel_wise_data_module.py +++ b/terratorch/datamodules/generic_pixel_wise_data_module.py @@ -91,9 +91,9 @@ def __init__( test_split: Path | None = None, ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, - dataset_bands: list[HLSBands | int | list[int] | str] | None = None, - predict_dataset_bands: list[HLSBands | int | list[int] | str ] | None = None, - output_bands: list[HLSBands | int | list[int] | str] | None = None, + dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, + predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, + output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, constant_scale: float = 1, rgb_indices: list[int] | None = None, train_transform: A.Compose | None | list[A.BasicTransform] = None, @@ -330,9 +330,9 @@ def __init__( test_split: Path | None = None, ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, - dataset_bands: list[HLSBands | int | list[int] | str ] | None = None, - predict_dataset_bands: list[HLSBands | int | list[int] | str ] | None = None, - output_bands: list[HLSBands | int | list[int] | str ] | None = None, + dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, + predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, + output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, constant_scale: float = 1, rgb_indices: list[int] | None = None, train_transform: A.Compose | None | list[A.BasicTransform] = None, diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index 2262ae1c..4525ffdb 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -43,8 +43,8 @@ def __init__( ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, rgb_indices: list[int] | None = None, - dataset_bands: list[HLSBands | int | list[int] | str ] | None = None, - output_bands: list[HLSBands | int | list[int] | str ] | None = None, + dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, + output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, no_data_replace: float | None = None, @@ -212,18 +212,20 @@ def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type: else: raise Exception("The bands must be or all str or all int.") - def _bands_defined_by_interval(self, bands_list: list[int] | list[list[int]] = None) -> bool: + def _bands_defined_by_interval(self, bands_list: list[int] | list[tuple[int]] = None) -> bool: if not bands_list: return False elif all([type(band)==int or type(band)==str or isinstance(band, HLSBands) for band in bands_list]): return False - elif all([isinstance(subinterval, list) for subinterval in bands_list]): - if all([type(band)==int for band in sum(bands_list, [])]): + elif all([isinstance(subinterval, tuple) for subinterval in bands_list]): + bands_list_ = [list(subinterval) for subinterval in bands_list] + if all([type(band)==int for band in sum(bands_list_, [])]): return True else: raise Exception(f"Whe using subintervals, the limits must be int.") else: - raise Exception(f"Excpected List[int] or List[str] or List[List[int]], but received {type(bands_list)}.") + print(bands_list) + raise Exception(f"Excpected List[int] or List[str] or List[tuple[int, int]], but received {type(bands_list)}.") class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset): """GenericNonGeoSegmentationDataset""" @@ -239,8 +241,8 @@ def __init__( ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, rgb_indices: list[str] | None = None, - dataset_bands: list[HLSBands | int | list[int] | str ] | None = None, - output_bands: list[HLSBands | int | list[int] | str ] | None = None, + dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, + output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, class_names: list[str] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, @@ -406,8 +408,8 @@ def __init__( ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, rgb_indices: list[int] | None = None, - dataset_bands: list[HLSBands | int | list[int] | str ] | None = None, - output_bands: list[HLSBands | int | list[int] | str ] | None = None, + dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, + output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, no_data_replace: float | None = None, From 288825243c296b1a93f06af6cb6790f238000b46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Tue, 23 Jul 2024 10:27:07 -0300 Subject: [PATCH 15/22] More compact way to check if the bands are defined by interval MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .../datasets/generic_pixel_wise_dataset.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index 4525ffdb..42ebbdb0 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -121,11 +121,10 @@ def __init__( ) self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices - bands_by_interval = (self._bands_defined_by_interval(bands_list=dataset_bands) and - self._bands_defined_by_interval(bands_list=output_bands)) + is_bands_by_interval = self._check_if_its_defined_by_interval(dataset_bands, output_bands) # If the bands are defined by sub-intervals or not. - if bands_by_interval: + if is_bands_by_interval: self.dataset_bands = self._generate_bands_intervals(dataset_bands) self.output_bands = self._generate_bands_intervals(output_bands) else: @@ -212,6 +211,19 @@ def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type: else: raise Exception("The bands must be or all str or all int.") + def _check_if_its_defined_by_interval(self, dataset_bands: list[int] | list[tuple[int]] = None, + output_bands: list[int] | list[tuple[int]] = None) -> bool: + + is_dataset_bands_defined = self._bands_defined_by_interval(bands_list=dataset_bands) + is_output_bands_defined = self._bands_defined_by_interval(bands_list=output_bands) + + if is_dataset_bands_defined and is_output_bands_defined: + return True + elif not is_dataset_bands_defined and not is_output_bands_defined: + return False + else: + raise Exception(f"Both dataset_bands and output_bands must have the same type, but received {dataset_bands} and {output_bands}") + def _bands_defined_by_interval(self, bands_list: list[int] | list[tuple[int]] = None) -> bool: if not bands_list: return False @@ -224,7 +236,6 @@ def _bands_defined_by_interval(self, bands_list: list[int] | list[tuple[int]] = else: raise Exception(f"Whe using subintervals, the limits must be int.") else: - print(bands_list) raise Exception(f"Excpected List[int] or List[str] or List[tuple[int, int]], but received {type(bands_list)}.") class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset): From 303bfe89a22a20cbe65de90fbf4d701aff7df25f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Tue, 23 Jul 2024 10:28:26 -0300 Subject: [PATCH 16/22] This warning is not necessary MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/datasets/generic_pixel_wise_dataset.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index 42ebbdb0..2241f38a 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -130,12 +130,6 @@ def __init__( else: self.dataset_bands = dataset_bands self.output_bands = output_bands - - bands_type = self._bands_as_int_or_str(dataset_bands, output_bands) - - if bands_type == str: - UserWarning("When the bands are defined as str, guarantee your input files"+ - "are organized by band and all have its specific name.") if self.output_bands and not self.dataset_bands: msg = "If output bands provided, dataset_bands must also be provided" From 25b10c60caae3838d92b41b8b30c2d3c067bb90d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Tue, 23 Jul 2024 10:32:21 -0300 Subject: [PATCH 17/22] Reformatting using black MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .../datasets/generic_pixel_wise_dataset.py | 50 +++++++++++-------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index 2241f38a..ddedb129 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -43,8 +43,8 @@ def __init__( ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, rgb_indices: list[int] | None = None, - dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, - output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, + dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, + output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, no_data_replace: float | None = None, @@ -121,7 +121,7 @@ def __init__( ) self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices - is_bands_by_interval = self._check_if_its_defined_by_interval(dataset_bands, output_bands) + is_bands_by_interval = self._check_if_its_defined_by_interval(dataset_bands, output_bands) # If the bands are defined by sub-intervals or not. if is_bands_by_interval: @@ -130,7 +130,7 @@ def __init__( else: self.dataset_bands = dataset_bands self.output_bands = output_bands - + if self.output_bands and not self.dataset_bands: msg = "If output bands provided, dataset_bands must also be provided" return Exception(msg) # noqa: PLE0101 @@ -154,7 +154,7 @@ def __len__(self) -> int: return len(self.image_files) def __getitem__(self, index: int) -> dict[str, Any]: - image = self._load_file(self.image_files[index], nan_replace = self.no_data_replace).to_numpy() + image = self._load_file(self.image_files[index], nan_replace=self.no_data_replace).to_numpy() # to channels last if self.expand_temporal_dimension: image = rearrange(image, "(channels time) h w -> channels time h w", channels=len(self.output_bands)) @@ -164,7 +164,9 @@ def __getitem__(self, index: int) -> dict[str, Any]: image = image[..., self.filter_indices] output = { "image": image.astype(np.float32) * self.constant_scale, - "mask": self._load_file(self.segmentation_mask_files[index], nan_replace = self.no_label_replace).to_numpy()[0], + "mask": self._load_file(self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[ + 0 + ], "filename": self.image_files[index], } @@ -173,14 +175,14 @@ def __getitem__(self, index: int) -> dict[str, Any]: if self.transform: output = self.transform(**output) return output - + def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArray: data = rioxarray.open_rasterio(path, masked=True) if nan_replace is not None: data = data.fillna(nan_replace) return data - def _generate_bands_intervals(self, bands_intervals:List[List[int]] = None): + def _generate_bands_intervals(self, bands_intervals: List[List[int]] = None): bands = list() for b_interval in bands_intervals: bands_sublist = np.arange(b_interval[0], b_interval[1] + 1).astype(int).tolist() @@ -194,19 +196,20 @@ def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type: return None else: for b, bands_list in enumerate([dataset_bands, output_bands]): - if all([type(band)==int for band in bands_list]): + if all([type(band) == int for band in bands_list]): band_type[b] = int - elif all([type(band)==str for band in bands_list]): + elif all([type(band) == str for band in bands_list]): band_type[b] = str else: - pass + pass if band_type.count(band_type[0]) == len(band_type): return band_type[0] else: raise Exception("The bands must be or all str or all int.") - def _check_if_its_defined_by_interval(self, dataset_bands: list[int] | list[tuple[int]] = None, - output_bands: list[int] | list[tuple[int]] = None) -> bool: + def _check_if_its_defined_by_interval( + self, dataset_bands: list[int] | list[tuple[int]] = None, output_bands: list[int] | list[tuple[int]] = None + ) -> bool: is_dataset_bands_defined = self._bands_defined_by_interval(bands_list=dataset_bands) is_output_bands_defined = self._bands_defined_by_interval(bands_list=output_bands) @@ -216,21 +219,26 @@ def _check_if_its_defined_by_interval(self, dataset_bands: list[int] | list[tupl elif not is_dataset_bands_defined and not is_output_bands_defined: return False else: - raise Exception(f"Both dataset_bands and output_bands must have the same type, but received {dataset_bands} and {output_bands}") + raise Exception( + f"Both dataset_bands and output_bands must have the same type, but received {dataset_bands} and {output_bands}" + ) def _bands_defined_by_interval(self, bands_list: list[int] | list[tuple[int]] = None) -> bool: if not bands_list: return False - elif all([type(band)==int or type(band)==str or isinstance(band, HLSBands) for band in bands_list]): + elif all([type(band) == int or type(band) == str or isinstance(band, HLSBands) for band in bands_list]): return False elif all([isinstance(subinterval, tuple) for subinterval in bands_list]): bands_list_ = [list(subinterval) for subinterval in bands_list] - if all([type(band)==int for band in sum(bands_list_, [])]): + if all([type(band) == int for band in sum(bands_list_, [])]): return True else: raise Exception(f"Whe using subintervals, the limits must be int.") else: - raise Exception(f"Excpected List[int] or List[str] or List[tuple[int, int]], but received {type(bands_list)}.") + raise Exception( + f"Excpected List[int] or List[str] or List[tuple[int, int]], but received {type(bands_list)}." + ) + class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset): """GenericNonGeoSegmentationDataset""" @@ -246,8 +254,8 @@ def __init__( ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, rgb_indices: list[str] | None = None, - dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, - output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, + dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, + output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, class_names: list[str] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, @@ -413,8 +421,8 @@ def __init__( ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, rgb_indices: list[int] | None = None, - dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, - output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None, + dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, + output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, no_data_replace: float | None = None, From 10551a0d59295101575540925bc6e49bc276e3fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Tue, 23 Jul 2024 10:41:19 -0300 Subject: [PATCH 18/22] Minor improvements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/datasets/generic_pixel_wise_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index ddedb129..6f8c2957 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -183,11 +183,11 @@ def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArr return data def _generate_bands_intervals(self, bands_intervals: List[List[int]] = None): - bands = list() + bands = [] for b_interval in bands_intervals: - bands_sublist = np.arange(b_interval[0], b_interval[1] + 1).astype(int).tolist() + bands_sublist = list(range(b_interval[0], b_interval[1] + 1)) bands.append(bands_sublist) - return sorted(sum(bands, [])) + return reduce(operator.iadd, bands, []) def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type: From aedca3222ba24f45133bad9c4df7c27c1ccd9050 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Tue, 23 Jul 2024 11:43:42 -0300 Subject: [PATCH 19/22] Missing imports MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- terratorch/datasets/generic_pixel_wise_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index 6f8c2957..78533a1e 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -8,7 +8,8 @@ from functools import partial from pathlib import Path from typing import Any, List, Union - +from functools import reduce +import operator import albumentations as A import matplotlib as mpl import numpy as np From 5708e0a3f433b999c2c4fce520a52d4593cfa942 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Wed, 24 Jul 2024 15:36:05 -0300 Subject: [PATCH 20/22] More tests to check if the bands ar properly returned MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- tests/test_generic_dataset.py | 103 +++++++++++++++++++++++++++++++++- 1 file changed, 101 insertions(+), 2 deletions(-) diff --git a/tests/test_generic_dataset.py b/tests/test_generic_dataset.py index 7ec70933..5214cce3 100644 --- a/tests/test_generic_dataset.py +++ b/tests/test_generic_dataset.py @@ -14,7 +14,6 @@ SEGMENTATION_LABEL_PATH = "tests/segmentation_test_label.tif" NUM_CLASSES_SEGMENTATION = 2 - @pytest.fixture(scope="session") def split_file_path(tmp_path_factory): split_file_path = tmp_path_factory.mktemp("split") / "split.txt" @@ -59,7 +58,6 @@ def test_data_type_regression_float_float(self, regression_dataset): assert torch.is_floating_point(regression_dataset[0]["image"]) assert torch.is_floating_point(regression_dataset[0]["mask"]) - class TestGenericSegmentationDataset: @pytest.fixture(scope="class") def data_root_segmentation(self, tmp_path_factory: TempPathFactory): @@ -94,3 +92,104 @@ def test_file_discovery_generic_segmentation_dataset(self, segmentation_dataset) def test_data_type_regression_float_long(self, segmentation_dataset): assert torch.is_floating_point(segmentation_dataset[0]["image"]) assert not torch.is_floating_point(segmentation_dataset[0]["mask"]) + +# Testing bands +# HLS_bands +HLS_dataset_bands = [ + "COASTAL_AEROSOL", + "BLUE", + "GREEN", + "RED", + "NIR_NARROW", + "SWIR_1", + "SWIR_2", + "CIRRUS", + "THEMRAL_INFRARED_1", + "THEMRAL_INFRARED_2", +] + +HLS_output_bands = [ + "BLUE", + "GREEN", + "RED", + "NIR_NARROW", + "SWIR_1", + "SWIR_2", +] + +# Integer Intervals bands +int_dataset_bands = (0,10) +int_output_bands = (1,6) +# Simple string bands +str_dataset_bands = [f"band_{j}" for j in range(10)] +str_output_bands = [f"band_{j}" for j in range(1,6)] + + +class TestGenericDatasetWithBands: + @pytest.fixture(scope="class") + def data_root_regression(self, tmp_path_factory: TempPathFactory): + data_dir = tmp_path_factory.mktemp("data") + image_dir_path = data_dir / "input_data" + label_dir_path = data_dir / "label_data" + os.mkdir(image_dir_path) + os.mkdir(label_dir_path) + for i in range(10): + os.symlink(REGRESSION_IMAGE_PATH, image_dir_path / f"{i}_img.tif") + os.symlink(REGRESSION_LABEL_PATH, label_dir_path / f"{i}_label.tif") + + # add a few with no suffix + for i in range(10, 15): + os.symlink(REGRESSION_IMAGE_PATH, image_dir_path / f"{i}.tif") + os.symlink(REGRESSION_LABEL_PATH, label_dir_path / f"{i}.tif") + return data_dir + + @pytest.fixture(scope="class") + def regression_dataset_with_HLS_bands(self, data_root_regression, split_file_path): + return GenericNonGeoPixelwiseRegressionDataset( + data_root_regression, + dataset_bands=HLS_dataset_bands, + output_bands=HLS_output_bands, + image_grep="input_data/*_img.tif", + label_grep="label_data/*_label.tif", + split=split_file_path, + ) + + @pytest.fixture(scope="class") + def regression_dataset_with_interval_bands(self, data_root_regression, split_file_path): + return GenericNonGeoPixelwiseRegressionDataset( + data_root_regression, + dataset_bands=[int_dataset_bands], + output_bands=[int_output_bands], + image_grep="input_data/*_img.tif", + label_grep="label_data/*_label.tif", + split=split_file_path, + ) + + @pytest.fixture(scope="class") + def regression_dataset_with_str_bands(self, data_root_regression, split_file_path): + return GenericNonGeoPixelwiseRegressionDataset( + data_root_regression, + dataset_bands=str_dataset_bands, + output_bands=str_output_bands, + image_grep="input_data/*_img.tif", + label_grep="label_data/*_label.tif", + split=split_file_path, + ) + + def test_usage_of_HLS_bands(self, regression_dataset_with_HLS_bands): + + dataset = regression_dataset_with_HLS_bands + assert dataset.output_bands == HLS_output_bands + + def test_usage_of_interval_bands(self, regression_dataset_with_interval_bands): + + dataset = regression_dataset_with_interval_bands + int_output_bands_ = list(int_output_bands) + int_output_bands_[1] += 1 + assert dataset.output_bands == list(range(*int_output_bands_)) + + def test_usage_of_str_bands(self, regression_dataset_with_str_bands): + + dataset = regression_dataset_with_str_bands + assert dataset.output_bands == str_output_bands + From 3faf7d3a2a9e90c057d6f12a88babe1075fb91fc Mon Sep 17 00:00:00 2001 From: Carlos Gomes Date: Fri, 26 Jul 2024 12:00:37 +0200 Subject: [PATCH 21/22] accept mixed band specifications Signed-off-by: Carlos Gomes --- .../datasets/generic_pixel_wise_dataset.py | 98 ++------ .../datasets/generic_scalar_label_dataset.py | 8 +- tests/test_generic_dataset.py | 235 ++++++++++++------ 3 files changed, 184 insertions(+), 157 deletions(-) diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index 78533a1e..ed0b65b4 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -1,24 +1,21 @@ # Copyright contributors to the Terratorch project -"""Module containing generic dataset classes -""" +"""Module containing generic dataset classes""" + import glob +import operator import os from abc import ABC -from functools import partial -from pathlib import Path -from typing import Any, List, Union from functools import reduce -import operator +from pathlib import Path +from typing import Any + import albumentations as A import matplotlib as mpl import numpy as np import rioxarray -import torch import xarray as xr -from albumentations.pytorch import ToTensorV2 from einops import rearrange -from matplotlib import cm from matplotlib import pyplot as plt from matplotlib.figure import Figure from matplotlib.patches import Rectangle @@ -122,15 +119,8 @@ def __init__( ) self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices - is_bands_by_interval = self._check_if_its_defined_by_interval(dataset_bands, output_bands) - - # If the bands are defined by sub-intervals or not. - if is_bands_by_interval: - self.dataset_bands = self._generate_bands_intervals(dataset_bands) - self.output_bands = self._generate_bands_intervals(output_bands) - else: - self.dataset_bands = dataset_bands - self.output_bands = output_bands + self.dataset_bands = self._generate_bands_intervals(dataset_bands) + self.output_bands = self._generate_bands_intervals(output_bands) if self.output_bands and not self.dataset_bands: msg = "If output bands provided, dataset_bands must also be provided" @@ -183,63 +173,25 @@ def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArr data = data.fillna(nan_replace) return data - def _generate_bands_intervals(self, bands_intervals: List[List[int]] = None): - bands = [] - for b_interval in bands_intervals: - bands_sublist = list(range(b_interval[0], b_interval[1] + 1)) - bands.append(bands_sublist) - return reduce(operator.iadd, bands, []) - - def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type: - - band_type = [None, None] - if not dataset_bands and not output_bands: + def _generate_bands_intervals(self, bands_intervals: list[int | str | HLSBands | tuple[int]] | None = None): + if bands_intervals is None: return None - else: - for b, bands_list in enumerate([dataset_bands, output_bands]): - if all([type(band) == int for band in bands_list]): - band_type[b] = int - elif all([type(band) == str for band in bands_list]): - band_type[b] = str - else: - pass - if band_type.count(band_type[0]) == len(band_type): - return band_type[0] - else: - raise Exception("The bands must be or all str or all int.") - - def _check_if_its_defined_by_interval( - self, dataset_bands: list[int] | list[tuple[int]] = None, output_bands: list[int] | list[tuple[int]] = None - ) -> bool: - - is_dataset_bands_defined = self._bands_defined_by_interval(bands_list=dataset_bands) - is_output_bands_defined = self._bands_defined_by_interval(bands_list=output_bands) - - if is_dataset_bands_defined and is_output_bands_defined: - return True - elif not is_dataset_bands_defined and not is_output_bands_defined: - return False - else: - raise Exception( - f"Both dataset_bands and output_bands must have the same type, but received {dataset_bands} and {output_bands}" - ) - - def _bands_defined_by_interval(self, bands_list: list[int] | list[tuple[int]] = None) -> bool: - if not bands_list: - return False - elif all([type(band) == int or type(band) == str or isinstance(band, HLSBands) for band in bands_list]): - return False - elif all([isinstance(subinterval, tuple) for subinterval in bands_list]): - bands_list_ = [list(subinterval) for subinterval in bands_list] - if all([type(band) == int for band in sum(bands_list_, [])]): - return True + bands = [] + for element in bands_intervals: + # if its an interval + if isinstance(element, tuple): + if len(element) != 2: # noqa: PLR2004 + msg = "When defining an interval, a tuple of two integers should be passed, defining start and end indices inclusive" + raise Exception(msg) + expanded_element = list(range(element[0], element[1] + 1)) + bands.extend(expanded_element) else: - raise Exception(f"Whe using subintervals, the limits must be int.") - else: - raise Exception( - f"Excpected List[int] or List[str] or List[tuple[int, int]], but received {type(bands_list)}." - ) - + bands.append(element) + # check the expansion didnt result in duplicate elements + if len(set(bands)) != len(bands): + msg = "Duplicate indices detected. Indices must be unique." + raise Exception(msg) + return bands class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset): """GenericNonGeoSegmentationDataset""" diff --git a/terratorch/datasets/generic_scalar_label_dataset.py b/terratorch/datasets/generic_scalar_label_dataset.py index 85b16a75..65db7cdb 100644 --- a/terratorch/datasets/generic_scalar_label_dataset.py +++ b/terratorch/datasets/generic_scalar_label_dataset.py @@ -110,17 +110,21 @@ def is_valid_file(x): self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices - self.dataset_bands = dataset_bands - self.output_bands = output_bands + self.dataset_bands = self._generate_bands_intervals(dataset_bands) + self.output_bands = self._generate_bands_intervals(output_bands) + if self.output_bands and not self.dataset_bands: msg = "If output bands provided, dataset_bands must also be provided" return Exception(msg) # noqa: PLE0101 + # There is a special condition if the bands are defined as simple strings. if self.output_bands: if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands): msg = "Output bands must be a subset of dataset bands" raise Exception(msg) + self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands] + else: self.filter_indices = None # If no transform is given, apply only to transform to torch tensor diff --git a/tests/test_generic_dataset.py b/tests/test_generic_dataset.py index 5214cce3..9ac8bd2a 100644 --- a/tests/test_generic_dataset.py +++ b/tests/test_generic_dataset.py @@ -6,7 +6,7 @@ import torch from _pytest.tmpdir import TempPathFactory -from terratorch.datasets import GenericNonGeoPixelwiseRegressionDataset, GenericNonGeoSegmentationDataset +from terratorch.datasets import GenericNonGeoPixelwiseRegressionDataset, GenericNonGeoSegmentationDataset, HLSBands REGRESSION_IMAGE_PATH = "tests/regression_test_input.tif" REGRESSION_LABEL_PATH = "tests/regression_test_label.tif" @@ -14,6 +14,57 @@ SEGMENTATION_LABEL_PATH = "tests/segmentation_test_label.tif" NUM_CLASSES_SEGMENTATION = 2 +# Testing bands +# HLS_bands +HLS_dataset_bands = [ + "COASTAL_AEROSOL", + "BLUE", + "GREEN", + "RED", + "NIR_NARROW", + "SWIR_1", + "SWIR_2", + "CIRRUS", + "THEMRAL_INFRARED_1", + "THEMRAL_INFRARED_2", +] + +HLS_output_bands = [ + "BLUE", + "GREEN", + "RED", + "NIR_NARROW", + "SWIR_1", + "SWIR_2", +] + +HLS_expected_filter_bands = list(range(1, 7)) +# Integer Intervals bands +int_dataset_bands = [(0, 20)] +int_output_bands = [(1, 6), (10, 12)] +# Simple string bands +str_dataset_bands = [f"band_{j}" for j in range(20)] +str_output_bands = [f"band_{j}" for j in range(1, 7)] + [f"band_{j}" for j in range(10, 13)] + +expected_filter_indices = list(range(1, 7)) + list(range(10, 13)) + + +# Mixed case +mixed_dataset_bands = [ + (0, 10), + HLSBands.RED, + HLSBands.BLUE, + HLSBands.GREEN, + "extra_band_1", + "extra_band_2", + 200, + 201, + 202, +] +mixed_output_bands = [1, 2, HLSBands.BLUE, "extra_band_1", 201, 202] +expected_mixed_filter_indices = [1, 2, 12, 14, 17, 18] + + @pytest.fixture(scope="session") def split_file_path(tmp_path_factory): split_file_path = tmp_path_factory.mktemp("split") / "split.txt" @@ -58,6 +109,65 @@ def test_data_type_regression_float_float(self, regression_dataset): assert torch.is_floating_point(regression_dataset[0]["image"]) assert torch.is_floating_point(regression_dataset[0]["mask"]) + @pytest.fixture(scope="class") + def regression_dataset_with_HLS_bands(self, data_root_regression, split_file_path): + return GenericNonGeoPixelwiseRegressionDataset( + data_root_regression, + dataset_bands=HLS_dataset_bands, + output_bands=HLS_output_bands, + image_grep="input_data/*_img.tif", + label_grep="label_data/*_label.tif", + split=split_file_path, + ), HLS_expected_filter_bands + + @pytest.fixture(scope="class") + def regression_dataset_with_interval_bands(self, data_root_regression, split_file_path): + return GenericNonGeoPixelwiseRegressionDataset( + data_root_regression, + dataset_bands=int_dataset_bands, + output_bands=int_output_bands, + image_grep="input_data/*_img.tif", + label_grep="label_data/*_label.tif", + split=split_file_path, + ), expected_filter_indices + + @pytest.fixture(scope="class") + def regression_dataset_with_str_bands(self, data_root_regression, split_file_path): + return GenericNonGeoPixelwiseRegressionDataset( + data_root_regression, + dataset_bands=str_dataset_bands, + output_bands=str_output_bands, + image_grep="input_data/*_img.tif", + label_grep="label_data/*_label.tif", + split=split_file_path, + ), expected_filter_indices + + @pytest.fixture(scope="class") + def regression_dataset_with_mixed_bands(self, data_root_regression, split_file_path): + return GenericNonGeoPixelwiseRegressionDataset( + data_root_regression, + dataset_bands=mixed_dataset_bands, + output_bands=mixed_output_bands, + image_grep="input_data/*_img.tif", + label_grep="label_data/*_label.tif", + split=split_file_path, + ), expected_mixed_filter_indices + + @pytest.mark.parametrize( + "dataset", + [ + "regression_dataset_with_HLS_bands", + "regression_dataset_with_str_bands", + "regression_dataset_with_interval_bands", + "regression_dataset_with_str_bands", + "regression_dataset_with_mixed_bands", + ], + ) + def test_correct_filter(self, dataset, request): + fixture, expected = request.getfixturevalue(dataset) + assert fixture.filter_indices == expected + + class TestGenericSegmentationDataset: @pytest.fixture(scope="class") def data_root_segmentation(self, tmp_path_factory: TempPathFactory): @@ -93,103 +203,64 @@ def test_data_type_regression_float_long(self, segmentation_dataset): assert torch.is_floating_point(segmentation_dataset[0]["image"]) assert not torch.is_floating_point(segmentation_dataset[0]["mask"]) -# Testing bands -# HLS_bands -HLS_dataset_bands = [ - "COASTAL_AEROSOL", - "BLUE", - "GREEN", - "RED", - "NIR_NARROW", - "SWIR_1", - "SWIR_2", - "CIRRUS", - "THEMRAL_INFRARED_1", - "THEMRAL_INFRARED_2", -] - -HLS_output_bands = [ - "BLUE", - "GREEN", - "RED", - "NIR_NARROW", - "SWIR_1", - "SWIR_2", -] - -# Integer Intervals bands -int_dataset_bands = (0,10) -int_output_bands = (1,6) -# Simple string bands -str_dataset_bands = [f"band_{j}" for j in range(10)] -str_output_bands = [f"band_{j}" for j in range(1,6)] - - -class TestGenericDatasetWithBands: @pytest.fixture(scope="class") - def data_root_regression(self, tmp_path_factory: TempPathFactory): - data_dir = tmp_path_factory.mktemp("data") - image_dir_path = data_dir / "input_data" - label_dir_path = data_dir / "label_data" - os.mkdir(image_dir_path) - os.mkdir(label_dir_path) - for i in range(10): - os.symlink(REGRESSION_IMAGE_PATH, image_dir_path / f"{i}_img.tif") - os.symlink(REGRESSION_LABEL_PATH, label_dir_path / f"{i}_label.tif") - - # add a few with no suffix - for i in range(10, 15): - os.symlink(REGRESSION_IMAGE_PATH, image_dir_path / f"{i}.tif") - os.symlink(REGRESSION_LABEL_PATH, label_dir_path / f"{i}.tif") - return data_dir - - @pytest.fixture(scope="class") - def regression_dataset_with_HLS_bands(self, data_root_regression, split_file_path): - return GenericNonGeoPixelwiseRegressionDataset( - data_root_regression, + def segmentation_dataset_with_HLS_bands(self, data_root_segmentation, split_file_path): + return GenericNonGeoSegmentationDataset( + data_root_segmentation, + NUM_CLASSES_SEGMENTATION, dataset_bands=HLS_dataset_bands, output_bands=HLS_output_bands, image_grep="input_data/*_img.tif", label_grep="label_data/*_label.tif", split=split_file_path, - ) + ), HLS_expected_filter_bands @pytest.fixture(scope="class") - def regression_dataset_with_interval_bands(self, data_root_regression, split_file_path): - return GenericNonGeoPixelwiseRegressionDataset( - data_root_regression, - dataset_bands=[int_dataset_bands], - output_bands=[int_output_bands], + def segmentation_dataset_with_interval_bands(self, data_root_segmentation, split_file_path): + return GenericNonGeoSegmentationDataset( + data_root_segmentation, + NUM_CLASSES_SEGMENTATION, + dataset_bands=int_dataset_bands, + output_bands=int_output_bands, image_grep="input_data/*_img.tif", label_grep="label_data/*_label.tif", split=split_file_path, - ) + ), expected_filter_indices @pytest.fixture(scope="class") - def regression_dataset_with_str_bands(self, data_root_regression, split_file_path): - return GenericNonGeoPixelwiseRegressionDataset( - data_root_regression, + def segmentation_dataset_with_str_bands(self, data_root_segmentation, split_file_path): + return GenericNonGeoSegmentationDataset( + data_root_segmentation, + NUM_CLASSES_SEGMENTATION, dataset_bands=str_dataset_bands, output_bands=str_output_bands, image_grep="input_data/*_img.tif", label_grep="label_data/*_label.tif", split=split_file_path, - ) - - def test_usage_of_HLS_bands(self, regression_dataset_with_HLS_bands): - - dataset = regression_dataset_with_HLS_bands - assert dataset.output_bands == HLS_output_bands - - def test_usage_of_interval_bands(self, regression_dataset_with_interval_bands): - - dataset = regression_dataset_with_interval_bands - int_output_bands_ = list(int_output_bands) - int_output_bands_[1] += 1 - assert dataset.output_bands == list(range(*int_output_bands_)) - - def test_usage_of_str_bands(self, regression_dataset_with_str_bands): - - dataset = regression_dataset_with_str_bands - assert dataset.output_bands == str_output_bands + ), expected_filter_indices + @pytest.fixture(scope="class") + def segmentation_dataset_with_mixed_bands(self, data_root_segmentation, split_file_path): + return GenericNonGeoSegmentationDataset( + data_root_segmentation, + NUM_CLASSES_SEGMENTATION, + dataset_bands=mixed_dataset_bands, + output_bands=mixed_output_bands, + image_grep="input_data/*_img.tif", + label_grep="label_data/*_label.tif", + split=split_file_path, + ), expected_mixed_filter_indices + + @pytest.mark.parametrize( + "dataset", + [ + "segmentation_dataset_with_HLS_bands", + "segmentation_dataset_with_str_bands", + "segmentation_dataset_with_interval_bands", + "segmentation_dataset_with_str_bands", + "segmentation_dataset_with_mixed_bands", + ], + ) + def test_correct_filter(self, dataset, request): + fixture, expected = request.getfixturevalue(dataset) + assert fixture.filter_indices == expected From 1a254e42d0eb58479a366a0327432c0990cc9ec2 Mon Sep 17 00:00:00 2001 From: Carlos Gomes Date: Fri, 26 Jul 2024 15:54:28 +0200 Subject: [PATCH 22/22] improve docstring comments Signed-off-by: Carlos Gomes --- terratorch/datasets/generic_pixel_wise_dataset.py | 7 +++---- terratorch/datasets/generic_scalar_label_dataset.py | 8 ++++---- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index ed0b65b4..acb07ec7 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -3,10 +3,8 @@ """Module containing generic dataset classes""" import glob -import operator import os from abc import ABC -from functools import reduce from pathlib import Path from typing import Any @@ -71,8 +69,8 @@ def __init__( that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True. rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. - dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. - output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. + dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None. + output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands. constant_scale (float): Factor to multiply image values by. Defaults to 1. transform (Albumentations.Compose | None): Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module, @@ -193,6 +191,7 @@ def _generate_bands_intervals(self, bands_intervals: list[int | str | HLSBands | raise Exception(msg) return bands + class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset): """GenericNonGeoSegmentationDataset""" diff --git a/terratorch/datasets/generic_scalar_label_dataset.py b/terratorch/datasets/generic_scalar_label_dataset.py index 65db7cdb..a4785569 100644 --- a/terratorch/datasets/generic_scalar_label_dataset.py +++ b/terratorch/datasets/generic_scalar_label_dataset.py @@ -42,8 +42,8 @@ def __init__( ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, rgb_indices: list[int] | None = None, - dataset_bands: list[HLSBands | int] | None = None, - output_bands: list[HLSBands | int] | None = None, + dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, + output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, no_data_replace: float = 0, @@ -64,8 +64,8 @@ def __init__( that must be present in file names to be included (as in mmsegmentation), or exact matches (e.g. eurosat). Defaults to True. rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2]. - dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. - output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset. + dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None. + output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands. constant_scale (float): Factor to multiply image values by. Defaults to 1. transform (Albumentations.Compose | None): Albumentations transform to be applied. Should end with ToTensorV2(). If used through the generic_data_module,