From 5dba48101bfd6a5a7e80d647b1e2056174440b27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Lucas=20de=20Sousa=20Almeida?= Date: Mon, 22 Jul 2024 13:25:17 -0300 Subject: [PATCH] Strings are allowed to define bands MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lucas de Sousa Almeida --- .../generic_pixel_wise_data_module.py | 12 ++++++------ .../datasets/generic_pixel_wise_dataset.py | 17 ++++++++++------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/terratorch/datamodules/generic_pixel_wise_data_module.py b/terratorch/datamodules/generic_pixel_wise_data_module.py index 050a1583..a79f82fb 100644 --- a/terratorch/datamodules/generic_pixel_wise_data_module.py +++ b/terratorch/datamodules/generic_pixel_wise_data_module.py @@ -91,9 +91,9 @@ def __init__( test_split: Path | None = None, ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, - dataset_bands: list[HLSBands | int | list[int]] | None = None, - predict_dataset_bands: list[HLSBands | int | list[int]] | None = None, - output_bands: list[HLSBands | int | list[int]] | None = None, + dataset_bands: list[HLSBands | int | list[int] | str] | None = None, + predict_dataset_bands: list[HLSBands | int | list[int] | str ] | None = None, + output_bands: list[HLSBands | int | list[int] | str] | None = None, constant_scale: float = 1, rgb_indices: list[int] | None = None, train_transform: A.Compose | None | list[A.BasicTransform] = None, @@ -330,9 +330,9 @@ def __init__( test_split: Path | None = None, ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, - dataset_bands: list[HLSBands | int | list[int]] | None = None, - predict_dataset_bands: list[HLSBands | int | list[int]] | None = None, - output_bands: list[HLSBands | int | list[int]] | None = None, + dataset_bands: list[HLSBands | int | list[int] | str ] | None = None, + predict_dataset_bands: list[HLSBands | int | list[int] | str ] | None = None, + output_bands: list[HLSBands | int | list[int] | str ] | None = None, constant_scale: float = 1, rgb_indices: list[int] | None = None, train_transform: A.Compose | None | list[A.BasicTransform] = None, diff --git a/terratorch/datasets/generic_pixel_wise_dataset.py b/terratorch/datasets/generic_pixel_wise_dataset.py index 96de29f2..ce9991aa 100644 --- a/terratorch/datasets/generic_pixel_wise_dataset.py +++ b/terratorch/datasets/generic_pixel_wise_dataset.py @@ -88,6 +88,7 @@ def __init__( expected 0. Defaults to False. """ super().__init__() + self.split_file = split label_data_root = label_data_root if label_data_root is not None else data_root @@ -136,7 +137,7 @@ def __init__( if bands_type == str: raise UserWarning("When the bands are defined as str, guarantee your input files"+ "are organized by band and all have its specific name.") - + if self.output_bands and not self.dataset_bands: msg = "If output bands provided, dataset_bands must also be provided" return Exception(msg) # noqa: PLE0101 @@ -146,7 +147,9 @@ def __init__( if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands): msg = "Output bands must be a subset of dataset bands" raise Exception(msg) + self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands] + else: self.filter_indices = None @@ -176,7 +179,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: if self.transform: output = self.transform(**output) return output - + def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArray: data = rioxarray.open_rasterio(path, masked=True) if nan_replace is not None: @@ -200,7 +203,7 @@ def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type: band_type[b] = str else: pass - if band_type.cound(band_type[0]) == len(band_type) + if band_type.cound(band_type[0]) == len(band_type): return band_type[0] else: raise Exception("The bands must be or all str or all int.") @@ -232,8 +235,8 @@ def __init__( ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, rgb_indices: list[str] | None = None, - dataset_bands: list[HLSBands | int | list[int]] | None = None, - output_bands: list[HLSBands | int | list[int]] | None = None, + dataset_bands: list[HLSBands | int | list[int] | str ] | None = None, + output_bands: list[HLSBands | int | list[int] | str ] | None = None, class_names: list[str] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, @@ -399,8 +402,8 @@ def __init__( ignore_split_file_extensions: bool = True, allow_substring_split_file: bool = True, rgb_indices: list[int] | None = None, - dataset_bands: list[HLSBands | int | list[int]] | None = None, - output_bands: list[HLSBands | int | list[int]] | None = None, + dataset_bands: list[HLSBands | int | list[int] | str ] | None = None, + output_bands: list[HLSBands | int | list[int] | str ] | None = None, constant_scale: float = 1, transform: A.Compose | None = None, no_data_replace: float | None = None,