Skip to content

Commit

Permalink
Merge pull request #264 from PedroConrado/add/new-datasets
Browse files Browse the repository at this point in the history
[WIP] Adds new datasets
  • Loading branch information
romeokienzler authored Dec 4, 2024
2 parents 55acdb9 + 745e973 commit 9754f5c
Show file tree
Hide file tree
Showing 22 changed files with 2,073 additions and 37 deletions.
10 changes: 10 additions & 0 deletions terratorch/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@
from terratorch.datamodules.torchgeo_data_module import TorchGeoDataModule, TorchNonGeoDataModule
from terratorch.datamodules.generic_multimodal_data_module import GenericMultiModalDataModule

from terratorch.datamodules.burn_intensity import BurnIntensityNonGeoDataModule
from terratorch.datamodules.carbonflux import CarbonFluxNonGeoDataModule
from terratorch.datamodules.landslide4sense import Landslide4SenseNonGeoDataModule
from terratorch.datamodules.biomassters import BioMasstersNonGeoDataModule
from terratorch.datamodules.forestnet import ForestNetNonGeoDataModule

# miscellaneous datamodules
from terratorch.datamodules.openearthmap import OpenEarthMapNonGeoDataModule
Expand All @@ -54,6 +59,11 @@
"GenericNonGeoSegmentationDataModule",
"GenericNonGeoClassificationDataModule",
# "GenericNonGeoRegressionDataModule",
"BurnIntensityNonGeoDataModule",
"CarbonFluxNonGeoDataModule",
"Landslide4SenseNonGeoDataModule",
"ForestNetNonGeoDataModule",
"BioMasstersNonGeoDataModule"
"Sen1Floods11NonGeoDataModule",
"Sen4MapLucasDataModule",
"FireScarsNonGeoDataModule",
Expand Down
190 changes: 190 additions & 0 deletions terratorch/datamodules/biomassters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from collections.abc import Sequence
from typing import Any

import albumentations as A
from torch.utils.data import DataLoader

from terratorch.datamodules.generic_multimodal_data_module import MultimodalNormalize, wrap_in_compose_is_list
from terratorch.datamodules.generic_pixel_wise_data_module import Normalize
from terratorch.datasets import BioMasstersNonGeo
from torchgeo.datamodules import NonGeoDataModule
from torchgeo.transforms import AugmentationSequential

MEANS = {
"AGBM": 63.4584,
"S1": {
"VV_Asc": 0.08871397,
"VH_Asc": 0.02172604,
"VV_Desc": 0.08556002,
"VH_Desc": 0.02795591,
"RVI_Asc": 0.75507677,
"RVI_Desc": 0.6600374
},
"S2": {
"BLUE": 1633.0802,
"GREEN": 1610.0035,
"RED": 1599.557,
"RED_EDGE_1": 1916.7083,
"RED_EDGE_2": 2478.8325,
"RED_EDGE_3": 2591.326,
"NIR_BROAD": 2738.5837,
"NIR_NARROW": 2685.8281,
"SWIR_1": 1023.90204,
"SWIR_2": 696.48755,
"CLOUD_PROBABILITY": 21.177078
}
}

STDS = {
"AGBM": 72.21242,
"S1": {
"VV_Asc": 0.16714208,
"VH_Asc": 0.04876742,
"VV_Desc": 0.19260046,
"VH_Desc": 0.10272296,
"RVI_Asc": 0.24945821,
"RVI_Desc": 0.3590119
},
"S2": {
"BLUE": 2499.7146,
"GREEN": 2308.5298,
"RED": 2388.2268,
"RED_EDGE_1": 2389.6375,
"RED_EDGE_2": 2209.6467,
"RED_EDGE_3": 2104.572,
"NIR_BROAD": 2194.209,
"NIR_NARROW": 2031.7762,
"SWIR_1": 934.0556,
"SWIR_2": 759.8444,
"CLOUD_PROBABILITY": 49.352486
}
}

class BioMasstersNonGeoDataModule(NonGeoDataModule):
"""NonGeo datamodule implementation for BioMassters."""

default_metadata_filename = "The_BioMassters_-_features_metadata.csv.csv"

def __init__(
self,
data_root: str,
batch_size: int = 4,
num_workers: int = 0,
bands: dict[str, Sequence[str]] | Sequence[str] = BioMasstersNonGeo.all_band_names,
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
aug: AugmentationSequential = None,
drop_last: bool = True,
sensors: Sequence[str] = ["S1", "S2"],
as_time_series: bool = False,
metadata_filename: str = default_metadata_filename,
max_cloud_percentage: float | None = None,
max_red_mean: float | None = None,
include_corrupt: bool = True,
subset: float = 1,
seed: int = 42,
use_four_frames: bool = False,
**kwargs: Any,
) -> None:
super().__init__(BioMasstersNonGeo, batch_size, num_workers, **kwargs)
self.data_root = data_root
self.sensors = sensors
if isinstance(bands, dict):
self.bands = bands
else:
sens = sensors[0]
self.bands = {sens: bands}

self.means = {}
self.stds = {}
for sensor in self.sensors:
self.means[sensor] = [MEANS[sensor][band] for band in self.bands[sensor]]
self.stds[sensor] = [STDS[sensor][band] for band in self.bands[sensor]]

self.mask_mean = MEANS["AGBM"]
self.mask_std = STDS["AGBM"]
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
if len(sensors) == 1:
self.aug = Normalize(self.means[sensors[0]], self.stds[sensors[0]]) if aug is None else aug
else:
MultimodalNormalize(self.means, self.stds) if aug is None else aug
self.drop_last = drop_last
self.as_time_series = as_time_series
self.metadata_filename = metadata_filename
self.max_cloud_percentage = max_cloud_percentage
self.max_red_mean = max_red_mean
self.include_corrupt = include_corrupt
self.subset = subset
self.seed = seed
self.use_four_frames = use_four_frames

def setup(self, stage: str) -> None:
if stage in ["fit"]:
self.train_dataset = self.dataset_class(
split="train",
root=self.data_root,
transform=self.train_transform,
bands=self.bands,
mask_mean=self.mask_mean,
mask_std=self.mask_std,
sensors=self.sensors,
as_time_series=self.as_time_series,
metadata_filename=self.metadata_filename,
max_cloud_percentage=self.max_cloud_percentage,
max_red_mean=self.max_red_mean,
include_corrupt=self.include_corrupt,
subset=self.subset,
seed=self.seed,
use_four_frames=self.use_four_frames,
)
if stage in ["fit", "validate"]:
self.val_dataset = self.dataset_class(
split="test",
root=self.data_root,
transform=self.val_transform,
bands=self.bands,
mask_mean=self.mask_mean,
mask_std=self.mask_std,
sensors=self.sensors,
as_time_series=self.as_time_series,
metadata_filename=self.metadata_filename,
max_cloud_percentage=self.max_cloud_percentage,
max_red_mean=self.max_red_mean,
include_corrupt=self.include_corrupt,
subset=self.subset,
seed=self.seed,
use_four_frames=self.use_four_frames,
)
if stage in ["test"]:
self.test_dataset = self.dataset_class(
split="test",
root=self.data_root,
transform=self.test_transform,
bands=self.bands,
mask_mean=self.mask_mean,
mask_std=self.mask_std,
sensors=self.sensors,
as_time_series=self.as_time_series,
metadata_filename=self.metadata_filename,
max_cloud_percentage=self.max_cloud_percentage,
max_red_mean=self.max_red_mean,
include_corrupt=self.include_corrupt,
subset=self.subset,
seed=self.seed,
use_four_frames=self.use_four_frames,
)

def _dataloader_factory(self, split: str):
dataset = self._valid_attribute(f"{split}_dataset", "dataset")
batch_size = self._valid_attribute(f"{split}_batch_size", "batch_size")
return DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=split == "train",
num_workers=self.num_workers,
collate_fn=self.collate_fn,
drop_last=split =="train" and self.drop_last,
)
94 changes: 94 additions & 0 deletions terratorch/datamodules/burn_intensity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from collections.abc import Sequence
from typing import Any

import albumentations as A

from terratorch.datamodules.utils import NormalizeWithTimesteps, wrap_in_compose_is_list
from terratorch.datasets import BurnIntensityNonGeo
from torchgeo.datamodules import NonGeoDataModule

MEANS = {
"BLUE": [331.6921, 896.8024, 348.8031],
"GREEN": [555.1077, 1093.9736, 500.2181],
"RED": [605.2513, 1142.7225, 597.9034],
"NIR": [1761.3884, 1890.2156, 1552.0403],
"SWIR_1": [1117.1825, 1408.0839, 1293.0919],
"SWIR_2": [2168.0090, 2270.9753, 1362.1312],
}

STDS = {
"BLUE": [213.0656, 1620.4131, 314.7517],
"GREEN": [273.0910, 1628.4181, 365.6746],
"RED": [414.8322, 1600.7698, 424.8185],
"NIR": [818.7486, 1236.8453, 804.9058],
"SWIR_1": [677.2739, 1153.7432, 795.4156],
"SWIR_2": [612.9131, 1495.8365, 661.6196],
}

class BurnIntensityNonGeoDataModule(NonGeoDataModule):
"""NonGeo datamodule implementation for BurnIntensity."""

def __init__(
self,
data_root: str,
batch_size: int = 4,
num_workers: int = 0,
bands: Sequence[str] = BurnIntensityNonGeo.all_band_names,
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
use_full_data: bool = True,
no_data_replace: float | None = 0.0001,
no_label_replace: int | None = -1,
use_metadata: bool = False,
**kwargs: Any,
) -> None:
super().__init__(BurnIntensityNonGeo, batch_size, num_workers, **kwargs)
self.data_root = data_root

means = [MEANS[b] for b in bands]
stds = [STDS[b] for b in bands]
self.bands = bands
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.aug = NormalizeWithTimesteps(means, stds)
self.use_full_data = use_full_data
self.no_data_replace = no_data_replace
self.no_label_replace = no_label_replace
self.use_metadata = use_metadata

def setup(self, stage: str) -> None:
if stage in ["fit"]:
self.train_dataset = self.dataset_class(
split="train",
data_root=self.data_root,
transform=self.train_transform,
bands=self.bands,
use_full_data=self.use_full_data,
no_data_replace=self.no_data_replace,
no_label_replace=self.no_label_replace,
use_metadata=self.use_metadata,
)
if stage in ["fit", "validate"]:
self.val_dataset = self.dataset_class(
split="val",
data_root=self.data_root,
transform=self.val_transform,
bands=self.bands,
use_full_data=self.use_full_data,
no_data_replace=self.no_data_replace,
no_label_replace=self.no_label_replace,
use_metadata=self.use_metadata,
)
if stage in ["test"]:
self.test_dataset = self.dataset_class(
split="val",
data_root=self.data_root,
transform=self.test_transform,
bands=self.bands,
use_full_data=self.use_full_data,
no_data_replace=self.no_data_replace,
no_label_replace=self.no_label_replace,
use_metadata=self.use_metadata,
)
Loading

0 comments on commit 9754f5c

Please sign in to comment.