Skip to content

Commit

Permalink
Merge pull request #61 from IBM/read_std_mean_from_file
Browse files Browse the repository at this point in the history
Read std mean from file
  • Loading branch information
Joao-L-S-Almeida authored Jul 30, 2024
2 parents 3670c7b + fae2531 commit 9df5f5f
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 10 deletions.
20 changes: 13 additions & 7 deletions terratorch/datamodules/generic_pixel_wise_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
"""
This module contains generic data modules for instantiation at runtime.
"""

import os
from collections.abc import Callable, Iterable
from pathlib import Path
from typing import Any

import numpy as np
import albumentations as A
import kornia.augmentation as K
import torch
Expand All @@ -17,7 +17,7 @@
from torchgeo.transforms import AugmentationSequential

from terratorch.datasets import GenericNonGeoPixelwiseRegressionDataset, GenericNonGeoSegmentationDataset, HLSBands

from terratorch.io.file import load_from_file_or_attribute

def wrap_in_compose_is_list(transform_list):
# set check shapes to false because of the multitemporal case
Expand Down Expand Up @@ -79,8 +79,8 @@ def __init__(
test_data_root: Path,
img_grep: str,
label_grep: str,
means: list[float],
stds: list[float],
means: list[float] | str,
stds: list[float] | str,
num_classes: int,
predict_data_root: Path | None = None,
train_label_data_root: Path | None = None,
Expand Down Expand Up @@ -198,6 +198,9 @@ def __init__(
# K.Normalize(means, stds),
# data_keys=["image"],
# )
means = load_from_file_or_attribute(means)
stds = load_from_file_or_attribute(stds)

self.aug = Normalize(means, stds)

# self.aug = Normalize(means, stds)
Expand Down Expand Up @@ -317,8 +320,8 @@ def __init__(
train_data_root: Path,
val_data_root: Path,
test_data_root: Path,
means: list[float],
stds: list[float],
means: list[float] | str,
stds: list[float] | str,
predict_data_root: Path | None = None,
img_grep: str | None = "*",
label_grep: str | None = "*",
Expand Down Expand Up @@ -430,6 +433,9 @@ def __init__(
# K.Normalize(means, stds),
# data_keys=["image"],
# )
means = load_from_file_or_attribute(means)
stds = load_from_file_or_attribute(stds)

self.aug = Normalize(means, stds)
self.no_data_replace = no_data_replace
self.no_label_replace = no_label_replace
Expand Down
10 changes: 7 additions & 3 deletions terratorch/datamodules/generic_scalar_label_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
HLSBands,
)

from terratorch.io.file import load_from_file_or_attribute

def wrap_in_compose_is_list(transform_list):
# set check shapes to false because of the multitemporal case
return A.Compose(transform_list, is_check_shapes=False) if isinstance(transform_list, Iterable) else transform_list


class Normalize(Callable):
def __init__(self, means, stds):
super().__init__()
Expand Down Expand Up @@ -68,8 +68,8 @@ def __init__(
train_data_root: Path,
val_data_root: Path,
test_data_root: Path,
means: list[float],
stds: list[float],
means: list[float] | str,
stds: list[float] | str,
num_classes: int,
predict_data_root: Path | None = None,
train_split: Path | None = None,
Expand Down Expand Up @@ -166,6 +166,10 @@ def __init__(
# K.Normalize(means, stds),
# data_keys=["image"],
# )

means = load_from_file_or_attribute(means)
stds = load_from_file_or_attribute(stds)

self.aug = Normalize(means, stds)

# self.aug = Normalize(means, stds)
Expand Down
19 changes: 19 additions & 0 deletions terratorch/io/file.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import importlib
from torch import nn
import numpy as np

def open_generic_torch_model(model: type | str = None,
model_kwargs: dict = None,
Expand Down Expand Up @@ -51,3 +52,21 @@ def load_torch_weights(model:nn.Module=None, save_dir: str = None, name: str = N
)

return model

def load_from_file_or_attribute(value: list[float]|str):

if isinstance(value, list):
return value
elif isinstance(value, str): # It can be the path for a file
if os.path.isfile(value):
try:
print(value)
content = np.genfromtxt(value).tolist()
except:
raise Exception(f"File must be txt, but received {value}")
else:
raise Exception(f"The input {value} does not exist or is not a file.")

return content


124 changes: 124 additions & 0 deletions tests/manufactured-finetune_prithvi_swin_B_metrics_from_file.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
# precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: tests/
name: all_ecos_random
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 100
max_epochs: 5
check_val_every_n_epoch: 1
log_every_n_steps: 20
enable_checkpointing: true
default_root_dir: tests/
data:
class_path: GenericNonGeoPixelwiseRegressionDataModule
init_args:
batch_size: 2
num_workers: 4
train_transform:
- class_path: albumentations.HorizontalFlip
init_args:
p: 0.5
- class_path: albumentations.Rotate
init_args:
limit: 30
border_mode: 0 # cv2.BORDER_CONSTANT
value: 0
# mask_value: 1
p: 0.5
- class_path: ToTensorV2
dataset_bands:
- [0, 11]
output_bands:
- [1, 3]
- [4, 6]
rgb_indices:
- 2
- 1
- 0
train_data_root: tests/
train_label_data_root: tests/
val_data_root: tests/
val_label_data_root: tests/
test_data_root: tests/
test_label_data_root: tests/
img_grep: "regression*input*.tif"
label_grep: "regression*label*.tif"
means: tests/means.txt
stds: tests/stds.txt
no_label_replace: -1
no_data_replace: 0

model:
class_path: terratorch.tasks.PixelwiseRegressionTask
init_args:
model_args:
decoder: UperNetDecoder
pretrained: true
backbone: prithvi_swin_B
backbone_pretrained_cfg_overlay:
file: tests/prithvi_swin_B.pt
backbone_drop_path_rate: 0.3
# backbone_window_size: 8
decoder_channels: 256
in_channels: 6
bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 1
head_dropout: 0.5708022831486758
head_final_act: torch.nn.ReLU
head_learned_upscale_layers: 2
loss: rmse
#aux_heads:
# - name: aux_head
# decoder: IdentityDecoder
# decoder_args:
# decoder_out_index: 2
# head_dropout: 0,5
# head_channel_list:
# - 64
# head_final_act: torch.nn.ReLU
#aux_loss:
# aux_head: 0.4
ignore_index: -1
freeze_backbone: true
freeze_decoder: false
model_factory: PrithviModelFactory

# uncomment this block for tiled inference
# tiled_inference_parameters:
# h_crop: 224
# h_stride: 192
# w_crop: 224
# w_stride: 192
# average_patches: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.00013524680528283027
weight_decay: 0.047782217873995426
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss

7 changes: 7 additions & 0 deletions tests/means.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
411.4701
558.54065
815.94025
812.4403
1113.7145
1067.641

7 changes: 7 additions & 0 deletions tests/stds.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
547.36707
898.5121
1020.9082
2665.5352
2340.584
1610.1407

13 changes: 13 additions & 0 deletions tests/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,17 @@ def test_finetune_bands_str(model_name):
# Running the terratorch CLI
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_string.yaml"]
_ = build_lightning_cli(command_list)

@pytest.mark.parametrize("model_name", ["prithvi_swin_B"])
def test_finetune_bands_str(model_name):

model_instance = timm.create_model(model_name)

state_dict = model_instance.state_dict()

torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))

# Running the terratorch CLI
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_metrics_from_file.yaml"]
_ = build_lightning_cli(command_list)

0 comments on commit 9df5f5f

Please sign in to comment.