Skip to content

Commit

Permalink
Merge pull request #68 from IBM/segmentation/tests
Browse files Browse the repository at this point in the history
Allowing to overwrite Prithvi and new segmentation tests.
  • Loading branch information
Joao-L-S-Almeida authored Oct 14, 2024
2 parents a6c06c8 + 7d4b4a5 commit ab6628f
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 5 deletions.
4 changes: 2 additions & 2 deletions terratorch/datasets/generic_pixel_wise_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
if self.expand_temporal_dimension:
image = rearrange(image, "(channels time) h w -> channels time h w", channels=len(self.output_bands))
image = np.moveaxis(image, 0, -1)

if self.filter_indices:
image = image[..., self.filter_indices]
output = {
Expand All @@ -168,7 +168,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
if self.transform:
output = self.transform(**output)
output["filename"] = self.image_files[index]

return output

def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArray:
Expand Down
1 change: 0 additions & 1 deletion terratorch/io/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def load_from_file_or_attribute(value: list[float]|str):
elif isinstance(value, str): # It can be the path for a file
if os.path.isfile(value):
try:
print(value)
content = np.genfromtxt(value).tolist()
except:
raise Exception(f"File must be txt, but received {value}")
Expand Down
3 changes: 3 additions & 0 deletions terratorch/models/backbones/prithvi_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from terratorch.datasets.utils import HLSBands
from terratorch.models.backbones.select_patch_embed_weights import select_patch_embed_weights
from terratorch.models.backbones.swin_encoder_decoder import MMSegSwinTransformer
from terratorch.datasets.utils import generate_bands_intervals

PRETRAINED_BANDS = [
HLSBands.BLUE,
Expand Down Expand Up @@ -172,6 +173,8 @@ def _create_swin_mmseg_transformer(
# the current swin model is not multitemporal
if "num_frames" in kwargs:
kwargs = {k: v for k, v in kwargs.items() if k != "num_frames"}

model_bands = generate_bands_intervals(model_bands)
kwargs["in_chans"] = len(model_bands)

def checkpoint_filter_wrapper_fn(state_dict, model):
Expand Down
8 changes: 7 additions & 1 deletion terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from functools import partial
from pathlib import Path
from collections import defaultdict

from timm.models import FeatureInfo
from timm.models._builder import build_model_with_cfg
Expand All @@ -13,6 +14,7 @@
from terratorch.datasets import HLSBands
from terratorch.models.backbones.select_patch_embed_weights import select_patch_embed_weights
from terratorch.models.backbones.vit_encoder_decoder import TemporalViTEncoder
from terratorch.datasets.utils import generate_bands_intervals

PRETRAINED_BANDS = [
HLSBands.BLUE,
Expand Down Expand Up @@ -80,6 +82,8 @@ def _create_prithvi(
if "features_only" in kwargs:
kwargs = {k: v for k, v in kwargs.items() if k != "features_only"}

model_bands = generate_bands_intervals(model_bands)

kwargs["in_chans"] = len(model_bands)

def checkpoint_filter_wrapper_fn(state_dict, model):
Expand Down Expand Up @@ -138,13 +142,15 @@ def create_prithvi_vit_100(
"norm_layer": partial(nn.LayerNorm, eps=1e-6),
"num_frames": 1,
}

model = _create_prithvi(
model_name,
pretrained=pretrained,
model_bands=bands,
pretrained_bands=pretrained_bands,
**dict(model_args, **kwargs),
**dict(model_args,**kwargs),
)

return model


Expand Down
1 change: 1 addition & 0 deletions terratorch/tasks/loss_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def compute_loss(
If there are auxiliary heads, the main decode head is returned under the key "decode_head".
All other heads are returned with the same key as their name.
"""

loss = self._compute_loss(model_output.output, ground_truth, criterion)
if not model_output.auxiliary_heads:
return {"loss": loss}
Expand Down
135 changes: 135 additions & 0 deletions tests/manufactured-finetune_prithvi_swin_B_segmentation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: cpu
strategy: auto
devices: auto
num_nodes: 1
# precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: tests/
name: all_ecos_random
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 100
max_epochs: 5
check_val_every_n_epoch: 1
log_every_n_steps: 20
enable_checkpointing: true
default_root_dir: tests/
data:
class_path: GenericNonGeoSegmentationDataModule
init_args:
batch_size: 4
num_workers: 4
train_transform:
- class_path: albumentations.HorizontalFlip
init_args:
p: 0.5
- class_path: albumentations.Rotate
init_args:
limit: 30
border_mode: 0 # cv2.BORDER_CONSTANT
value: 0
# mask_value: 1
p: 0.5
- class_path: ToTensorV2
dataset_bands:
- [0, 11]
output_bands:
- [1, 3]
- [4, 6]
rgb_indices:
- 2
- 1
- 0
train_data_root: tests/
train_label_data_root: tests/
val_data_root: tests/
val_label_data_root: tests/
test_data_root: tests/
test_label_data_root: tests/
img_grep: "segmentation*input*.tif"
label_grep: "segmentation*label*.tif"
means:
- 547.36707
- 898.5121
- 1020.9082
- 2665.5352
- 2340.584
- 1610.1407
stds:
- 411.4701
- 558.54065
- 815.94025
- 812.4403
- 1113.7145
- 1067.641
no_label_replace: -1
no_data_replace: 0
num_classes: 2
model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_args:
decoder: UperNetDecoder
pretrained: true
backbone: prithvi_swin_B
backbone_pretrained_cfg_overlay:
file: tests/prithvi_swin_B.pt
backbone_drop_path_rate: 0.3
# backbone_window_size: 8
decoder_channels: 256
in_channels: 6
bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 1
num_classes: 2
head_dropout: 0.5708022831486758
loss: ce
#aux_heads:
# - name: aux_head
# decoder: IdentityDecoder
# decoder_args:
# decoder_out_index: 2
# head_dropout: 0,5
# head_channel_list:
# - 64
# head_final_act: torch.nn.ReLU
#aux_loss:
# aux_head: 0.4
ignore_index: -1
freeze_backbone: true
freeze_decoder: false
model_factory: PrithviModelFactory

# uncomment this block for tiled inference
# tiled_inference_parameters:
# h_crop: 224
# h_stride: 192
# w_crop: 224
# w_stride: 192
# average_patches: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.00013524680528283027
weight_decay: 0.047782217873995426
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss

Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ trainer:
init_args:
monitor: val/loss
patience: 100
max_epochs: 2
max_epochs: 3
check_val_every_n_epoch: 1
log_every_n_steps: 20
enable_checkpointing: true
Expand Down
12 changes: 12 additions & 0 deletions tests/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ def test_finetune_metrics_from_file(model_name):
_ = build_lightning_cli(command_list)

"""
@pytest.mark.parametrize("model_name", ["prithvi_swin_B"])
def test_finetune_segmentation(model_name):
model_instance = timm.create_model(model_name)
state_dict = model_instance.state_dict()
torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))
# Running the terratorch CLI
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_segmentation.yaml"]
@pytest.mark.parametrize("model_name", ["prithvi_swin_B"])
def test_finetune_bands_str(model_name):
# Running the terratorch CLI
Expand Down

0 comments on commit ab6628f

Please sign in to comment.