Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve/bands definition #54

Merged
merged 22 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
63d4733
Bands could be define by intervals
Joao-L-S-Almeida Jul 18, 2024
8731019
Constructing the bands using the definition by interval
Joao-L-S-Almeida Jul 19, 2024
1dd6650
Extending the supported formats for bands to include list[int]
Joao-L-S-Almeida Jul 19, 2024
a0ce8aa
Extending the supported formats for bands to include list[int]
Joao-L-S-Almeida Jul 19, 2024
295128f
Testing the definition by interval using a dedicated yaml file
Joao-L-S-Almeida Jul 19, 2024
64dcf5d
Special case for bands_list=:None
Joao-L-S-Almeida Jul 19, 2024
e48965f
Basic support to use simple strings to name the bands
Joao-L-S-Almeida Jul 22, 2024
5dba481
Strings are allowed to define bands
Joao-L-S-Almeida Jul 22, 2024
174f2f1
Testing to use strings to define a model
Joao-L-S-Almeida Jul 22, 2024
989bf80
Exception for None inputs
Joao-L-S-Almeida Jul 22, 2024
831c662
Support for str
Joao-L-S-Almeida Jul 22, 2024
de533dd
YAML file for testing string as bands
Joao-L-S-Almeida Jul 22, 2024
7a94a82
This test is no longer required
Joao-L-S-Almeida Jul 22, 2024
0a84c42
Band intervals should be tuples with two entries
Joao-L-S-Almeida Jul 23, 2024
2888252
More compact way to check if the bands are defined by interval
Joao-L-S-Almeida Jul 23, 2024
303bfe8
This warning is not necessary
Joao-L-S-Almeida Jul 23, 2024
25b10c6
Reformatting using black
Joao-L-S-Almeida Jul 23, 2024
10551a0
Minor improvements
Joao-L-S-Almeida Jul 23, 2024
aedca32
Missing imports
Joao-L-S-Almeida Jul 23, 2024
5708e0a
More tests to check if the bands ar properly returned
Joao-L-S-Almeida Jul 24, 2024
3faf7d3
accept mixed band specifications
CarlosGomes98 Jul 26, 2024
1a254e4
improve docstring comments
CarlosGomes98 Jul 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions terratorch/datamodules/generic_pixel_wise_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def __init__(
test_split: Path | None = None,
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
dataset_bands: list[HLSBands | int] | None = None,
predict_dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | list[int] | str] | None = None,
predict_dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
output_bands: list[HLSBands | int | list[int] | str] | None = None,
constant_scale: float = 1,
rgb_indices: list[int] | None = None,
train_transform: A.Compose | None | list[A.BasicTransform] = None,
Expand Down Expand Up @@ -330,9 +330,9 @@ def __init__(
test_split: Path | None = None,
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
dataset_bands: list[HLSBands | int] | None = None,
predict_dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
predict_dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
output_bands: list[HLSBands | int | list[int] | str ] | None = None,
constant_scale: float = 1,
rgb_indices: list[int] | None = None,
train_transform: A.Compose | None | list[A.BasicTransform] = None,
Expand Down
78 changes: 68 additions & 10 deletions terratorch/datasets/generic_pixel_wise_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from abc import ABC
from functools import partial
from pathlib import Path
from typing import Any
from typing import Any, List, Union

import albumentations as A
import matplotlib as mpl
Expand Down Expand Up @@ -43,8 +43,8 @@ def __init__(
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
rgb_indices: list[int] | None = None,
dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
output_bands: list[HLSBands | int | list[int] | str ] | None = None,
constant_scale: float = 1,
transform: A.Compose | None = None,
no_data_replace: float | None = None,
Expand Down Expand Up @@ -88,6 +88,7 @@ def __init__(
expected 0. Defaults to False.
"""
super().__init__()

self.split_file = split

label_data_root = label_data_root if label_data_root is not None else data_root
Expand Down Expand Up @@ -120,19 +121,38 @@ def __init__(
)
self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices

self.dataset_bands = dataset_bands
self.output_bands = output_bands
bands_by_interval = (self._bands_defined_by_interval(bands_list=dataset_bands) and
self._bands_defined_by_interval(bands_list=output_bands))

# If the bands are defined by sub-intervals or not.
if bands_by_interval:
self.dataset_bands = self._generate_bands_intervals(dataset_bands)
self.output_bands = self._generate_bands_intervals(output_bands)
else:
self.dataset_bands = dataset_bands
self.output_bands = output_bands

bands_type = self._bands_as_int_or_str(dataset_bands, output_bands)

if bands_type == str:
UserWarning("When the bands are defined as str, guarantee your input files"+
"are organized by band and all have its specific name.")

if self.output_bands and not self.dataset_bands:
msg = "If output bands provided, dataset_bands must also be provided"
return Exception(msg) # noqa: PLE0101

# There is a special condition if the bands are defined as simple strings.
if self.output_bands:
if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
msg = "Output bands must be a subset of dataset bands"
raise Exception(msg)

self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]

else:
self.filter_indices = None

# If no transform is given, apply only to transform to torch tensor
self.transform = transform if transform else lambda **batch: to_tensor(batch)
# self.transform = transform if transform else ToTensorV2()
Expand All @@ -154,18 +174,56 @@ def __getitem__(self, index: int) -> dict[str, Any]:
"mask": self._load_file(self.segmentation_mask_files[index], nan_replace = self.no_label_replace).to_numpy()[0],
"filename": self.image_files[index],
}

if self.reduce_zero_label:
output["mask"] -= 1
if self.transform:
output = self.transform(**output)
return output

def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArray:
data = rioxarray.open_rasterio(path, masked=True)
if nan_replace is not None:
data = data.fillna(nan_replace)
return data

def _generate_bands_intervals(self, bands_intervals:List[List[int]] = None):
bands = list()
for b_interval in bands_intervals:
bands_sublist = np.arange(b_interval[0], b_interval[1] + 1).astype(int).tolist()
bands.append(bands_sublist)
return sorted(sum(bands, []))

def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type:

band_type = [None, None]
if not dataset_bands and not output_bands:
return None
else:
for b, bands_list in enumerate([dataset_bands, output_bands]):
if all([type(band)==int for band in bands_list]):
band_type[b] = int
elif all([type(band)==str for band in bands_list]):
band_type[b] = str
else:
pass
if band_type.count(band_type[0]) == len(band_type):
return band_type[0]
else:
raise Exception("The bands must be or all str or all int.")

def _bands_defined_by_interval(self, bands_list: list[int] | list[list[int]] = None) -> bool:
if not bands_list:
return False
elif all([type(band)==int or type(band)==str or isinstance(band, HLSBands) for band in bands_list]):
return False
elif all([isinstance(subinterval, list) for subinterval in bands_list]):
if all([type(band)==int for band in sum(bands_list, [])]):
return True
else:
raise Exception(f"Whe using subintervals, the limits must be int.")
else:
raise Exception(f"Excpected List[int] or List[str] or List[List[int]], but received {type(bands_list)}.")

class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset):
"""GenericNonGeoSegmentationDataset"""
Expand All @@ -181,8 +239,8 @@ def __init__(
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
rgb_indices: list[str] | None = None,
dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
output_bands: list[HLSBands | int | list[int] | str ] | None = None,
class_names: list[str] | None = None,
constant_scale: float = 1,
transform: A.Compose | None = None,
Expand Down Expand Up @@ -348,8 +406,8 @@ def __init__(
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
rgb_indices: list[int] | None = None,
dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
output_bands: list[HLSBands | int | list[int] | str ] | None = None,
constant_scale: float = 1,
transform: A.Compose | None = None,
no_data_replace: float | None = None,
Expand Down
136 changes: 136 additions & 0 deletions tests/manufactured-finetune_prithvi_swin_B_band_interval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
# precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: tests/
name: all_ecos_random
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 100
max_epochs: 5
check_val_every_n_epoch: 1
log_every_n_steps: 20
enable_checkpointing: true
default_root_dir: tests/
data:
class_path: GenericNonGeoPixelwiseRegressionDataModule
init_args:
batch_size: 2
num_workers: 4
train_transform:
- class_path: albumentations.HorizontalFlip
init_args:
p: 0.5
- class_path: albumentations.Rotate
init_args:
limit: 30
border_mode: 0 # cv2.BORDER_CONSTANT
value: 0
# mask_value: 1
p: 0.5
- class_path: ToTensorV2
dataset_bands:
- [0, 11]
output_bands:
- [1, 3]
- [4, 6]
rgb_indices:
- 2
- 1
- 0
train_data_root: tests/
train_label_data_root: tests/
val_data_root: tests/
val_label_data_root: tests/
test_data_root: tests/
test_label_data_root: tests/
img_grep: "regression*input*.tif"
label_grep: "regression*label*.tif"
means:
- 547.36707
- 898.5121
- 1020.9082
- 2665.5352
- 2340.584
- 1610.1407
stds:
- 411.4701
- 558.54065
- 815.94025
- 812.4403
- 1113.7145
- 1067.641
no_label_replace: -1
no_data_replace: 0

model:
class_path: terratorch.tasks.PixelwiseRegressionTask
init_args:
model_args:
decoder: UperNetDecoder
pretrained: true
backbone: prithvi_swin_B
backbone_pretrained_cfg_overlay:
file: tests/prithvi_swin_B.pt
backbone_drop_path_rate: 0.3
# backbone_window_size: 8
decoder_channels: 256
in_channels: 6
bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 1
head_dropout: 0.5708022831486758
head_final_act: torch.nn.ReLU
head_learned_upscale_layers: 2
loss: rmse
#aux_heads:
# - name: aux_head
# decoder: IdentityDecoder
# decoder_args:
# decoder_out_index: 2
# head_dropout: 0,5
# head_channel_list:
# - 64
# head_final_act: torch.nn.ReLU
#aux_loss:
# aux_head: 0.4
ignore_index: -1
freeze_backbone: true
freeze_decoder: false
model_factory: PrithviModelFactory

# uncomment this block for tiled inference
# tiled_inference_parameters:
# h_crop: 224
# h_stride: 192
# w_crop: 224
# w_stride: 192
# average_patches: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.00013524680528283027
weight_decay: 0.047782217873995426
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss

Loading
Loading