Skip to content

Commit

Permalink
add missing geobench_data_module module
Browse files Browse the repository at this point in the history
Signed-off-by: Carlos Gomes <[email protected]>
  • Loading branch information
CarlosGomes98 committed Aug 15, 2024
1 parent 9c8bbc0 commit 64384a4
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions terratorch/datamodules/geobench_data_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import Any

import albumentations as A
import kornia.augmentation as K # noqa: N812
import torch
from torchgeo.datamodules import NonGeoDataModule
from torchgeo.transforms import AugmentationSequential

from terratorch.datamodules.utils import wrap_in_compose_is_list


class GeobenchDataModule(NonGeoDataModule):
def __init__(
self,
dataset_class: type,
means: dict[str, float],
stds: dict[str, float],
batch_size: int = 8,
num_workers: int = 0,
data_root: str = "./",
train_transform: A.Compose | None | list[A.BasicTransform] = None,
val_transform: A.Compose | None | list[A.BasicTransform] = None,
test_transform: A.Compose | None | list[A.BasicTransform] = None,
aug: AugmentationSequential = None,
partition: str = "default",
**kwargs: Any,
) -> None:
super().__init__(dataset_class, batch_size, num_workers, **kwargs)

bands = kwargs.get("bands", dataset_class.all_band_names)
self.means = torch.tensor([means[b] for b in bands])
self.stds = torch.tensor([stds[b] for b in bands])
self.train_transform = wrap_in_compose_is_list(train_transform)
self.val_transform = wrap_in_compose_is_list(val_transform)
self.test_transform = wrap_in_compose_is_list(test_transform)
self.data_root = data_root
self.partition = partition
self.aug = (
AugmentationSequential(K.Normalize(self.means, self.stds), data_keys=["image"]) if aug is None else aug
)

def setup(self, stage: str) -> None:
if stage in ["fit"]:
self.train_dataset = self.dataset_class(
split="train",
data_root=self.data_root,
transform=self.train_transform,
partition=self.partition,
**self.kwargs,
)
if stage in ["fit", "validate"]:
self.val_dataset = self.dataset_class(
split="val",
data_root=self.data_root,
transform=self.val_transform,
partition=self.partition,
**self.kwargs,
)
if stage in ["test"]:
self.test_dataset = self.dataset_class(
split="test",
data_root=self.data_root,
transform=self.test_transform,
partition=self.partition,
**self.kwargs,
)

0 comments on commit 64384a4

Please sign in to comment.