Skip to content

Commit

Permalink
Merge pull request #54 from IBM/improve/bands_definition
Browse files Browse the repository at this point in the history
Improve/bands definition
  • Loading branch information
Joao-L-S-Almeida authored Jul 26, 2024
2 parents 3b76391 + 1a254e4 commit da480f0
Show file tree
Hide file tree
Showing 7 changed files with 533 additions and 41 deletions.
12 changes: 6 additions & 6 deletions terratorch/datamodules/generic_pixel_wise_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def __init__(
test_split: Path | None = None,
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
dataset_bands: list[HLSBands | int] | None = None,
predict_dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
constant_scale: float = 1,
rgb_indices: list[int] | None = None,
train_transform: A.Compose | None | list[A.BasicTransform] = None,
Expand Down Expand Up @@ -330,9 +330,9 @@ def __init__(
test_split: Path | None = None,
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
dataset_bands: list[HLSBands | int] | None = None,
predict_dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
constant_scale: float = 1,
rgb_indices: list[int] | None = None,
train_transform: A.Compose | None | list[A.BasicTransform] = None,
Expand Down
61 changes: 43 additions & 18 deletions terratorch/datasets/generic_pixel_wise_dataset.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
# Copyright contributors to the Terratorch project

"""Module containing generic dataset classes
"""
"""Module containing generic dataset classes"""

import glob
import os
from abc import ABC
from functools import partial
from pathlib import Path
from typing import Any

import albumentations as A
import matplotlib as mpl
import numpy as np
import rioxarray
import torch
import xarray as xr
from albumentations.pytorch import ToTensorV2
from einops import rearrange
from matplotlib import cm
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from matplotlib.patches import Rectangle
Expand All @@ -43,8 +39,8 @@ def __init__(
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
rgb_indices: list[int] | None = None,
dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
constant_scale: float = 1,
transform: A.Compose | None = None,
no_data_replace: float | None = None,
Expand Down Expand Up @@ -73,8 +69,8 @@ def __init__(
that must be present in file names to be included (as in mmsegmentation), or exact
matches (e.g. eurosat). Defaults to True.
rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.
output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands.
constant_scale (float): Factor to multiply image values by. Defaults to 1.
transform (Albumentations.Compose | None): Albumentations transform to be applied.
Should end with ToTensorV2(). If used through the generic_data_module,
Expand All @@ -88,6 +84,7 @@ def __init__(
expected 0. Defaults to False.
"""
super().__init__()

self.split_file = split

label_data_root = label_data_root if label_data_root is not None else data_root
Expand Down Expand Up @@ -120,19 +117,24 @@ def __init__(
)
self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

self.dataset_bands = dataset_bands
self.output_bands = output_bands
self.dataset_bands = self._generate_bands_intervals(dataset_bands)
self.output_bands = self._generate_bands_intervals(output_bands)

if self.output_bands and not self.dataset_bands:
msg = "If output bands provided, dataset_bands must also be provided"
return Exception(msg) # noqa: PLE0101

# There is a special condition if the bands are defined as simple strings.
if self.output_bands:
if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
msg = "Output bands must be a subset of dataset bands"
raise Exception(msg)

self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

else:
self.filter_indices = None

# If no transform is given, apply only to transform to torch tensor
self.transform = transform if transform else lambda **batch: to_tensor(batch)
# self.transform = transform if transform else ToTensorV2()
Expand All @@ -141,7 +143,7 @@ def __len__(self) -> int:
return len(self.image_files)

def __getitem__(self, index: int) -> dict[str, Any]:
image = self._load_file(self.image_files[index], nan_replace = self.no_data_replace).to_numpy()
image = self._load_file(self.image_files[index], nan_replace=self.no_data_replace).to_numpy()
# to channels last
if self.expand_temporal_dimension:
image = rearrange(image, "(channels time) h w -> channels time h w", channels=len(self.output_bands))
Expand All @@ -151,9 +153,12 @@ def __getitem__(self, index: int) -> dict[str, Any]:
image = image[..., self.filter_indices]
output = {
"image": image.astype(np.float32) * self.constant_scale,
"mask": self._load_file(self.segmentation_mask_files[index], nan_replace = self.no_label_replace).to_numpy()[0],
"mask": self._load_file(self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[
0
],
"filename": self.image_files[index],
}

if self.reduce_zero_label:
output["mask"] -= 1
if self.transform:
Expand All @@ -166,6 +171,26 @@ def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArr
data = data.fillna(nan_replace)
return data

def _generate_bands_intervals(self, bands_intervals: list[int | str | HLSBands | tuple[int]] | None = None):
if bands_intervals is None:
return None
bands = []
for element in bands_intervals:
# if its an interval
if isinstance(element, tuple):
if len(element) != 2: # noqa: PLR2004
msg = "When defining an interval, a tuple of two integers should be passed, defining start and end indices inclusive"
raise Exception(msg)
expanded_element = list(range(element[0], element[1] + 1))
bands.extend(expanded_element)
else:
bands.append(element)
# check the expansion didnt result in duplicate elements
if len(set(bands)) != len(bands):
msg = "Duplicate indices detected. Indices must be unique."
raise Exception(msg)
return bands


class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset):
"""GenericNonGeoSegmentationDataset"""
Expand All @@ -181,8 +206,8 @@ def __init__(
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
rgb_indices: list[str] | None = None,
dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
class_names: list[str] | None = None,
constant_scale: float = 1,
transform: A.Compose | None = None,
Expand Down Expand Up @@ -348,8 +373,8 @@ def __init__(
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
rgb_indices: list[int] | None = None,
dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
constant_scale: float = 1,
transform: A.Compose | None = None,
no_data_replace: float | None = None,
Expand Down
16 changes: 10 additions & 6 deletions terratorch/datasets/generic_scalar_label_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def __init__(
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
rgb_indices: list[int] | None = None,
dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
constant_scale: float = 1,
transform: A.Compose | None = None,
no_data_replace: float = 0,
Expand All @@ -64,8 +64,8 @@ def __init__(
that must be present in file names to be included (as in mmsegmentation), or exact
matches (e.g. eurosat). Defaults to True.
rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.
output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands.
constant_scale (float): Factor to multiply image values by. Defaults to 1.
transform (Albumentations.Compose | None): Albumentations transform to be applied.
Should end with ToTensorV2(). If used through the generic_data_module,
Expand Down Expand Up @@ -110,17 +110,21 @@ def is_valid_file(x):

self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

self.dataset_bands = dataset_bands
self.output_bands = output_bands
self.dataset_bands = self._generate_bands_intervals(dataset_bands)
self.output_bands = self._generate_bands_intervals(output_bands)

if self.output_bands and not self.dataset_bands:
msg = "If output bands provided, dataset_bands must also be provided"
return Exception(msg) # noqa: PLE0101

# There is a special condition if the bands are defined as simple strings.
if self.output_bands:
if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
msg = "Output bands must be a subset of dataset bands"
raise Exception(msg)

self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

else:
self.filter_indices = None
# If no transform is given, apply only to transform to torch tensor
Expand Down
136 changes: 136 additions & 0 deletions tests/manufactured-finetune_prithvi_swin_B_band_interval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
# precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: tests/
name: all_ecos_random
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 100
max_epochs: 5
check_val_every_n_epoch: 1
log_every_n_steps: 20
enable_checkpointing: true
default_root_dir: tests/
data:
class_path: GenericNonGeoPixelwiseRegressionDataModule
init_args:
batch_size: 2
num_workers: 4
train_transform:
- class_path: albumentations.HorizontalFlip
init_args:
p: 0.5
- class_path: albumentations.Rotate
init_args:
limit: 30
border_mode: 0 # cv2.BORDER_CONSTANT
value: 0
# mask_value: 1
p: 0.5
- class_path: ToTensorV2
dataset_bands:
- [0, 11]
output_bands:
- [1, 3]
- [4, 6]
rgb_indices:
- 2
- 1
- 0
train_data_root: tests/
train_label_data_root: tests/
val_data_root: tests/
val_label_data_root: tests/
test_data_root: tests/
test_label_data_root: tests/
img_grep: "regression*input*.tif"
label_grep: "regression*label*.tif"
means:
- 547.36707
- 898.5121
- 1020.9082
- 2665.5352
- 2340.584
- 1610.1407
stds:
- 411.4701
- 558.54065
- 815.94025
- 812.4403
- 1113.7145
- 1067.641
no_label_replace: -1
no_data_replace: 0

model:
class_path: terratorch.tasks.PixelwiseRegressionTask
init_args:
model_args:
decoder: UperNetDecoder
pretrained: true
backbone: prithvi_swin_B
backbone_pretrained_cfg_overlay:
file: tests/prithvi_swin_B.pt
backbone_drop_path_rate: 0.3
# backbone_window_size: 8
decoder_channels: 256
in_channels: 6
bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 1
head_dropout: 0.5708022831486758
head_final_act: torch.nn.ReLU
head_learned_upscale_layers: 2
loss: rmse
#aux_heads:
# - name: aux_head
# decoder: IdentityDecoder
# decoder_args:
# decoder_out_index: 2
# head_dropout: 0,5
# head_channel_list:
# - 64
# head_final_act: torch.nn.ReLU
#aux_loss:
# aux_head: 0.4
ignore_index: -1
freeze_backbone: true
freeze_decoder: false
model_factory: PrithviModelFactory

# uncomment this block for tiled inference
# tiled_inference_parameters:
# h_crop: 224
# h_stride: 192
# w_crop: 224
# w_stride: 192
# average_patches: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.00013524680528283027
weight_decay: 0.047782217873995426
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss

Loading

0 comments on commit da480f0

Please sign in to comment.