Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
PedroConrado committed Aug 2, 2024
2 parents 4a29f56 + 6c483a8 commit fa33060
Show file tree
Hide file tree
Showing 19 changed files with 907 additions and 300 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ dependencies = [
"geobench>=1.0.0",
"mlflow>=2.12.1",
# broken due to https://github.com/Lightning-AI/pytorch-lightning/issues/19977
"lightning>=2, <=2.2.5"
"lightning>=2, <=2.2.5",
# see issue #64
"albumentations<=1.4.10"
]

[project.optional-dependencies]
Expand Down
32 changes: 19 additions & 13 deletions terratorch/datamodules/generic_pixel_wise_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
"""
This module contains generic data modules for instantiation at runtime.
"""

import os
from collections.abc import Callable, Iterable
from pathlib import Path
from typing import Any

import numpy as np
import albumentations as A
import kornia.augmentation as K
import torch
Expand All @@ -17,7 +17,7 @@
from torchgeo.transforms import AugmentationSequential

from terratorch.datasets import GenericNonGeoPixelwiseRegressionDataset, GenericNonGeoSegmentationDataset, HLSBands

from terratorch.io.file import load_from_file_or_attribute

def wrap_in_compose_is_list(transform_list):
# set check shapes to false because of the multitemporal case
Expand Down Expand Up @@ -79,8 +79,8 @@ def __init__(
test_data_root: Path,
img_grep: str,
label_grep: str,
means: list[float],
stds: list[float],
means: list[float] | str,
stds: list[float] | str,
num_classes: int,
predict_data_root: Path | None = None,
train_label_data_root: Path | None = None,
Expand All @@ -91,9 +91,9 @@ def __init__(
test_split: Path | None = None,
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
dataset_bands: list[HLSBands | int] | None = None,
predict_dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
constant_scale: float = 1,
rgb_indices: list[int] | None = None,
train_transform: A.Compose | None | list[A.BasicTransform] = None,
Expand Down Expand Up @@ -198,6 +198,9 @@ def __init__(
# K.Normalize(means, stds),
# data_keys=["image"],
# )
means = load_from_file_or_attribute(means)
stds = load_from_file_or_attribute(stds)

self.aug = Normalize(means, stds)

# self.aug = Normalize(means, stds)
Expand Down Expand Up @@ -317,8 +320,8 @@ def __init__(
train_data_root: Path,
val_data_root: Path,
test_data_root: Path,
means: list[float],
stds: list[float],
means: list[float] | str,
stds: list[float] | str,
predict_data_root: Path | None = None,
img_grep: str | None = "*",
label_grep: str | None = "*",
Expand All @@ -330,9 +333,9 @@ def __init__(
test_split: Path | None = None,
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
dataset_bands: list[HLSBands | int] | None = None,
predict_dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
constant_scale: float = 1,
rgb_indices: list[int] | None = None,
train_transform: A.Compose | None | list[A.BasicTransform] = None,
Expand Down Expand Up @@ -430,6 +433,9 @@ def __init__(
# K.Normalize(means, stds),
# data_keys=["image"],
# )
means = load_from_file_or_attribute(means)
stds = load_from_file_or_attribute(stds)

self.aug = Normalize(means, stds)
self.no_data_replace = no_data_replace
self.no_label_replace = no_label_replace
Expand Down
10 changes: 7 additions & 3 deletions terratorch/datamodules/generic_scalar_label_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
HLSBands,
)

from terratorch.io.file import load_from_file_or_attribute

def wrap_in_compose_is_list(transform_list):
# set check shapes to false because of the multitemporal case
return A.Compose(transform_list, is_check_shapes=False) if isinstance(transform_list, Iterable) else transform_list


class Normalize(Callable):
def __init__(self, means, stds):
super().__init__()
Expand Down Expand Up @@ -68,8 +68,8 @@ def __init__(
train_data_root: Path,
val_data_root: Path,
test_data_root: Path,
means: list[float],
stds: list[float],
means: list[float] | str,
stds: list[float] | str,
num_classes: int,
predict_data_root: Path | None = None,
train_split: Path | None = None,
Expand Down Expand Up @@ -166,6 +166,10 @@ def __init__(
# K.Normalize(means, stds),
# data_keys=["image"],
# )

means = load_from_file_or_attribute(means)
stds = load_from_file_or_attribute(stds)

self.aug = Normalize(means, stds)

# self.aug = Normalize(means, stds)
Expand Down
64 changes: 45 additions & 19 deletions terratorch/datasets/generic_pixel_wise_dataset.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
# Copyright contributors to the Terratorch project

"""Module containing generic dataset classes
"""
"""Module containing generic dataset classes"""

import glob
import os
from abc import ABC
from functools import partial
from pathlib import Path
from typing import Any

import albumentations as A
import matplotlib as mpl
import numpy as np
import rioxarray
import torch
import xarray as xr
from albumentations.pytorch import ToTensorV2
from einops import rearrange
from matplotlib import cm
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from matplotlib.patches import Rectangle
Expand All @@ -43,8 +39,8 @@ def __init__(
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
rgb_indices: list[int] | None = None,
dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
constant_scale: float = 1,
transform: A.Compose | None = None,
no_data_replace: float | None = None,
Expand Down Expand Up @@ -73,8 +69,8 @@ def __init__(
that must be present in file names to be included (as in mmsegmentation), or exact
matches (e.g. eurosat). Defaults to True.
rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.
output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands.
constant_scale (float): Factor to multiply image values by. Defaults to 1.
transform (Albumentations.Compose | None): Albumentations transform to be applied.
Should end with ToTensorV2(). If used through the generic_data_module,
Expand All @@ -88,6 +84,7 @@ def __init__(
expected 0. Defaults to False.
"""
super().__init__()

self.split_file = split

label_data_root = label_data_root if label_data_root is not None else data_root
Expand Down Expand Up @@ -120,19 +117,24 @@ def __init__(
)
self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

self.dataset_bands = dataset_bands
self.output_bands = output_bands
self.dataset_bands = self._generate_bands_intervals(dataset_bands)
self.output_bands = self._generate_bands_intervals(output_bands)

if self.output_bands and not self.dataset_bands:
msg = "If output bands provided, dataset_bands must also be provided"
return Exception(msg) # noqa: PLE0101

# There is a special condition if the bands are defined as simple strings.
if self.output_bands:
if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
msg = "Output bands must be a subset of dataset bands"
raise Exception(msg)

self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

else:
self.filter_indices = None

# If no transform is given, apply only to transform to torch tensor
self.transform = transform if transform else lambda **batch: to_tensor(batch)
# self.transform = transform if transform else ToTensorV2()
Expand All @@ -141,7 +143,7 @@ def __len__(self) -> int:
return len(self.image_files)

def __getitem__(self, index: int) -> dict[str, Any]:
image = self._load_file(self.image_files[index], nan_replace = self.no_data_replace).to_numpy()
image = self._load_file(self.image_files[index], nan_replace=self.no_data_replace).to_numpy()
# to channels last
if self.expand_temporal_dimension:
image = rearrange(image, "(channels time) h w -> channels time h w", channels=len(self.output_bands))
Expand All @@ -151,13 +153,17 @@ def __getitem__(self, index: int) -> dict[str, Any]:
image = image[..., self.filter_indices]
output = {
"image": image.astype(np.float32) * self.constant_scale,
"mask": self._load_file(self.segmentation_mask_files[index], nan_replace = self.no_label_replace).to_numpy()[0],
"filename": self.image_files[index],
"mask": self._load_file(self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[
0
]
}

if self.reduce_zero_label:
output["mask"] -= 1
if self.transform:
output = self.transform(**output)
output["filename"] = self.image_files[index]

return output

def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArray:
Expand All @@ -166,6 +172,26 @@ def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArr
data = data.fillna(nan_replace)
return data

def _generate_bands_intervals(self, bands_intervals: list[int | str | HLSBands | tuple[int]] | None = None):
if bands_intervals is None:
return None
bands = []
for element in bands_intervals:
# if its an interval
if isinstance(element, tuple):
if len(element) != 2: # noqa: PLR2004
msg = "When defining an interval, a tuple of two integers should be passed, defining start and end indices inclusive"
raise Exception(msg)
expanded_element = list(range(element[0], element[1] + 1))
bands.extend(expanded_element)
else:
bands.append(element)
# check the expansion didnt result in duplicate elements
if len(set(bands)) != len(bands):
msg = "Duplicate indices detected. Indices must be unique."
raise Exception(msg)
return bands


class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset):
"""GenericNonGeoSegmentationDataset"""
Expand All @@ -181,8 +207,8 @@ def __init__(
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
rgb_indices: list[str] | None = None,
dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
class_names: list[str] | None = None,
constant_scale: float = 1,
transform: A.Compose | None = None,
Expand Down Expand Up @@ -348,8 +374,8 @@ def __init__(
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
rgb_indices: list[int] | None = None,
dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
constant_scale: float = 1,
transform: A.Compose | None = None,
no_data_replace: float | None = None,
Expand Down
23 changes: 13 additions & 10 deletions terratorch/datasets/generic_scalar_label_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def __init__(
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
rgb_indices: list[int] | None = None,
dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
constant_scale: float = 1,
transform: A.Compose | None = None,
no_data_replace: float = 0,
Expand All @@ -64,8 +64,8 @@ def __init__(
that must be present in file names to be included (as in mmsegmentation), or exact
matches (e.g. eurosat). Defaults to True.
rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.
output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands.
constant_scale (float): Factor to multiply image values by. Defaults to 1.
transform (Albumentations.Compose | None): Albumentations transform to be applied.
Should end with ToTensorV2(). If used through the generic_data_module,
Expand Down Expand Up @@ -110,17 +110,21 @@ def is_valid_file(x):

self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

self.dataset_bands = dataset_bands
self.output_bands = output_bands
self.dataset_bands = self._generate_bands_intervals(dataset_bands)
self.output_bands = self._generate_bands_intervals(output_bands)

if self.output_bands and not self.dataset_bands:
msg = "If output bands provided, dataset_bands must also be provided"
return Exception(msg) # noqa: PLE0101

# There is a special condition if the bands are defined as simple strings.
if self.output_bands:
if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
msg = "Output bands must be a subset of dataset bands"
raise Exception(msg)

self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

else:
self.filter_indices = None
# If no transform is given, apply only to transform to torch tensor
Expand All @@ -139,13 +143,12 @@ def __getitem__(self, index: int) -> dict[str, Any]:

output = {
"image": image.astype(np.float32) * self.constant_scale,
"label": label,
"filename": self.samples[index][
0
], # samples is an attribute of ImageFolder. Contains a tuple of (Path, Target)
"label": label, # samples is an attribute of ImageFolder. Contains a tuple of (Path, Target)
}
if self.transforms:
output = self.transforms(**output)
output["filename"] = self.image_files[index]

return output

def _load_file(self, path) -> xr.DataArray:
Expand Down
19 changes: 19 additions & 0 deletions terratorch/io/file.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import importlib
from torch import nn
import numpy as np

def open_generic_torch_model(model: type | str = None,
model_kwargs: dict = None,
Expand Down Expand Up @@ -51,3 +52,21 @@ def load_torch_weights(model:nn.Module=None, save_dir: str = None, name: str = N
)

return model

def load_from_file_or_attribute(value: list[float]|str):

if isinstance(value, list):
return value
elif isinstance(value, str): # It can be the path for a file
if os.path.isfile(value):
try:
print(value)
content = np.genfromtxt(value).tolist()
except:
raise Exception(f"File must be txt, but received {value}")
else:
raise Exception(f"The input {value} does not exist or is not a file.")

return content


Loading

0 comments on commit fa33060

Please sign in to comment.