From 1321f596d9a16fce0ba69e425e7f9470eb86e531 Mon Sep 17 00:00:00 2001 From: Pedro Henrique Conrado Date: Mon, 11 Nov 2024 14:42:24 -0500 Subject: [PATCH 1/2] adds new datasets --- terratorch/datamodules/__init__.py | 11 + terratorch/datamodules/biomassters.py | 190 +++++++ terratorch/datamodules/burn_intensity.py | 94 ++++ terratorch/datamodules/carbonflux.py | 112 ++++ terratorch/datamodules/forestnet.py | 94 ++++ .../generic_multimodal_data_module.py | 62 +++ terratorch/datamodules/landslide4sense.py | 96 ++++ .../multi_temporal_crop_classification.py | 6 +- terratorch/datamodules/utils.py | 26 +- terratorch/datasets/__init__.py | 11 + terratorch/datasets/biomassters.py | 507 ++++++++++++++++++ terratorch/datasets/burn_intensity.py | 250 +++++++++ terratorch/datasets/carbonflux.py | 247 +++++++++ terratorch/datasets/forestnet.py | 227 ++++++++ .../datasets/generic_multimodal_dataset.py | 23 + terratorch/datasets/landslide4sense.py | 155 ++++++ terratorch/datasets/m_forestnet.py | 2 +- terratorch/datasets/transforms.py | 40 +- terratorch/models/pixel_wise_model.py | 6 +- terratorch/tasks/classification_tasks.py | 8 +- terratorch/tasks/regression_tasks.py | 5 +- terratorch/tasks/segmentation_tasks.py | 5 +- 22 files changed, 2166 insertions(+), 11 deletions(-) create mode 100644 terratorch/datamodules/biomassters.py create mode 100644 terratorch/datamodules/burn_intensity.py create mode 100644 terratorch/datamodules/carbonflux.py create mode 100644 terratorch/datamodules/forestnet.py create mode 100644 terratorch/datamodules/generic_multimodal_data_module.py create mode 100644 terratorch/datamodules/landslide4sense.py create mode 100644 terratorch/datasets/biomassters.py create mode 100644 terratorch/datasets/burn_intensity.py create mode 100644 terratorch/datasets/carbonflux.py create mode 100644 terratorch/datasets/forestnet.py create mode 100644 terratorch/datasets/generic_multimodal_dataset.py create mode 100644 terratorch/datasets/landslide4sense.py diff --git a/terratorch/datamodules/__init__.py b/terratorch/datamodules/__init__.py index 02b46b14..1c2af9b0 100644 --- a/terratorch/datamodules/__init__.py +++ b/terratorch/datamodules/__init__.py @@ -33,6 +33,12 @@ from terratorch.datamodules.sen4agrinet import Sen4AgriNetDataModule from terratorch.datamodules.torchgeo_data_module import TorchGeoDataModule, TorchNonGeoDataModule +from terratorch.datamodules.burn_intensity import BurnIntensityNonGeoDataModule +from terratorch.datamodules.carbonflux import CarbonFluxNonGeoDataModule +from terratorch.datamodules.landslide4sense import Landslide4SenseNonGeoDataModule +from terratorch.datamodules.biomassters import BioMasstersNonGeoDataModule +from terratorch.datamodules.forestnet import ForestNetNonGeoDataModule + # Generic classification datamodule from terratorch.datamodules.sen4map import Sen4MapLucasDataModule @@ -42,6 +48,11 @@ "GenericNonGeoSegmentationDataModule", "GenericNonGeoClassificationDataModule", # "GenericNonGeoRegressionDataModule", + "BurnIntensityNonGeoDataModule", + "CarbonFluxNonGeoDataModule", + "Landslide4SenseNonGeoDataModule", + "ForestNetNonGeoDataModule", + "BioMasstersNonGeoDataModule" "Sen1Floods11NonGeoDataModule", "Sen4MapLucasDataModule", "FireScarsNonGeoDataModule", diff --git a/terratorch/datamodules/biomassters.py b/terratorch/datamodules/biomassters.py new file mode 100644 index 00000000..eaa04471 --- /dev/null +++ b/terratorch/datamodules/biomassters.py @@ -0,0 +1,190 @@ +from collections.abc import Sequence +from typing import Any + +import albumentations as A +from torch.utils.data import DataLoader + +from terratorch.datamodules.generic_multimodal_data_module import MultimodalNormalize, wrap_in_compose_is_list +from terratorch.datamodules.generic_pixel_wise_data_module import Normalize +from terratorch.datasets import BioMasstersNonGeo +from torchgeo.datamodules import NonGeoDataModule +from torchgeo.transforms import AugmentationSequential + +MEANS = { + "AGBM": 63.4584, + "S1": { + "VV_Asc": 0.08871397, + "VH_Asc": 0.02172604, + "VV_Desc": 0.08556002, + "VH_Desc": 0.02795591, + "RVI_Asc": 0.75507677, + "RVI_Desc": 0.6600374 + }, + "S2": { + "BLUE": 1633.0802, + "GREEN": 1610.0035, + "RED": 1599.557, + "RED_EDGE_1": 1916.7083, + "RED_EDGE_2": 2478.8325, + "RED_EDGE_3": 2591.326, + "NIR_BROAD": 2738.5837, + "NIR_NARROW": 2685.8281, + "SWIR_1": 1023.90204, + "SWIR_2": 696.48755, + "CLOUD_PROBABILITY": 21.177078 + } +} + +STDS = { + "AGBM": 72.21242, + "S1": { + "VV_Asc": 0.16714208, + "VH_Asc": 0.04876742, + "VV_Desc": 0.19260046, + "VH_Desc": 0.10272296, + "RVI_Asc": 0.24945821, + "RVI_Desc": 0.3590119 + }, + "S2": { + "BLUE": 2499.7146, + "GREEN": 2308.5298, + "RED": 2388.2268, + "RED_EDGE_1": 2389.6375, + "RED_EDGE_2": 2209.6467, + "RED_EDGE_3": 2104.572, + "NIR_BROAD": 2194.209, + "NIR_NARROW": 2031.7762, + "SWIR_1": 934.0556, + "SWIR_2": 759.8444, + "CLOUD_PROBABILITY": 49.352486 + } +} + +class BioMasstersNonGeoDataModule(NonGeoDataModule): + """NonGeo datamodule implementation for BioMassters.""" + + default_metadata_filename = "The_BioMassters_-_features_metadata.csv.csv" + + def __init__( + self, + data_root: str, + batch_size: int = 4, + num_workers: int = 0, + bands: dict[str, Sequence[str]] | Sequence[str] = BioMasstersNonGeo.all_band_names, + train_transform: A.Compose | None | list[A.BasicTransform] = None, + val_transform: A.Compose | None | list[A.BasicTransform] = None, + test_transform: A.Compose | None | list[A.BasicTransform] = None, + aug: AugmentationSequential = None, + drop_last: bool = True, + sensors: Sequence[str] = ["S1", "S2"], + as_time_series: bool = False, + metadata_filename: str = default_metadata_filename, + max_cloud_percentage: float | None = None, + max_red_mean: float | None = None, + include_corrupt: bool = True, + subset: float = 1, + seed: int = 42, + use_four_frames: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(BioMasstersNonGeo, batch_size, num_workers, **kwargs) + self.data_root = data_root + self.sensors = sensors + if isinstance(bands, dict): + self.bands = bands + else: + sens = sensors[0] + self.bands = {sens: bands} + + self.means = {} + self.stds = {} + for sensor in self.sensors: + self.means[sensor] = [MEANS[sensor][band] for band in self.bands[sensor]] + self.stds[sensor] = [STDS[sensor][band] for band in self.bands[sensor]] + + self.mask_mean = MEANS["AGBM"] + self.mask_std = STDS["AGBM"] + self.train_transform = wrap_in_compose_is_list(train_transform) + self.val_transform = wrap_in_compose_is_list(val_transform) + self.test_transform = wrap_in_compose_is_list(test_transform) + if len(sensors) == 1: + self.aug = Normalize(self.means[sensors[0]], self.stds[sensors[0]]) if aug is None else aug + else: + MultimodalNormalize(self.means, self.stds) if aug is None else aug + self.drop_last = drop_last + self.as_time_series = as_time_series + self.metadata_filename = metadata_filename + self.max_cloud_percentage = max_cloud_percentage + self.max_red_mean = max_red_mean + self.include_corrupt = include_corrupt + self.subset = subset + self.seed = seed + self.use_four_frames = use_four_frames + + def setup(self, stage: str) -> None: + if stage in ["fit"]: + self.train_dataset = self.dataset_class( + split="train", + root=self.data_root, + transform=self.train_transform, + bands=self.bands, + mask_mean=self.mask_mean, + mask_std=self.mask_std, + sensors=self.sensors, + as_time_series=self.as_time_series, + metadata_filename=self.metadata_filename, + max_cloud_percentage=self.max_cloud_percentage, + max_red_mean=self.max_red_mean, + include_corrupt=self.include_corrupt, + subset=self.subset, + seed=self.seed, + use_four_frames=self.use_four_frames, + ) + if stage in ["fit", "validate"]: + self.val_dataset = self.dataset_class( + split="test", + root=self.data_root, + transform=self.val_transform, + bands=self.bands, + mask_mean=self.mask_mean, + mask_std=self.mask_std, + sensors=self.sensors, + as_time_series=self.as_time_series, + metadata_filename=self.metadata_filename, + max_cloud_percentage=self.max_cloud_percentage, + max_red_mean=self.max_red_mean, + include_corrupt=self.include_corrupt, + subset=self.subset, + seed=self.seed, + use_four_frames=self.use_four_frames, + ) + if stage in ["test"]: + self.test_dataset = self.dataset_class( + split="test", + root=self.data_root, + transform=self.test_transform, + bands=self.bands, + mask_mean=self.mask_mean, + mask_std=self.mask_std, + sensors=self.sensors, + as_time_series=self.as_time_series, + metadata_filename=self.metadata_filename, + max_cloud_percentage=self.max_cloud_percentage, + max_red_mean=self.max_red_mean, + include_corrupt=self.include_corrupt, + subset=self.subset, + seed=self.seed, + use_four_frames=self.use_four_frames, + ) + + def _dataloader_factory(self, split: str): + dataset = self._valid_attribute(f"{split}_dataset", "dataset") + batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size") + return DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=split == "train", + num_workers=self.num_workers, + collate_fn=self.collate_fn, + drop_last=split =="train" and self.drop_last, + ) diff --git a/terratorch/datamodules/burn_intensity.py b/terratorch/datamodules/burn_intensity.py new file mode 100644 index 00000000..6c5b3343 --- /dev/null +++ b/terratorch/datamodules/burn_intensity.py @@ -0,0 +1,94 @@ +from collections.abc import Sequence +from typing import Any + +import albumentations as A + +from terratorch.datamodules.utils import NormalizeWithTimesteps, wrap_in_compose_is_list +from terratorch.datasets import BurnIntensityNonGeo +from torchgeo.datamodules import NonGeoDataModule + +MEANS = { + "BLUE": [331.6921, 896.8024, 348.8031], + "GREEN": [555.1077, 1093.9736, 500.2181], + "RED": [605.2513, 1142.7225, 597.9034], + "NIR": [1761.3884, 1890.2156, 1552.0403], + "SWIR_1": [1117.1825, 1408.0839, 1293.0919], + "SWIR_2": [2168.0090, 2270.9753, 1362.1312], +} + +STDS = { + "BLUE": [213.0656, 1620.4131, 314.7517], + "GREEN": [273.0910, 1628.4181, 365.6746], + "RED": [414.8322, 1600.7698, 424.8185], + "NIR": [818.7486, 1236.8453, 804.9058], + "SWIR_1": [677.2739, 1153.7432, 795.4156], + "SWIR_2": [612.9131, 1495.8365, 661.6196], +} + +class BurnIntensityNonGeoDataModule(NonGeoDataModule): + """NonGeo datamodule implementation for BurnIntensity.""" + + def __init__( + self, + data_root: str, + batch_size: int = 4, + num_workers: int = 0, + bands: Sequence[str] = BurnIntensityNonGeo.all_band_names, + train_transform: A.Compose | None | list[A.BasicTransform] = None, + val_transform: A.Compose | None | list[A.BasicTransform] = None, + test_transform: A.Compose | None | list[A.BasicTransform] = None, + use_full_data: bool = True, + no_data_replace: float | None = 0.0001, + no_label_replace: int | None = -1, + use_metadata: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(BurnIntensityNonGeo, batch_size, num_workers, **kwargs) + self.data_root = data_root + + means = [MEANS[b] for b in bands] + stds = [STDS[b] for b in bands] + self.bands = bands + self.train_transform = wrap_in_compose_is_list(train_transform) + self.val_transform = wrap_in_compose_is_list(val_transform) + self.test_transform = wrap_in_compose_is_list(test_transform) + self.aug = NormalizeWithTimesteps(means, stds) + self.use_full_data = use_full_data + self.no_data_replace = no_data_replace + self.no_label_replace = no_label_replace + self.use_metadata = use_metadata + + def setup(self, stage: str) -> None: + if stage in ["fit"]: + self.train_dataset = self.dataset_class( + split="train", + data_root=self.data_root, + transform=self.train_transform, + bands=self.bands, + use_full_data=self.use_full_data, + no_data_replace=self.no_data_replace, + no_label_replace=self.no_label_replace, + use_metadata=self.use_metadata, + ) + if stage in ["fit", "validate"]: + self.val_dataset = self.dataset_class( + split="val", + data_root=self.data_root, + transform=self.val_transform, + bands=self.bands, + use_full_data=self.use_full_data, + no_data_replace=self.no_data_replace, + no_label_replace=self.no_label_replace, + use_metadata=self.use_metadata, + ) + if stage in ["test"]: + self.test_dataset = self.dataset_class( + split="val", + data_root=self.data_root, + transform=self.test_transform, + bands=self.bands, + use_full_data=self.use_full_data, + no_data_replace=self.no_data_replace, + no_label_replace=self.no_label_replace, + use_metadata=self.use_metadata, + ) diff --git a/terratorch/datamodules/carbonflux.py b/terratorch/datamodules/carbonflux.py new file mode 100644 index 00000000..fb2f145f --- /dev/null +++ b/terratorch/datamodules/carbonflux.py @@ -0,0 +1,112 @@ +from collections.abc import Sequence +from typing import Any + +import albumentations as A + +from terratorch.datamodules.generic_multimodal_data_module import MultimodalNormalize +from terratorch.datamodules.generic_multimodal_data_module import wrap_in_compose_is_list +from terratorch.datasets import CarbonFluxNonGeo +from torchgeo.datamodules import NonGeoDataModule +from torchgeo.transforms import AugmentationSequential + +MEANS = { + "image": { + "BLUE": 0.07372144372093026, + "GREEN": 0.10117611215116282, + "RED": 0.11269885680232558, + "NIR": 0.2775572554069766, + "SWIR_1": 0.21387001372093037, + "SWIR_2": 0.14144541145348838 + }, + "merra_vars": [282.373169, 296.706468, 288.852922, 278.612209, 0.540145, + 53.830276, 53.827718, 206.817980, 23.077581, 0.000003], + "mask": 3.668982 +} + +STDS = { + "image": { + "BLUE": 0.13324302628303733, + "GREEN": 0.13308921403475235, + "RED": 0.13829909331863693, + "NIR": 0.12039809083338567, + "SWIR_1": 0.1088096350639653, + "SWIR_2": 0.09366368859284444 + }, + "merra_vars": [9.296960, 11.402008, 10.311107, 8.064209, 0.171909, + 49.945953, 48.907351, 74.591578, 8.746668, 0.000014], + "mask": 3.804261 +} + + +class CarbonFluxNonGeoDataModule(NonGeoDataModule): + """NonGeo datamodule implementation for Landslide4Sense.""" + + def __init__( + self, + data_root: str, + batch_size: int = 4, + num_workers: int = 0, + bands: Sequence[str] = CarbonFluxNonGeo.all_band_names, + train_transform: A.Compose | None | list[A.BasicTransform] = None, + val_transform: A.Compose | None | list[A.BasicTransform] = None, + test_transform: A.Compose | None | list[A.BasicTransform] = None, + aug: AugmentationSequential = None, + no_data_replace: float | None = 0.0001, + use_metadata: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(CarbonFluxNonGeo, batch_size, num_workers, **kwargs) + self.data_root = data_root + + means = { + m: ([MEANS[m][band] for band in bands] if m == "image" else MEANS[m]) + for m in MEANS.keys() + } + stds = { + m: ([STDS[m][band] for band in bands] if m == "image" else STDS[m]) + for m in STDS.keys() + } + self.mask_means = MEANS["mask"] + self.mask_std = STDS["mask"] + self.bands = bands + self.train_transform = wrap_in_compose_is_list(train_transform) + self.val_transform = wrap_in_compose_is_list(val_transform) + self.test_transform = wrap_in_compose_is_list(test_transform) + self.aug = MultimodalNormalize(means, stds) if aug is None else aug + self.no_data_replace = no_data_replace + self.use_metadata = use_metadata + + def setup(self, stage: str) -> None: + if stage in ["fit"]: + self.train_dataset = self.dataset_class( + split="train", + data_root=self.data_root, + transform=self.train_transform, + bands=self.bands, + gpp_mean=self.mask_means, + gpp_std=self.mask_std, + no_data_replace=self.no_data_replace, + use_metadata=self.use_metadata, + ) + if stage in ["fit", "validate"]: + self.val_dataset = self.dataset_class( + split="test", + data_root=self.data_root, + transform=self.val_transform, + bands=self.bands, + gpp_mean=self.mask_means, + gpp_std=self.mask_std, + no_data_replace=self.no_data_replace, + use_metadata=self.use_metadata, + ) + if stage in ["test"]: + self.test_dataset = self.dataset_class( + split="test", + data_root=self.data_root, + transform=self.test_transform, + bands=self.bands, + gpp_mean=self.mask_means, + gpp_std=self.mask_std, + no_data_replace=self.no_data_replace, + use_metadata=self.use_metadata, + ) diff --git a/terratorch/datamodules/forestnet.py b/terratorch/datamodules/forestnet.py new file mode 100644 index 00000000..c78108d5 --- /dev/null +++ b/terratorch/datamodules/forestnet.py @@ -0,0 +1,94 @@ +from collections.abc import Sequence +from typing import Any + +import albumentations as A +import kornia.augmentation as K # noqa: N812 + +from terratorch.datamodules.generic_pixel_wise_data_module import Normalize +from terratorch.datamodules.generic_multimodal_data_module import wrap_in_compose_is_list +from terratorch.datasets import ForestNetNonGeo +from torchgeo.datamodules import NonGeoDataModule +from torchgeo.transforms import AugmentationSequential + +MEANS = { + "BLUE": 19.8680, + "GREEN": 28.1656, + "RED": 14.9309, + "NIR": 82.1076, + "SWIR_1": 39.4819, + "SWIR_2": 17.7241 +} + +STDS = { + "BLUE": 17.4523, + "GREEN": 15.8399, + "RED": 17.9444, + "NIR": 21.4439, + "SWIR_1": 14.4642, + "SWIR_2": 9.9120 +} + + +class ForestNetNonGeoDataModule(NonGeoDataModule): + """NonGeo datamodule implementation for Landslide4Sense.""" + + def __init__( + self, + data_root: str, + batch_size: int = 4, + num_workers: int = 0, + label_map: dict[str, int] = ForestNetNonGeo.default_label_map, + bands: Sequence[str] = ForestNetNonGeo.all_band_names, + train_transform: A.Compose | None | list[A.BasicTransform] = None, + val_transform: A.Compose | None | list[A.BasicTransform] = None, + test_transform: A.Compose | None | list[A.BasicTransform] = None, + fraction: float = 1.0, + aug: AugmentationSequential = None, + use_metadata: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(ForestNetNonGeo, batch_size, num_workers, **kwargs) + self.data_root = data_root + + self.means = [MEANS[b] for b in bands] + self.stds = [STDS[b] for b in bands] + self.label_map = label_map + self.bands = bands + self.train_transform = wrap_in_compose_is_list(train_transform) + self.val_transform = wrap_in_compose_is_list(val_transform) + self.test_transform = wrap_in_compose_is_list(test_transform) + self.aug = Normalize(self.means, self.stds) if aug is None else aug + self.fraction = fraction + self.use_metadata = use_metadata + + def setup(self, stage: str) -> None: + if stage in ["fit"]: + self.train_dataset = self.dataset_class( + split="train", + data_root=self.data_root, + label_map=self.label_map, + transform=self.train_transform, + bands=self.bands, + fraction=self.fraction, + use_metadata=self.use_metadata, + ) + if stage in ["fit", "validate"]: + self.val_dataset = self.dataset_class( + split="val", + data_root=self.data_root, + label_map=self.label_map, + transform=self.val_transform, + bands=self.bands, + fraction=self.fraction, + use_metadata=self.use_metadata, + ) + if stage in ["test"]: + self.test_dataset = self.dataset_class( + split="test", + data_root=self.data_root, + label_map=self.label_map, + transform=self.test_transform, + bands=self.bands, + fraction=self.fraction, + use_metadata=self.use_metadata, + ) diff --git a/terratorch/datamodules/generic_multimodal_data_module.py b/terratorch/datamodules/generic_multimodal_data_module.py new file mode 100644 index 00000000..97a13439 --- /dev/null +++ b/terratorch/datamodules/generic_multimodal_data_module.py @@ -0,0 +1,62 @@ +from collections.abc import Callable, Iterable + +import albumentations as A +import torch + + +def wrap_in_compose_is_list(transform_list, image_modalities=None, non_image_modalities=None): + additional_targets = {} + if image_modalities: + for modality in image_modalities: + additional_targets[modality] = "image" + if non_image_modalities: + # Global label values are ignored and need to be processed separately + for modality in non_image_modalities: + additional_targets[modality] = "global_label" + # set check shapes to false because of the multitemporal case + return A.Compose(transform_list, is_check_shapes=False, additional_targets=additional_targets) \ + if isinstance(transform_list, Iterable) else transform_list + +class MultimodalNormalize(Callable): + def __init__(self, means, stds): + super().__init__() + self.means = means + self.stds = stds + + def __call__(self, batch): + for m in self.means.keys(): + if m not in batch["image"]: + continue + image = batch["image"][m] + if len(image.shape) == 5: + # B, C, T, H, W + means = torch.tensor(self.means[m], device=image.device).view(1, -1, 1, 1, 1) + stds = torch.tensor(self.stds[m], device=image.device).view(1, -1, 1, 1, 1) + elif len(image.shape) == 4: + # B, C, H, W + means = torch.tensor(self.means[m], device=image.device).view(1, -1, 1, 1) + stds = torch.tensor(self.stds[m], device=image.device).view(1, -1, 1, 1) + elif len(self.means[m]) == 1: + # B, (T,) H, W + means = torch.tensor(self.means[m], device=image.device) + stds = torch.tensor(self.stds[m], device=image.device) + elif len(image.shape) == 3: # No batch dim + # C, H, W + means = torch.tensor(self.means[m], device=image.device).view(-1, 1, 1) + stds = torch.tensor(self.stds[m], device=image.device).view(-1, 1, 1) + + elif len(image.shape) == 2: + means = torch.tensor(self.means[m], device=image.device) + stds = torch.tensor(self.stds[m], device=image.device) + + elif len(image.shape) == 1: + means = torch.tensor(self.means[m], device=image.device) + stds = torch.tensor(self.stds[m], device=image.device) + + else: + msg = (f"Expected batch with 5 or 4 dimensions (B, C, (T,) H, W), sample with 3 dimensions (C, H, W) " + f"or a single channel, but got {len(image.shape)}") + raise Exception(msg) + batch["image"][m] = (image - means) / stds + return batch + diff --git a/terratorch/datamodules/landslide4sense.py b/terratorch/datamodules/landslide4sense.py new file mode 100644 index 00000000..0e843907 --- /dev/null +++ b/terratorch/datamodules/landslide4sense.py @@ -0,0 +1,96 @@ +from collections.abc import Sequence +from typing import Any + +import albumentations as A +import kornia.augmentation as K # noqa: N812 +from torchgeo.datamodules import NonGeoDataModule +from torchgeo.transforms import AugmentationSequential + +from terratorch.datamodules.generic_multimodal_data_module import wrap_in_compose_is_list +from terratorch.datasets import Landslide4SenseNonGeo + +MEANS = { + "COASTAL AEROSOL": -0.4914, + "BLUE": -0.3074, + "GREEN": -0.1277, + "RED": -0.0625, + "RED_EDGE_1": 0.0439, + "RED_EDGE_2": 0.0803, + "RED_EDGE_3": 0.0644, + "NIR_BROAD": 0.0802, + "WATER_VAPOR": 0.3000, + "CIRRUS": 0.4082, + "SWIR_1": 0.0823, + "SWIR_2": 0.0516, + "SLOPE": 0.3338, + "DEM": 0.7819, +} + +STDS = { + "COASTAL AEROSOL": 0.9325, + "BLUE": 0.8775, + "GREEN": 0.8860, + "RED": 0.8869, + "RED_EDGE_1": 0.8857, + "RED_EDGE_2": 0.8418, + "RED_EDGE_3": 0.8354, + "NIR_BROAD": 0.8491, + "WATER_VAPOR": 0.9061, + "CIRRUS": 1.6072, + "SWIR_1": 0.8848, + "SWIR_2": 0.9232, + "SLOPE": 0.9018, + "DEM": 1.2913, +} + + +class Landslide4SenseNonGeoDataModule(NonGeoDataModule): + """NonGeo datamodule implementation for Landslide4Sense.""" + + def __init__( + self, + data_root: str, + batch_size: int = 4, + num_workers: int = 0, + bands: Sequence[str] = Landslide4SenseNonGeo.all_band_names, + train_transform: A.Compose | None | list[A.BasicTransform] = None, + val_transform: A.Compose | None | list[A.BasicTransform] = None, + test_transform: A.Compose | None | list[A.BasicTransform] = None, + aug: AugmentationSequential = None, + **kwargs: Any, + ) -> None: + super().__init__(Landslide4SenseNonGeo, batch_size, num_workers, **kwargs) + self.data_root = data_root + + self.means = [MEANS[b] for b in bands] + self.stds = [STDS[b] for b in bands] + self.bands = bands + self.train_transform = wrap_in_compose_is_list(train_transform) + self.val_transform = wrap_in_compose_is_list(val_transform) + self.test_transform = wrap_in_compose_is_list(test_transform) + self.aug = ( + AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug + ) + + def setup(self, stage: str) -> None: + if stage in ["fit"]: + self.train_dataset = self.dataset_class( + split="train", + data_root=self.data_root, + transform=self.train_transform, + bands=self.bands + ) + if stage in ["fit", "validate"]: + self.val_dataset = self.dataset_class( + split="val", + data_root=self.data_root, + transform=self.val_transform, + bands=self.bands + ) + if stage in ["test"]: + self.test_dataset = self.dataset_class( + split="test", + data_root=self.data_root, + transform=self.test_transform, + bands=self.bands + ) diff --git a/terratorch/datamodules/multi_temporal_crop_classification.py b/terratorch/datamodules/multi_temporal_crop_classification.py index 103900be..4957e088 100644 --- a/terratorch/datamodules/multi_temporal_crop_classification.py +++ b/terratorch/datamodules/multi_temporal_crop_classification.py @@ -52,13 +52,13 @@ def __init__( super().__init__(MultiTemporalCropClassification, batch_size, num_workers, **kwargs) self.data_root = data_root - means = [MEANS[b] for b in bands] - stds = [STDS[b] for b in bands] + self.means = [MEANS[b] for b in bands] + self.stds = [STDS[b] for b in bands] self.bands = bands self.train_transform = wrap_in_compose_is_list(train_transform) self.val_transform = wrap_in_compose_is_list(val_transform) self.test_transform = wrap_in_compose_is_list(test_transform) - self.aug = Normalize(means, stds) + self.aug = Normalize(self.means, self.stds) self.drop_last = drop_last self.no_data_replace = no_data_replace self.no_label_replace = no_label_replace diff --git a/terratorch/datamodules/utils.py b/terratorch/datamodules/utils.py index c8bd4e96..fdcbd019 100644 --- a/terratorch/datamodules/utils.py +++ b/terratorch/datamodules/utils.py @@ -1,12 +1,12 @@ # Copyright contributors to the Terratorch project import re -from collections.abc import Iterable +from collections.abc import Callable, Iterable import albumentations as A import numpy as np +import torch -np_str_obj_array_pattern = re.compile(r"[SaUO]") def wrap_in_compose_is_list(transform_list): # set check shapes to false because of the multitemporal case @@ -22,4 +22,26 @@ def check_dataset_stackability(dataset, batch_size) -> bool: print("The batch samples can't be stacked, since they don't have the same dimensions. Setting batch_size=1.") return 1 +class NormalizeWithTimesteps(Callable): + def __init__(self, means, stds): + super().__init__() + self.means = means # (C, T) + self.stds = stds # (C, T) + def __call__(self, batch): + image = batch["image"] + + if len(image.shape) == 5: # (B, T, C, H, W) + means = torch.tensor(self.means, device=image.device).transpose(0, 1).reshape(1, image.shape[1], image.shape[2], 1, 1) + stds = torch.tensor(self.stds, device=image.device).transpose(0, 1).reshape(1, image.shape[1], image.shape[2], 1, 1) + + elif len(image.shape) == 4: # (B, C, H, W) + means = torch.tensor(self.means, device=image.device).mean(dim=1).view(1, image.shape[1], 1, 1) + stds = torch.tensor(self.stds, device=image.device).mean(dim=1).view(1, image.shape[1], 1, 1) + + else: + msg = f"Expected batch to have 5 or 4 dimensions, but got {len(image.shape)}" + raise Exception(msg) + + batch["image"] = (image - means) / stds + return batch diff --git a/terratorch/datasets/__init__.py b/terratorch/datasets/__init__.py index 41694be3..9f7f3bb1 100644 --- a/terratorch/datasets/__init__.py +++ b/terratorch/datasets/__init__.py @@ -34,6 +34,12 @@ from terratorch.datasets.sen4agrinet import Sen4AgriNet from terratorch.datasets.utils import HLSBands +from terratorch.datasets.burn_intensity import BurnIntensityNonGeo +from terratorch.datasets.carbonflux import CarbonFluxNonGeo +from terratorch.datasets.landslide4sense import Landslide4SenseNonGeo +from terratorch.datasets.forestnet import ForestNetNonGeo +from terratorch.datasets.biomassters import BioMasstersNonGeo + # TorchGeo RasterDatasets from terratorch.datasets.wsf import WSF2019, WSFEvolution @@ -45,6 +51,11 @@ "GenericNonGeoPixelwiseRegressionDataset", "GenericNonGeoClassificationDataset", "GenericNonGeoRegressionDataset", + "BurnIntensityNonGeo", + "CarbonFluxNonGeo", + "Landslide4SenseNonGeo", + "BioMasstersNonGeo", + "ForestNetNonGeo", "FireScarsNonGeo", "FireScarsHLS", "FireScarsSegmentationMask", diff --git a/terratorch/datasets/biomassters.py b/terratorch/datasets/biomassters.py new file mode 100644 index 00000000..323fcc6c --- /dev/null +++ b/terratorch/datasets/biomassters.py @@ -0,0 +1,507 @@ +import os +import random +from collections.abc import Sequence +from pathlib import Path +from typing import Any, Union + +import albumentations as A +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import rasterio +import torch +from matplotlib.figure import Figure +from torch import Tensor + +from terratorch.datasets.generic_multimodal_dataset import MultimodalToTensor +from terratorch.datasets.transforms import MultimodalTransforms +from terratorch.datasets.utils import default_transform +from torchgeo.datasets import BioMassters +from torchgeo.datasets.utils import percentile_normalization + + +class BioMasstersNonGeo(BioMassters): + """BioMassters Dataset for Aboveground Biomass prediction. + + Dataset intended for Aboveground Biomass (AGB) prediction + over Finnish forests based on Sentinel 1 and 2 data with + corresponding target AGB mask values generated by Light Detection + and Ranging (LiDAR). + + Dataset Format: + + * .tif files for Sentinel 1 and 2 data + * .tif file for pixel wise AGB target mask + * .csv files for metadata regarding features and targets + + Dataset Features: + + * 13,000 target AGB masks of size (256x256px) + * 12 months of data per target mask + * Sentinel 1 and Sentinel 2 data for each location + * Sentinel 1 available for every month + * Sentinel 2 available for almost every month + (not available for every month due to ESA acquisition halt over the region + during particular periods) + + If you use this dataset in your research, please cite the following paper: + + * https://nascetti-a.github.io/BioMasster/ + + .. versionadded:: 0.5 + """ + + S1_BAND_NAMES = ["VV_Asc", "VH_Asc", "VV_Desc", "VH_Desc", "RVI_Asc", "RVI_Desc"] + S2_BAND_NAMES = [ + "BLUE", + "GREEN", + "RED", + "RED_EDGE_1", + "RED_EDGE_2", + "RED_EDGE_3", + "NIR_BROAD", + "NIR_NARROW", + "SWIR_1", + "SWIR_2", + "CLOUD_PROBABILITY", + ] + + all_band_names = { + "S1": S1_BAND_NAMES, + "S2": S2_BAND_NAMES, + } + + rgb_bands = { + "S1": [], + "S2": ["RED", "GREEN", "BLUE"], + } + + valid_splits = ("train", "test") + valid_sensors = ("S1", "S2") + + BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} + + default_metadata_filename = "The_BioMassters_-_features_metadata.csv.csv" + + def __init__( + self, + root = "data", + split: str = "train", + bands: dict[str, Sequence[str]] | Sequence[str] = BAND_SETS["all"], + transform: A.Compose | None = None, + mask_mean: float | None = 63.4584, + mask_std: float | None = 72.21242, + sensors: Sequence[str] = ["S1", "S2"], + as_time_series: bool = False, + metadata_filename: str = default_metadata_filename, + max_cloud_percentage: float | None = None, + max_red_mean: float | None = None, + include_corrupt: bool = True, + subset: float = 1, + seed: int = 42, + use_four_frames: bool = False + ) -> None: + """Initialize a new instance of BioMassters dataset. + + If ``as_time_series=False`` (the default), each time step becomes its own + sample with the target being shared across multiple samples. + + Args: + root: root directory where dataset can be found + split: train or test split + sensors: which sensors to consider for the sample, Sentinel 1 and/or + Sentinel 2 ('S1', 'S2') + as_time_series: whether or not to return all available + time-steps or just a single one for a given target location + metadata_filename: metadata file to be used + max_cloud_percentage: maximum allowed cloud percentage for images + max_red_mean: maximum allowed red_mean value for images + include_corrupt: whether to include images marked as corrupted + + Raises: + AssertionError: if ``split`` or ``sensors`` is invalid + DatasetNotFoundError: If dataset is not found. + """ + self.root = root + self.sensors = sensors + self.bands = bands + assert ( + split in self.valid_splits + ), f"Please choose one of the valid splits: {self.valid_splits}." + self.split = split + + assert set(sensors).issubset( + set(self.valid_sensors) + ), f"Please choose a subset of valid sensors: {self.valid_sensors}." + + if len(self.sensors) == 1: + sens = self.sensors[0] + self.band_indices = [ + self.all_band_names[sens].index(band) for band in self.bands[sens] + ] + else: + self.band_indices = { + sens: [self.all_band_names[sens].index(band) for band in self.bands[sens]] + for sens in self.sensors + } + + self.mask_mean = mask_mean + self.mask_std = mask_std + self.as_time_series = as_time_series + self.metadata_filename = metadata_filename + self.max_cloud_percentage = max_cloud_percentage + self.max_red_mean = max_red_mean + self.include_corrupt = include_corrupt + self.subset = subset + self.seed = seed + self.use_four_frames = use_four_frames + + self._verify() + + # open metadata csv files + self.df = pd.read_csv(os.path.join(self.root, self.metadata_filename)) + + # Filter sensors + self.df = self.df[self.df["satellite"].isin(self.sensors)] + + # Filter split + self.df = self.df[self.df["split"] == self.split] + + # Optional filtering + self._filter_and_select_data() + + # Optional subsampling + self._random_subsample() + + # generate numerical month from filename since first month is September + # and has numerical index of 0 + self.df["num_month"] = ( + self.df["filename"] + .str.split("_", expand=True)[2] + .str.split(".", expand=True)[0] + .astype(int) + ) + + # Set dataframe index depending on the task for easier indexing + if self.as_time_series: + self.df["num_index"] = self.df.groupby(["chip_id"]).ngroup() + else: + filter_df = ( + self.df.groupby(["chip_id", "month"])["satellite"].count().reset_index() + ) + filter_df = filter_df[ + filter_df["satellite"] == len(self.sensors) + ].drop("satellite", axis=1) + # Guarantee that each sample has corresponding number of images available + self.df = self.df.merge(filter_df, on=["chip_id", "month"], how="inner") + + self.df["num_index"] = self.df.groupby(["chip_id", "month"]).ngroup() + + # Adjust transforms based on the number of sensors + if len(self.sensors) == 1: + self.transform = transform if transform else default_transform + elif transform is None: + self.transform = MultimodalToTensor(self.sensors) + else: + transform = { + s: transform[s] if s in transform else default_transform + for s in self.sensors + } + self.transform = MultimodalTransforms(transform, shared=False) + + if self.use_four_frames: + self._select_4_frames() + + def __len__(self) -> int: + return len(self.df["num_index"].unique()) + + def _load_input(self, filenames: list[Path]) -> Tensor: + """Load the input imagery at the index. + + Args: + filenames: list of filenames corresponding to input + + Returns: + input image + """ + filepaths = [ + os.path.join(self.root, f"{self.split}_features", f) for f in filenames + ] + arr_list = [rasterio.open(fp).read() for fp in filepaths] + + if self.as_time_series: + arr = np.stack(arr_list, axis=0) # (T, C, H, W) + else: + arr = np.concatenate(arr_list, axis=0) + return arr.astype(np.int32) + + def _load_target(self, filename: Path) -> Tensor: + """Load the target mask at the index. + + Args: + filename: filename of target to index + + Returns: + target mask + """ + with rasterio.open(os.path.join(self.root, f"{self.split}_agbm", filename), "r") as src: + arr: np.typing.NDArray[np.float64] = src.read() + + return arr + + def _compute_rvi(self, img: np.ndarray, linear: np.ndarray, sens: str) -> np.ndarray: + """Compute the RVI indices for S1 data.""" + rvi_channels = [] + if self.as_time_series: + if "RVI_Asc" in self.bands[sens]: + try: + vv_asc_index = self.all_band_names["S1"].index("VV_Asc") + vh_asc_index = self.all_band_names["S1"].index("VH_Asc") + except ValueError as e: + msg = f"RVI_Asc needs band: {e}" + raise ValueError(msg) from e + + VV = linear[:, vv_asc_index, :, :] + VH = linear[:, vh_asc_index, :, :] + rvi_asc = 4 * VH / (VV + VH + 1e-6) + rvi_asc = np.expand_dims(rvi_asc, axis=1) + rvi_channels.append(rvi_asc) + if "RVI_Desc" in self.bands[sens]: + try: + vv_desc_index = self.all_band_names["S1"].index("VV_Desc") + vh_desc_index = self.all_band_names["S1"].index("VH_Desc") + except ValueError as e: + msg = f"RVI_Desc needs band: {e}" + raise ValueError(msg) from e + + VV_desc = linear[:, vv_desc_index, :, :] + VH_desc = linear[:, vh_desc_index, :, :] + rvi_desc = 4 * VH_desc / (VV_desc + VH_desc + 1e-6) + rvi_desc = np.expand_dims(rvi_desc, axis=1) + rvi_channels.append(rvi_desc) + if rvi_channels: + rvi_concat = np.concatenate(rvi_channels, axis=1) + img = np.concatenate([img, rvi_concat], axis=1) + else: + if "RVI_Asc" in self.bands[sens]: + if linear.shape[0] < 2: + msg = f"Not enough bands to calculate RVI_Asc. Available bands: {linear.shape[0]}" + raise ValueError(msg) + VV = linear[0] + VH = linear[1] + rvi_asc = 4 * VH / (VV + VH + 1e-6) + rvi_asc = np.expand_dims(rvi_asc, axis=0) + rvi_channels.append(rvi_asc) + if "RVI_Desc" in self.bands[sens]: + if linear.shape[0] < 4: + msg = f"Not enough bands to calculate RVI_Desc. Available bands: {linear.shape[0]}" + raise ValueError(msg) + VV_desc = linear[2] + VH_desc = linear[3] + rvi_desc = 4 * VH_desc / (VV_desc + VH_desc + 1e-6) + rvi_desc = np.expand_dims(rvi_desc, axis=0) + rvi_channels.append(rvi_desc) + if rvi_channels: + rvi_concat = np.concatenate(rvi_channels, axis=0) + img = np.concatenate([linear, rvi_concat], axis=0) + return img + + def _select_4_frames(self): + """Filter the dataset to select only 4 frames per sample.""" + + if "cloud_percentage" in self.df.columns: + self.df = self.df.sort_values(by=["chip_id", "cloud_percentage"]) + else: + self.df = self.df.sort_values(by=["chip_id", "num_month"]) + + self.df = ( + self.df.groupby("chip_id") + .head(4) # Select the first 4 frames per chip + .reset_index(drop=True) + ) + + def _process_sensor_images(self, sens: str, sens_filepaths: list[str]) -> np.ndarray: + """Process images for a given sensor.""" + img = self._load_input(sens_filepaths) + if sens == "S1": + img = img.astype(np.float32) + linear = 10 ** (img / 10) + img = self._compute_rvi(img, linear, sens) + if self.as_time_series: + img = img.transpose(0, 2, 3, 1) # (T, H, W, C) + else: + img = img.transpose(1, 2, 0) # (H, W, C) + if len(self.sensors) == 1: + img = img[..., self.band_indices] + else: + img = img[..., self.band_indices[sens]] + return img + + def __getitem__(self, index: int) -> dict: + sample_df = self.df[self.df["num_index"] == index].copy() + # Sort by satellite and month + sample_df.sort_values( + by=["satellite", "num_month"], inplace=True, ascending=True + ) + + filepaths = sample_df["filename"].tolist() + output = {} + + if len(self.sensors) == 1: + sens = self.sensors[0] + sens_filepaths = [fp for fp in filepaths if sens in fp] + img = self._process_sensor_images(sens, sens_filepaths) + output["image"] = img.astype(np.float32) + else: + for sens in self.sensors: + sens_filepaths = [fp for fp in filepaths if sens in fp] + img = self._process_sensor_images(sens, sens_filepaths) + output[sens] = img.astype(np.float32) + + # Load target + target_filename = sample_df["corresponding_agbm"].unique()[0] + target = np.array(self._load_target(Path(target_filename))) + target = target.transpose(1, 2, 0) + output["mask"] = target + if self.transform: + if len(self.sensors) == 1: + output = self.transform(**output) + else: + output = self.transform(output) + output["mask"] = output["mask"].squeeze().float() + return output + + def _filter_and_select_data(self): + if ( + self.max_cloud_percentage is not None + and "cloud_percentage" in self.df.columns + ): + self.df = self.df[self.df["cloud_percentage"] <= self.max_cloud_percentage] + + if self.max_red_mean is not None and "red_mean" in self.df.columns: + self.df = self.df[self.df["red_mean"] <= self.max_red_mean] + + if not self.include_corrupt and "corrupt_values" in self.df.columns: + self.df = self.df[self.df["corrupt_values"] is False] + + def _random_subsample(self): + if self.split == "train" and self.subset < 1.0: + num_samples = int(len(self.df["num_index"].unique()) * self.subset) + if self.seed is not None: + random.seed(self.seed) + selected_indices = random.sample( + list(self.df["num_index"].unique()), num_samples + ) + self.df = self.df[self.df["num_index"].isin(selected_indices)] + self.df.reset_index(drop=True, inplace=True) + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: str | None = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional suptitle to use for figure + + Returns: + a matplotlib Figure with the rendered sample + """ + # Determine if the sample contains multiple sensors or a single sensor + if isinstance(sample["image"], dict): + ncols = len(self.sensors) + 1 + else: + ncols = 2 # One for the image and one for the mask + + showing_predictions = "prediction" in sample + if showing_predictions: + ncols += 1 + + fig, axs = plt.subplots(1, ncols=ncols, figsize=(5 * ncols, 10)) + + if isinstance(sample["image"], dict): + # Multiple sensors case + for idx, sens in enumerate(self.sensors): + img = sample["image"][sens].numpy() + if self.as_time_series: + # Plot last time step + img = img[:, -1, ...] + if sens == "S2": + img = img[[2, 1, 0], ...].transpose(1, 2, 0) + img = percentile_normalization(img) + else: + co_polarization = img[0] # transmit == receive + cross_polarization = img[1] # transmit != receive + ratio = co_polarization / (cross_polarization + 1e-6) + + co_polarization = np.clip(co_polarization / 0.3, 0, 1) + cross_polarization = np.clip(cross_polarization / 0.05, 0, 1) + ratio = np.clip(ratio / 25, 0, 1) + + img = np.stack( + (co_polarization, cross_polarization, ratio), axis=0 + ) + img = img.transpose(1, 2, 0) # Convert to (H, W, 3) + + axs[idx].imshow(img) + axs[idx].axis("off") + if show_titles: + axs[idx].set_title(sens) + mask_idx = len(self.sensors) + else: + # Single sensor case + sens = self.sensors[0] + img = sample["image"].numpy() + if self.as_time_series: + # Plot last time step + img = img[:, -1, ...] + if sens == "S2": + img = img[[2, 1, 0], ...].transpose(1, 2, 0) + img = percentile_normalization(img) + else: + co_polarization = img[0] # transmit == receive + cross_polarization = img[1] # transmit != receive + ratio = co_polarization / (cross_polarization + 1e-6) + + co_polarization = np.clip(co_polarization / 0.3, 0, 1) + cross_polarization = np.clip(cross_polarization / 0.05, 0, 1) + ratio = np.clip(ratio / 25, 0, 1) + + img = np.stack( + (co_polarization, cross_polarization, ratio), axis=0 + ) + img = img.transpose(1, 2, 0) # Convert to (H, W, 3) + + axs[0].imshow(img) + axs[0].axis("off") + if show_titles: + axs[0].set_title(sens) + mask_idx = 1 + + # Plot target mask + if "mask" in sample: + target = sample["mask"].squeeze() + target_im = axs[mask_idx].imshow(target, cmap="YlGn") + plt.colorbar(target_im, ax=axs[mask_idx], fraction=0.046, pad=0.04) + axs[mask_idx].axis("off") + if show_titles: + axs[mask_idx].set_title("Target") + + # Plot prediction if available + if showing_predictions: + pred_idx = mask_idx + 1 + prediction = sample["prediction"].squeeze() + pred_im = axs[pred_idx].imshow(prediction, cmap="YlGn") + plt.colorbar(pred_im, ax=axs[pred_idx], fraction=0.046, pad=0.04) + axs[pred_idx].axis("off") + if show_titles: + axs[pred_idx].set_title("Prediction") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig diff --git a/terratorch/datasets/burn_intensity.py b/terratorch/datasets/burn_intensity.py new file mode 100644 index 00000000..999cc663 --- /dev/null +++ b/terratorch/datasets/burn_intensity.py @@ -0,0 +1,250 @@ +import os +from collections.abc import Sequence +from pathlib import Path +from typing import Any + +import albumentations as A +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import rioxarray +import torch +from matplotlib.colors import Normalize +from torch import Tensor +from xarray import DataArray + +from terratorch.datasets.utils import default_transform, validate_bands +from torchgeo.datasets import NonGeoDataset + + +class BurnIntensityNonGeo(NonGeoDataset): + """Dataset implementation for Burn Intensity classification.""" + + all_band_names = ( + "BLUE", "GREEN", "RED", "NIR", "SWIR_1", "SWIR_2", + ) + + rgb_bands = ("RED", "GREEN", "BLUE") + + BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} + + class_names = ( + "No burn", + "Unburned to Very Low", + "Low Severity", + "Moderate Severity", + "High Severity" + ) + + CSV_FILES = { + "limited": "BS_files_with_less_than_25_percent_zeros.csv", + "full": "BS_files_raw.csv", + } + + num_classes = 5 + splits = {"train": "train", "val": "val"} + time_steps = ["pre", "during", "post"] + + def __init__( + self, + data_root: str, + split: str = "train", + bands: Sequence[str] = BAND_SETS["all"], + transform: A.Compose | None = None, + use_full_data: bool = True, + no_data_replace: float | None = 0.0001, + no_label_replace: int | None = -1, + use_metadata: bool = False, + ) -> None: + """Initialize the BurnIntensity dataset. + + Args: + data_root (str): Path to the data root directory. + split (str): One of 'train' or 'val'. + bands (Sequence[str]): Bands to output. Defaults to all bands. + transform (Optional[A.Compose]): Albumentations transform to be applied. + use_metadata (bool): Whether to return metadata info (location). + use_full_data (bool): Wheter to use full data or data with less than 25 percent zeros. + no_data_replace (Optional[float]): Value to replace NaNs in images. + no_label_replace (Optional[int]): Value to replace NaNs in labels. + """ + super().__init__() + if split not in self.splits: + msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}." + raise ValueError(msg) + self.split = split + + validate_bands(bands, self.all_band_names) + self.bands = bands + self.band_indices = np.asarray([self.all_band_names.index(b) for b in bands]) + + self.data_root = Path(data_root) + + # Read the CSV file to get the list of cases to include + csv_file_key = "full" if use_full_data else "limited" + csv_path = self.data_root / self.CSV_FILES[csv_file_key] + df = pd.read_csv(csv_path) + casenames = df["Case_Name"].tolist() + + split_file = self.data_root / f"{split}.txt" + with open(split_file) as f: + split_images = [line.strip() for line in f.readlines()] + + split_images = [img for img in split_images if self._extract_casename(img) in casenames] + + # Build the samples list + self.samples = [] + for image_filename in split_images: + image_files = [] + for time_step in self.time_steps: + image_file = self.data_root / time_step / image_filename + image_files.append(str(image_file)) + mask_filename = image_filename.replace("HLS_", "BS_") + mask_file = self.data_root / "pre" / mask_filename + self.samples.append({ + "image_files": image_files, + "mask_file": str(mask_file), + "casename": self._extract_casename(image_filename), + }) + + self.use_metadata = use_metadata + self.no_data_replace = no_data_replace + self.no_label_replace = no_label_replace + + self.transform = transform if transform else default_transform + + def _extract_basename(self, filepath: str) -> str: + """Extract the base filename without extension.""" + return os.path.splitext(os.path.basename(filepath))[0] + + def _extract_casename(self, filename: str) -> str: + """Extract the casename from the filename.""" + basename = self._extract_basename(filename) + # Remove 'HLS_' or 'BS_' prefix + casename = basename.replace("HLS_", "").replace("BS_", "") + return casename + + def __len__(self) -> int: + return len(self.samples) + + def _get_coords(self, image: DataArray) -> torch.Tensor: + pixel_scale = image.rio.resolution() + width, height = image.rio.width, image.rio.height + + left, bottom, right, top = image.rio.bounds() + tie_point_x, tie_point_y = left, top + + center_col = width / 2 + center_row = height / 2 + + center_lon = tie_point_x + (center_col * pixel_scale[0]) + center_lat = tie_point_y - (center_row * pixel_scale[1]) + + lat_lon = np.asarray([center_lat, center_lon]) + return torch.tensor(lat_lon, dtype=torch.float32) + + def __getitem__(self, index: int) -> dict[str, Any]: + sample = self.samples[index] + image_files = sample["image_files"] + mask_file = sample["mask_file"] + + images = [] + for idx, image_file in enumerate(image_files): + image = self._load_file(Path(image_file), nan_replace=self.no_data_replace) + if idx == 0 and self.use_metadata: + location_coords = self._get_coords(image) + image = image.to_numpy() + image = np.moveaxis(image, 0, -1) + image = image[..., self.band_indices] + images.append(image) + + images = np.stack(images, axis=0) # (T, H, W, C) + + output = { + "image": images.astype(np.float32), + "mask": self._load_file(Path(mask_file), nan_replace=self.no_label_replace).to_numpy()[0] + } + + if self.transform: + output = self.transform(**output) + + output["mask"] = output["mask"].long() + if self.use_metadata: + output["location_coords"] = location_coords + + return output + + def _load_file(self, path: Path, nan_replace: float | int | None = None) -> DataArray: + data = rioxarray.open_rasterio(path, masked=True) + if nan_replace is not None: + data = data.fillna(nan_replace) + return data + + + def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Any: + """Plot a sample from the dataset. + + Args: + sample: A sample returned by `__getitem__`. + suptitle: Optional string to use as a suptitle. + + Returns: + A matplotlib Figure with the rendered sample. + """ + num_images = len(self.time_steps) + 2 + if "prediction" in sample: + num_images += 1 + + rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands] + if len(rgb_indices) != 3: + msg = "Dataset doesn't contain some of the RGB bands" + raise ValueError(msg) + + images = sample["image"] # (C, T, H, W) + mask = sample["mask"].numpy() + num_classes = len(np.unique(mask)) + + fig, ax = plt.subplots(1, num_images, figsize=(num_images * 5, 5)) + + for i in range(len(self.time_steps)): + image = images[:, i, :, :] # (C, H, W) + image = np.transpose(image, (1, 2, 0)) # (H, W, C) + rgb_image = image[..., rgb_indices] + rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min() + 1e-8) + rgb_image = np.clip(rgb_image, 0, 1) + ax[i].imshow(rgb_image) + ax[i].axis("off") + ax[i].set_title(f"{self.time_steps[i].capitalize()} Image") + + cmap = plt.get_cmap("jet", num_classes) + norm = Normalize(vmin=0, vmax=num_classes - 1) + + mask_ax_index = len(self.time_steps) + ax[mask_ax_index].imshow(mask, cmap=cmap, norm=norm) + ax[mask_ax_index].axis("off") + ax[mask_ax_index].set_title("Ground Truth Mask") + + if "prediction" in sample: + prediction = sample["prediction"].numpy() + pred_ax_index = mask_ax_index + 1 + ax[pred_ax_index].imshow(prediction, cmap=cmap, norm=norm) + ax[pred_ax_index].axis("off") + ax[pred_ax_index].set_title("Predicted Mask") + + legend_ax_index = -1 + class_names = sample.get("class_names", self.class_names) + positions = np.linspace(0, 1, num_classes) if num_classes > 1 else [0.5] + + legend_handles = [ + mpatches.Patch(color=cmap(pos), label=class_names[i]) + for i, pos in enumerate(positions) + ] + ax[legend_ax_index].legend(handles=legend_handles, loc="center") + ax[legend_ax_index].axis("off") + + if suptitle: + plt.suptitle(suptitle) + + plt.tight_layout() + return fig diff --git a/terratorch/datasets/carbonflux.py b/terratorch/datasets/carbonflux.py new file mode 100644 index 00000000..a7185cd6 --- /dev/null +++ b/terratorch/datasets/carbonflux.py @@ -0,0 +1,247 @@ +import os +import re +from collections.abc import Sequence +from pathlib import Path +from typing import Any + +import albumentations as A +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pyproj +import rioxarray +import torch + +from terratorch.datasets.generic_multimodal_dataset import MultimodalToTensor +from terratorch.datasets.transforms import MultimodalTransforms +from terratorch.datasets.utils import default_transform, validate_bands +from torchgeo.datasets import NonGeoDataset + + +class CarbonFluxNonGeo(NonGeoDataset): + """Dataset for Carbon Flux regression from HLS images and MERRA data.""" + + all_band_names = ( + "BLUE", "GREEN", "RED", "NIR", "SWIR_1", "SWIR_2", + ) + + rgb_bands = ( + "RED", "GREEN", "BLUE", + ) + + merra_var_names = ( + "T2MIN", "T2MAX", "T2MEAN", "TSMDEWMEAN", "GWETROOT", + "LHLAND", "SHLAND", "SWLAND", "PARDFLAND", "PRECTOTLAND" + ) + + splits = {"train": "train", "test": "test"} + + BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} + + metadata_file = "data_train_hls_37sites_v0_1.csv" + + def __init__( + self, + data_root: str, + split: str = "train", + bands: Sequence[str] = BAND_SETS["all"], + transform: A.Compose | None = None, + gpp_mean: float | None = None, + gpp_std: float | None = None, + no_data_replace: float | None = 0.0001, + use_metadata: bool = False, + modalities: Sequence[str] = ("image", "merra_vars") + ) -> None: + """Initialize the CarbonFluxNonGeo dataset. + + Args: + data_root (str): Path to the data root directory. + split (str): 'train' or 'test'. + bands (Sequence[str]): Bands to use. Defaults to all bands. + transform (Optional[A.Compose]): Albumentations transform to be applied. + use_metadata (bool): Whether to return metadata (coordinates and date). + merra_means (Sequence[float]): Means for MERRA data normalization. + merra_stds (Sequence[float]): Standard deviations for MERRA data normalization. + gpp_mean (float): Mean for GPP normalization. + gpp_std (float): Standard deviation for GPP normalization. + no_data_replace (Optional[float]): Value to replace NO_DATA values in images. + """ + super().__init__() + if split not in self.splits: + msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}." + raise ValueError(msg) + + self.split = split + + validate_bands(bands, self.all_band_names) + self.bands = bands + self.band_indices = [self.all_band_names.index(band) for band in bands] + + self.data_root = Path(data_root) + + # Load the CSV file with metadata + csv_file = self.data_root / self.metadata_file + df = pd.read_csv(csv_file) + + # Get list of image filenames in the split directory + image_dir = self.data_root / self.split + image_files = [f.name for f in image_dir.glob("*.tiff")] + + df["Chip"] = df["Chip"].str.replace(".tif$", ".tiff", regex=True) + # Filter the DataFrame to include only rows with 'Chip' in image_files + df = df[df["Chip"].isin(image_files)] + + # Build the samples list + self.samples = [] + for _, row in df.iterrows(): + image_filename = row["Chip"] + image_path = image_dir / image_filename + # MERRA vectors + merra_vars = row[list(self.merra_var_names)].values.astype(np.float32) + # GPP target + gpp = row["GPP"] + + image_path = image_dir / row["Chip"] + merra_vars = row[list(self.merra_var_names)].values.astype(np.float32) + gpp = row["GPP"] + self.samples.append({ + "image_path": str(image_path), + "merra_vars": merra_vars, + "gpp": gpp, + }) + + if gpp_mean is None or gpp_std is None: + msg = "Mean and standard deviation for GPP must be provided." + raise ValueError(msg) + self.gpp_mean = gpp_mean + self.gpp_std = gpp_std + + self.use_metadata = use_metadata + self.modalities = modalities + self.no_data_replace = no_data_replace + + if transform is None: + self.transform = MultimodalToTensor(self.modalities) + else: + transform = {m: transform[m] if m in transform else default_transform + for m in self.modalities} + self.transform = MultimodalTransforms(transform, shared=False) + + def __len__(self) -> int: + return len(self.samples) + + def _load_file(self, path: str, nan_replace: float | int | None = None): + data = rioxarray.open_rasterio(path, masked=True) + if nan_replace is not None: + data = data.fillna(nan_replace) + return data + + def _get_coords(self, image) -> torch.Tensor: + """Extract the center coordinates from the image geospatial metadata.""" + pixel_scale = image.rio.resolution() + width, height = image.rio.width, image.rio.height + + left, bottom, right, top = image.rio.bounds() + tie_point_x, tie_point_y = left, top + + center_col = width / 2 + center_row = height / 2 + + center_lon = tie_point_x + (center_col * pixel_scale[0]) + center_lat = tie_point_y - (center_row * pixel_scale[1]) + + src_crs = image.rio.crs + dst_crs = "EPSG:4326" + + transformer = pyproj.Transformer.from_crs(src_crs, dst_crs, always_xy=True) + lon, lat = transformer.transform(center_lon, center_lat) + + coords = np.array([lat, lon], dtype=np.float32) + return torch.from_numpy(coords) + + def _get_date(self, filename: str) -> torch.Tensor: + """Extract the date from the filename.""" + base_filename = os.path.basename(filename) + pattern = r"HLS\..{3}\.[A-Z0-9]{6}\.(?P\d{7}T\d{6})\..*\.tiff$" + match = re.match(pattern, base_filename) + if not match: + msg = f"Filename {filename} does not match expected pattern." + raise ValueError(msg) + + date_str = match.group("date") + year = int(date_str[:4]) + julian_day = int(date_str[4:7]) + + date_tensor = torch.tensor([year, julian_day], dtype=torch.int32) + return date_tensor + + def __getitem__(self, idx: int) -> dict[str, Any]: + sample = self.samples[idx] + image_path = sample["image_path"] + + image = self._load_file(image_path, nan_replace=self.no_data_replace) + + if self.use_metadata: + location_coords = self._get_coords(image) + temporal_coords = self._get_date(os.path.basename(image_path)) + + image = image.to_numpy() # (C, H, W) + image = image[self.band_indices, ...] + image = np.moveaxis(image, 0, -1) # (H, W, C) + + merra_vars = np.array(sample["merra_vars"]) + target = np.array(sample["gpp"]) + target_norm = (target - self.gpp_mean) / self.gpp_std + target_norm = torch.tensor(target_norm, dtype=torch.float32) + output = { + "image": image.astype(np.float32), + "merra_vars": merra_vars, + } + + if self.transform: + output = self.transform(output) + + output = { + "image": {m: output[m] for m in self.modalities if m in output}, + "mask": target_norm + } + if self.use_metadata: + output["location_coords"] = location_coords + output["temporal_coords"] = temporal_coords + + return output + + def plot(self, sample: dict[str, Any], suptitle: str | None = None) -> Any: + """Plot a sample from the dataset. + + Args: + sample: A sample returned by `__getitem__`. + suptitle: Optional title for the figure. + + Returns: + A matplotlib figure with the rendered sample. + """ + image = sample["image"].numpy() + + image = np.transpose(image, (1, 2, 0)) # (H, W, C) + + rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands] + if len(rgb_indices) != 3: + msg = "Dataset doesn't contain some of the RGB bands" + raise ValueError(msg) + + rgb_image = image[..., rgb_indices] + + rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min() + 1e-8) + rgb_image = np.clip(rgb_image, 0, 1) + + fig, ax = plt.subplots(1, 1, figsize=(6, 6)) + ax.imshow(rgb_image) + ax.axis("off") + ax.set_title("Image") + + if suptitle: + plt.suptitle(suptitle) + + plt.tight_layout() + return fig diff --git a/terratorch/datasets/forestnet.py b/terratorch/datasets/forestnet.py new file mode 100644 index 00000000..86dd37be --- /dev/null +++ b/terratorch/datasets/forestnet.py @@ -0,0 +1,227 @@ +import datetime +import glob +import json +import os +import re +from collections import defaultdict +from collections.abc import Sequence +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +from matplotlib.patches import Rectangle +from PIL import Image +from sklearn.model_selection import StratifiedShuffleSplit +import albumentations as A + +from terratorch.datasets.utils import default_transform, validate_bands +from torchgeo.datasets import NonGeoDataset + + +class ForestNetNonGeo(NonGeoDataset): + """NonGeo dataset implementation for ForestNet.""" + + + all_band_names = ( + "RED", "GREEN", "BLUE", "NIR", "SWIR_1", "SWIR_2" + ) + + rgb_bands = ( + "RED", "GREEN", "BLUE", + ) + + splits = ("train", "test", "val") + + BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} + + default_label_map = { # noqa: RUF012 + "Plantation": 0, + "Smallholder agriculture": 1, + "Grassland shrubland": 2, + "Other": 3, + } + + def __init__( + self, + data_root: str, + split: str = "train", + label_map: dict[str, int] = default_label_map, + transform: A.Compose | None = None, + fraction: float = 1.0, + bands: Sequence[str] = BAND_SETS["all"], + use_metadata: bool = False, + ) -> None: + """ + Initialize the ForestNetNonGeo dataset. + + Args: + data_root (str): Path to the data root directory. + split (str): One of 'train', 'val', or 'test'. + label_map (Dict[str, int]): Mapping from label names to integer labels. + transform: Transformations to be applied to the images. + fraction (float): Fraction of the dataset to use. Defaults to 1.0 (use all data). + """ + super().__init__() + if split not in self.splits: + msg = f"Incorrect split '{split}', please choose one of {list(self.splits)}." + raise ValueError(msg) + self.split = split + + validate_bands(bands, self.all_band_names) + self.bands = bands + self.band_indices = [self.all_band_names.index(b) for b in bands] + + self.use_metadata = use_metadata + + self.data_root = Path(data_root) + self.label_map = label_map + + # Load the CSV file corresponding to the split + csv_file = self.data_root / f"{split}_filtered.csv" + original_df = pd.read_csv(csv_file) + + # Apply stratified sampling if fraction < 1.0 + if fraction < 1.0: + sss = StratifiedShuffleSplit(n_splits=1, test_size=1 - fraction, random_state=47) + stratified_indices, _ = next(sss.split(original_df, original_df["merged_label"])) + self.dataset = original_df.iloc[stratified_indices].reset_index(drop=True) + else: + self.dataset = original_df + + self.transform = transform if transform else default_transform + + def __len__(self) -> int: + return len(self.dataset) + + def _get_coords(self, event_path: Path) -> torch.Tensor: + auxiliary_path = event_path / "auxiliary" + osm_json_path = auxiliary_path / "osm.json" + + with open(osm_json_path) as f: + osm_data = json.load(f) + lat = float(osm_data["closest_city"]["lat"]) + lon = float(osm_data["closest_city"]["lon"]) + lat_lon = np.asarray([lat, lon]) + + return torch.tensor(lat_lon, dtype=torch.float32) + + def _get_dates(self, image_files: list) -> list: + dates = [] + pattern = re.compile(r"(\d{4})_(\d{2})_(\d{2})_cloud_\d+\.(png|npy)") + for img_path in image_files: + match = pattern.search(img_path) + year, month, day = int(match.group(1)), int(match.group(2)), int(match.group(3)) + date_obj = datetime.datetime(year, month, day) # noqa: DTZ001 + julian_day = date_obj.timetuple().tm_yday + date_tensor = torch.tensor([year, julian_day], dtype=torch.int32) + dates.append(date_tensor) + return torch.stack(dates, dim=0) + + def __getitem__(self, index: int): + path = self.data_root / self.dataset["example_path"][index] + label = self.map_label(index) + + visible_images, infrared_images, temporal_coords = self._load_images(path) + + visible_images = np.stack(visible_images, axis=0) + infrared_images = np.stack(infrared_images, axis=0) + merged_images = np.concatenate([visible_images, infrared_images], axis=-1) + merged_images = merged_images[..., self.band_indices] # (T, H, W, 2C) + output = { + "image": merged_images.astype(np.float32) + } + + if self.transform: + output = self.transform(**output) + + if self.use_metadata: + location_coords = self._get_coords(path) + output["location_coords"] = location_coords + output["temporal_coords"] = temporal_coords + + output["label"] = label + + return output + + def _load_images(self, path: str): + """Load visible and infrared images from the given event path""" + visible_image_files = glob.glob(os.path.join(path, "images/visible/*_cloud_*.png")) + infra_image_files = glob.glob(os.path.join(path, "images/infrared/*_cloud_*.npy")) + + selected_visible_images = self.select_images(visible_image_files) + selected_infra_images = self.select_images(infra_image_files) + + dates = None + if self.use_metadata: + dates = self._get_dates(selected_visible_images) + + vis_images = [np.array(Image.open(img)) for img in selected_visible_images] # (T, H, W, C) + inf_images = [np.load(img, allow_pickle=True) for img in selected_infra_images] # (T, H, W, C) + return vis_images, inf_images, dates + + def least_cloudy_image(self, image_files): + pattern = re.compile(r"(\d{4})_\d{2}_\d{2}_cloud_(\d+)\.(png|npy)") + lowest_cloud_images = defaultdict(lambda: {"path": None, "cloud_value": float("inf")}) + + for path in image_files: + match = pattern.search(path) + if match: + year, cloud_value = match.group(1), int(match.group(2)) + if cloud_value < lowest_cloud_images[year]["cloud_value"]: + lowest_cloud_images[year] = {"path": path, "cloud_value": cloud_value} + + return [info["path"] for info in lowest_cloud_images.values()] + + def match_timesteps(self, image_files, selected_images): + if len(selected_images) < 3: + extra_imgs = [img for img in image_files if img not in selected_images] + selected_images += extra_imgs[:3 - len(selected_images)] + + while len(selected_images) < 3: + selected_images.append(selected_images[-1]) + return selected_images[:3] + + def select_images(self, image_files): + selected = self.least_cloudy_image(image_files) + return self.match_timesteps(image_files, selected) + + def map_label(self, index: int) -> torch.Tensor: + """Map the label name to an integer label.""" + label_name = self.dataset["merged_label"][index] + label = self.label_map[label_name] + return label + + def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None): + + num_images = sample["image"].shape[1] + 1 + + rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands] + if len(rgb_indices) != 3: + msg = "Dataset doesn't contain some of the RGB bands" + raise ValueError(msg) + + fig, ax = plt.subplots(1, num_images, figsize=(15, 5)) + + for i in range(sample["image"].shape[1]): + image = sample["image"][:, i, :, :] + if torch.is_tensor(image): + image = image.permute(1, 2, 0).numpy() + rgb_image = image[..., rgb_indices] + rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min() + 1e-8) + rgb_image = np.clip(rgb_image, 0, 1) + ax[i].imshow(rgb_image) + ax[i].axis("off") + ax[i].set_title(f"Timestep {i + 1}") + + legend_handles = [Rectangle((0, 0), 1, 1, color="blue")] + legend_label = [self.label_map.get(sample["label"], "Unknown Label")] + ax[-1].legend(legend_handles, legend_label, loc="center") + ax[-1].axis("off") + + if suptitle: + plt.suptitle(suptitle) + + plt.tight_layout() + return fig diff --git a/terratorch/datasets/generic_multimodal_dataset.py b/terratorch/datasets/generic_multimodal_dataset.py new file mode 100644 index 00000000..5b897c8b --- /dev/null +++ b/terratorch/datasets/generic_multimodal_dataset.py @@ -0,0 +1,23 @@ +import numpy as np +import torch + + +class MultimodalToTensor: + def __init__(self, modalities): + self.modalities = modalities + def __call__(self, d): + new_dict = {} + for k, v in d.items(): + if not isinstance(v, np.ndarray): + new_dict[k] = v + else: + # TODO: This code has hard assumptions on the data structure + if k in self.modalities and len(v.shape) >= 3: # Assuming raster modalities with 3+ dimensions + if len(v.shape) <= 4: + v = np.moveaxis(v, -1, 0) # C, H, W or C, T, H, W + elif len(v.shape) == 5: + v = np.moveaxis(v, -1, 1) # B, C, T, H, W + else: + raise ValueError(f"Unexpected shape for {k}: {v.shape}") + new_dict[k] = torch.from_numpy(v) + return new_dict diff --git a/terratorch/datasets/landslide4sense.py b/terratorch/datasets/landslide4sense.py new file mode 100644 index 00000000..54b71e06 --- /dev/null +++ b/terratorch/datasets/landslide4sense.py @@ -0,0 +1,155 @@ +from collections.abc import Sequence +from pathlib import Path + +import albumentations as A +import h5py +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib import colormaps +from matplotlib.colors import Normalize + +from terratorch.datasets.utils import default_transform, validate_bands +from torchgeo.datasets import NonGeoDataset + + +class Landslide4SenseNonGeo(NonGeoDataset): + """NonGeo dataset implementation for Landslide4Sense.""" + + all_band_names = ( + "COASTAL AEROSOL", + "BLUE", + "GREEN", + "RED", + "RED_EDGE_1", + "RED_EDGE_2", + "RED_EDGE_3", + "NIR_BROAD", + "WATER_VAPOR", + "CIRRUS", + "SWIR_1", + "SWIR_2", + "SLOPE", + "DEM", + ) + + rgb_bands = ("RED", "GREEN", "BLUE") + BAND_SETS = {"all": all_band_names, "rgb": rgb_bands} + + splits = {"train": "train", "val": "validation", "test": "test"} + + + def __init__( + self, + data_root: str, + split: str = "train", + bands: Sequence[str] = BAND_SETS["all"], + transform: A.Compose | None = None, + ) -> None: + """Initialize the Landslide4Sense dataset. + + Args: + data_root (str): Path to the data root directory. + split (str): One of 'train', 'validation', or 'test'. + bands (Sequence[str]): Bands to be used. Defaults to all bands. + transform (A.Compose | None): Albumentations transform to be applied. + Defaults to None, which applies default_transform(). + """ + super().__init__() + + if split not in self.splits: + msg = f"Incorrect split '{split}', please choose one of {list(self.splits.keys())}." + raise ValueError(msg) + split_name = self.splits[split] + self.split = split + + validate_bands(bands, self.all_band_names) + self.bands = bands + self.band_indices = [self.all_band_names.index(b) for b in bands] + + self.data_directory = Path(data_root) + + images_dir = self.data_directory / "images" / split_name + annotations_dir = self.data_directory / "annotations" / split_name + + self.image_files = sorted(images_dir.glob("image_*.h5")) + self.mask_files = sorted(annotations_dir.glob("mask_*.h5")) + + self.transform = transform if transform else default_transform + + def __len__(self) -> int: + return len(self.image_files) + + def __getitem__(self, index: int) -> dict[str, torch.Tensor]: + image_file = self.image_files[index] + mask_file = self.mask_files[index] + + with h5py.File(image_file, "r") as h5file: + image = np.array(h5file["img"])[..., self.band_indices] + + with h5py.File(mask_file, "r") as h5file: + mask = np.array(h5file["mask"]) + + output = {"image": image.astype(np.float32), "mask": mask} + + if self.transform: + output = self.transform(**output) + output["mask"] = output["mask"].long() + + return output + + def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> plt.Figure: + rgb_indices = [self.bands.index(band) for band in self.rgb_bands if band in self.bands] + + if len(rgb_indices) != 3: + msg = "Dataset doesn't contain some of the RGB bands" + raise ValueError(msg) + + image = sample["image"] + mask = sample["mask"].numpy() + if torch.is_tensor(image): + image = image.permute(1, 2, 0).numpy() + + rgb_image = image[:, :, rgb_indices] + + rgb_image = (rgb_image - rgb_image.min(axis=(0, 1))) * (1 / rgb_image.max(axis=(0, 1))) + rgb_image = np.clip(rgb_image, 0, 1) + + num_classes = len(np.unique(mask)) + cmap = colormaps["jet"] + norm = Normalize(vmin=0, vmax=num_classes - 1) + + num_images = 4 if "prediction" in sample else 3 + fig, ax = plt.subplots(1, num_images, figsize=(num_images * 4, 4), tight_layout=True) + + ax[0].imshow(rgb_image) + ax[0].set_title("Image") + ax[0].axis("off") + + ax[1].imshow(mask, cmap=cmap, norm=norm) + ax[1].set_title("Ground Truth Mask") + ax[1].axis("off") + + ax[2].imshow(rgb_image) + ax[2].imshow(mask, cmap=cmap, alpha=0.3, norm=norm) + ax[2].set_title("GT Mask on Image") + ax[2].axis("off") + + if "prediction" in sample: + prediction = sample["prediction"].numpy() + ax[3].imshow(prediction, cmap=cmap, norm=norm) + ax[3].set_title("Predicted Mask") + ax[3].axis("off") + + if sample.get("class_names"): + class_names = sample["class_names"] + legend_handles = [ + mpatches.Patch(color=cmap(i), label=class_names[i]) for i in range(num_classes) + ] + ax[0].legend(handles=legend_handles, bbox_to_anchor=(1.05, 1), loc="upper left") + + if suptitle: + plt.suptitle(suptitle) + + return fig diff --git a/terratorch/datasets/m_forestnet.py b/terratorch/datasets/m_forestnet.py index e65f37b7..cbdba016 100644 --- a/terratorch/datasets/m_forestnet.py +++ b/terratorch/datasets/m_forestnet.py @@ -11,13 +11,13 @@ import pandas as pd import torch from albumentations.pytorch import ToTensorV2 -from torchgeo.datasets import NonGeoDataset from terratorch.datasets.utils import ( clip_image, default_transform, validate_bands, ) +from torchgeo.datasets import NonGeoDataset class MForestNetNonGeo(NonGeoDataset): diff --git a/terratorch/datasets/transforms.py b/terratorch/datasets/transforms.py index c3b608b8..a43f6436 100644 --- a/terratorch/datasets/transforms.py +++ b/terratorch/datasets/transforms.py @@ -1,8 +1,10 @@ # Copyright contributors to the Terratorch project +import albumentations as A +import numpy as np +import torch from albumentations import BasicTransform, Compose, ImageOnlyTransform from einops import rearrange -from torch import Tensor N_DIMS_FOR_TEMPORAL = 4 N_DIMS_FLATTENED_TEMPORAL = 3 @@ -18,6 +20,11 @@ def fn(data): return fn +def default_non_image_transform(array): + if array.dtype in (float, int): + return torch.from_numpy(array) + else: + return array class FlattenTemporalIntoChannels(ImageOnlyTransform): """Flatten the temporal dimension into channels""" @@ -91,3 +98,34 @@ def apply(self, img, **params): def get_transform_init_args_names(self): return "band_indices" + +class MultimodalTransforms: + """Applies albumentations transforms to multiple images""" + def __init__( + self, + transforms: dict | A.Compose, + shared : bool = True, + non_image_modalities: list[str] | None = None, + non_image_transform: object | None = None, + ): + self.transforms = transforms + self.shared = shared + self.non_image_modalities = non_image_modalities + self.non_image_transform = non_image_transform or default_non_image_transform + + def __call__(self, data: dict): + if self.shared: + # albumentations requires a key 'image' and treats all other keys as additional targets + image_modality = list(set(data.keys()) - set(self.non_image_modalities))[0] + data["image"] = data.pop(image_modality) + data = self.transforms(**data) + data[image_modality] = data.pop("image") + + # Process sequence data which is ignored by albumentations as 'global_label' + for modality in self.non_image_modalities: + data[modality] = self.non_image_transform(data[modality]) + else: + # Applies transformations for each modality separate + for key, value in data.items(): + data[key] = self.transforms[key](image=value)["image"] # Only works with image modalities + return data diff --git a/terratorch/models/pixel_wise_model.py b/terratorch/models/pixel_wise_model.py index e82941e1..dc960926 100644 --- a/terratorch/models/pixel_wise_model.py +++ b/terratorch/models/pixel_wise_model.py @@ -90,7 +90,11 @@ def _check_for_single_channel_and_squeeze(x): def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput: """Sequentially pass `x` through model`s encoder, decoder and heads""" self.check_input_shape(x) - input_size = x.shape[-2:] + if isinstance(x, torch.Tensor): + input_size = x.shape[-2:] + elif isinstance(x, dict): + # Multimodal input in passed as dict + input_size = list(x.values())[0].shape[-2:] features = self.encoder(x, **kwargs) ## only for backwards compatibility with pre-neck times. diff --git a/terratorch/tasks/classification_tasks.py b/terratorch/tasks/classification_tasks.py index c1ef6a69..d529c80f 100644 --- a/terratorch/tasks/classification_tasks.py +++ b/terratorch/tasks/classification_tasks.py @@ -139,8 +139,14 @@ def configure_losses(self) -> None: ValueError: If *loss* is invalid. """ loss: str = self.hparams["loss"] + ignore_index = self.hparams["ignore_index"] + + class_weights = ( + torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None + ) if loss == "ce": - self.criterion: nn.Module = nn.CrossEntropyLoss(weight=self.hparams["class_weights"]) + ignore_value = -100 if ignore_index is None else ignore_index + self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=class_weights) elif loss == "bce": self.criterion = nn.BCEWithLogitsLoss() elif loss == "jaccard": diff --git a/terratorch/tasks/regression_tasks.py b/terratorch/tasks/regression_tasks.py index ec5abed0..2ccc9890 100644 --- a/terratorch/tasks/regression_tasks.py +++ b/terratorch/tasks/regression_tasks.py @@ -315,7 +315,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - rest = {k:batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) - self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) + self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0]) y_hat = model_output.output out = y_hat[y != -1] mask = y[y != -1] @@ -326,6 +326,9 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - try: datamodule = self.trainer.datamodule batch["prediction"] = y_hat + if isinstance(batch["image"], dict): + # Multimodal input + batch["image"] = batch["image"][self.trainer.datamodule.rgb_modality] for key in ["image", "mask", "prediction"]: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index fd96d7fa..df1638d8 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -265,7 +265,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - rest = {k:batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) - self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) + self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0]) y_hat_hard = to_segmentation_prediction(model_output) self.val_metrics.update(y_hat_hard, y) @@ -273,6 +273,9 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - try: datamodule = self.trainer.datamodule batch["prediction"] = y_hat_hard + if isinstance(batch["image"], dict): + # Multimodal input + batch["image"] = batch["image"][self.trainer.datamodule.rgb_modality] for key in ["image", "mask", "prediction"]: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] From 564866f6928420c0beeca480ce62314ca88a870a Mon Sep 17 00:00:00 2001 From: Pedro Henrique Conrado Date: Mon, 2 Dec 2024 16:47:59 -0500 Subject: [PATCH 2/2] mudancas temporarias --- terratorch/datamodules/__init__.py | 3 ++ terratorch/datasets/__init__.py | 2 +- terratorch/datasets/fire_scars.py | 2 +- terratorch/datasets/landslide4sense.py | 2 +- .../multi_temporal_crop_classification.py | 41 +++++++++---------- terratorch/datasets/sen1floods11.py | 2 +- 6 files changed, 27 insertions(+), 25 deletions(-) diff --git a/terratorch/datamodules/__init__.py b/terratorch/datamodules/__init__.py index 1c2af9b0..b75da89b 100644 --- a/terratorch/datamodules/__init__.py +++ b/terratorch/datamodules/__init__.py @@ -39,6 +39,9 @@ from terratorch.datamodules.biomassters import BioMasstersNonGeoDataModule from terratorch.datamodules.forestnet import ForestNetNonGeoDataModule +# miscellaneous datamodules +from terratorch.datamodules.openearthmap import OpenEarthMapNonGeoDataModule + # Generic classification datamodule from terratorch.datamodules.sen4map import Sen4MapLucasDataModule diff --git a/terratorch/datasets/__init__.py b/terratorch/datasets/__init__.py index 9f7f3bb1..41cd2fa7 100644 --- a/terratorch/datasets/__init__.py +++ b/terratorch/datasets/__init__.py @@ -50,7 +50,7 @@ "GenericNonGeoSegmentationDataset", "GenericNonGeoPixelwiseRegressionDataset", "GenericNonGeoClassificationDataset", - "GenericNonGeoRegressionDataset", + #"GenericNonGeoRegressionDataset", "BurnIntensityNonGeo", "CarbonFluxNonGeo", "Landslide4SenseNonGeo", diff --git a/terratorch/datasets/fire_scars.py b/terratorch/datasets/fire_scars.py index 7ddad452..f5b65516 100644 --- a/terratorch/datasets/fire_scars.py +++ b/terratorch/datasets/fire_scars.py @@ -198,7 +198,7 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure ax[3].imshow(image) ax[3].imshow(mask, cmap="jet", alpha=0.3, norm=norm) - if prediction: + if "prediction" in sample: ax[4].title.set_text("Predicted Mask") ax[4].imshow(prediction, cmap="jet", norm=norm) diff --git a/terratorch/datasets/landslide4sense.py b/terratorch/datasets/landslide4sense.py index 54b71e06..5c949ded 100644 --- a/terratorch/datasets/landslide4sense.py +++ b/terratorch/datasets/landslide4sense.py @@ -137,7 +137,7 @@ def plot(self, sample: dict[str, torch.Tensor], suptitle: str | None = None) -> ax[2].axis("off") if "prediction" in sample: - prediction = sample["prediction"].numpy() + prediction = sample["prediction"] ax[3].imshow(prediction, cmap=cmap, norm=norm) ax[3].set_title("Predicted Mask") ax[3].axis("off") diff --git a/terratorch/datasets/multi_temporal_crop_classification.py b/terratorch/datasets/multi_temporal_crop_classification.py index 32e5421f..709800d4 100644 --- a/terratorch/datasets/multi_temporal_crop_classification.py +++ b/terratorch/datasets/multi_temporal_crop_classification.py @@ -235,37 +235,35 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure raise ValueError(msg) images = sample["image"] - if not self.expand_temporal_dimension: - images = rearrange(images, "(channels time) h w -> channels time h w", channels=len(self.bands)) + images = images[rgb_indices, ...] # Shape: (T, 3, H, W) - # RGB -> channels-last - images = images[rgb_indices, ...].permute(1, 2, 3, 0).numpy() - mask = sample["mask"].numpy() - - images = [clip_image(img) for img in images] + processed_images = [] + for t in range(self.time_steps): + img = images[t] + img = img.permute(1, 2, 0) + img = img.numpy() + img = clip_image(img) + processed_images.append(img) + mask = sample["mask"].numpy() if "prediction" in sample: - prediction = sample["prediction"] num_images += 1 - else: - prediction = None - fig, ax = plt.subplots(1, num_images, figsize=(12, 5), layout="compressed") - ax[0].axis("off") norm = mpl.colors.Normalize(vmin=0, vmax=self.num_classes - 1) + for i, img in enumerate(processed_images): + ax[i + 1].axis("off") + ax[i + 1].title.set_text(f"T{i}") + ax[i + 1].imshow(img) - for i, img in enumerate(images): - ax[i+1].axis("off") - ax[i+1].title.set_text(f"T{i}") - ax[i+1].imshow(img) - - ax[self.time_steps+1].axis("off") - ax[self.time_steps+1].title.set_text("Ground Truth Mask") - ax[self.time_steps+1].imshow(mask, cmap="jet", norm=norm) + ax[self.time_steps + 1].axis("off") + ax[self.time_steps + 1].title.set_text("Ground Truth Mask") + ax[self.time_steps + 1].imshow(mask, cmap="jet", norm=norm) - if prediction: + if "prediction" in sample: + prediction = sample["prediction"] + ax[self.time_steps + 1].axis("off") ax[self.time_steps+2].title.set_text("Predicted Mask") ax[self.time_steps+2].imshow(prediction, cmap="jet", norm=norm) @@ -274,6 +272,7 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure handles = [Rectangle((0, 0), 1, 1, color=tuple(v for v in c)) for k, c, n in legend_data] labels = [n for k, c, n in legend_data] ax[0].legend(handles, labels, loc="center") + if suptitle is not None: plt.suptitle(suptitle) diff --git a/terratorch/datasets/sen1floods11.py b/terratorch/datasets/sen1floods11.py index e6fe9362..b36965c7 100644 --- a/terratorch/datasets/sen1floods11.py +++ b/terratorch/datasets/sen1floods11.py @@ -228,7 +228,7 @@ def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure ax[3].imshow(image) ax[3].imshow(mask, cmap="jet", alpha=0.3, norm=norm) - if prediction: + if "prediction" in sample: ax[4].title.set_text("Predicted Mask") ax[4].imshow(prediction, cmap="jet", norm=norm)