Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing to overwrite Prithvi and new segmentation tests. #68

Merged
merged 25 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
10f82fa
Testing backbones for segmentation
Joao-L-S-Almeida Aug 1, 2024
710f79e
Basic test for segmentation fine-tuning tasks
Joao-L-S-Almeida Aug 1, 2024
023caf2
Testing using more samples
Joao-L-S-Almeida Aug 1, 2024
fdaeea3
Logging, remove it after the tests are finished
Joao-L-S-Almeida Aug 1, 2024
278c4d0
Merging updates from main
Joao-L-S-Almeida Aug 7, 2024
3fe4ac6
Removing unnecessary logging
Joao-L-S-Almeida Aug 7, 2024
ee17c96
Logging
Joao-L-S-Almeida Aug 8, 2024
233a5c2
The number of channels can be estimated in different ways according t…
Joao-L-S-Almeida Aug 9, 2024
6bd2bde
terratorch.models.backbones.utils
Joao-L-S-Almeida Aug 9, 2024
52f0c12
HLSBands also need be considered
Joao-L-S-Almeida Aug 9, 2024
3d9b5f5
Better conditional for all the kinds of bands
Joao-L-S-Almeida Aug 9, 2024
5cf5876
Allowing the configuration to overwrite default arguments from the mo…
Joao-L-S-Almeida Aug 9, 2024
ec8125e
Intervals must include the upper extreme
Joao-L-S-Almeida Aug 16, 2024
f4e9c47
Merge pull request #106 from IBM/extend/bands
Joao-L-S-Almeida Aug 20, 2024
cf963a8
Merge pull request #107 from IBM/overwrite_default_prithvi
Joao-L-S-Almeida Aug 20, 2024
7c6a5a2
Merge pull request #126 from IBM/extend/bands
Joao-L-S-Almeida Aug 20, 2024
28a8dfb
Solving conflict in tests/test_finetune.py
Joao-L-S-Almeida Aug 20, 2024
5da9522
Merge branch 'segmentation/tests' of github.com:IBM/terratorch into s…
Joao-L-S-Almeida Aug 20, 2024
8007bdf
Run the tests again
Joao-L-S-Almeida Aug 22, 2024
6055503
Merge branch 'main' into segmentation/tests
Joao-L-S-Almeida Aug 30, 2024
0220ffc
revert modification
Joao-L-S-Almeida Sep 6, 2024
a7aac27
This function is no more required
Joao-L-S-Almeida Sep 6, 2024
b10d892
This function is no more required
Joao-L-S-Almeida Sep 6, 2024
3232145
merging
Joao-L-S-Almeida Sep 11, 2024
7d4b4a5
No more required.
Joao-L-S-Almeida Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions terratorch/datasets/generic_pixel_wise_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
if self.expand_temporal_dimension:
image = rearrange(image, "(channels time) h w -> channels time h w", channels=len(self.output_bands))
image = np.moveaxis(image, 0, -1)

if self.filter_indices:
image = image[..., self.filter_indices]
output = {
Expand All @@ -167,7 +167,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
if self.transform:
output = self.transform(**output)
output["filename"] = self.image_files[index]

return output

def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArray:
Expand Down
1 change: 0 additions & 1 deletion terratorch/io/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def load_from_file_or_attribute(value: list[float]|str):
elif isinstance(value, str): # It can be the path for a file
if os.path.isfile(value):
try:
print(value)
content = np.genfromtxt(value).tolist()
except:
raise Exception(f"File must be txt, but received {value}")
Expand Down
4 changes: 3 additions & 1 deletion terratorch/models/backbones/prithvi_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from terratorch.datasets.utils import HLSBands
from terratorch.models.backbones.prithvi_select_patch_embed_weights import prithvi_select_patch_embed_weights
from terratorch.models.backbones.swin_encoder_decoder import MMSegSwinTransformer
from terratorch.models.backbones.utils import _estimate_in_chans

PRETRAINED_BANDS = [
HLSBands.BLUE,
Expand Down Expand Up @@ -174,7 +175,8 @@ def _create_swin_mmseg_transformer(
# the current swin model is not multitemporal
if "num_frames" in kwargs:
kwargs = {k: v for k, v in kwargs.items() if k != "num_frames"}
kwargs["in_chans"] = len(model_bands)

kwargs["in_chans"] = _estimate_in_chans(model_bands=model_bands)

def checkpoint_filter_wrapper_fn(state_dict, model):
return checkpoint_filter_fn(state_dict, model, pretrained_bands, model_bands)
Expand Down
15 changes: 13 additions & 2 deletions terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from functools import partial
from pathlib import Path
from collections import defaultdict

import torch
from timm.models import FeatureInfo
Expand All @@ -14,6 +15,7 @@
from terratorch.datasets import HLSBands
from terratorch.models.backbones.prithvi_select_patch_embed_weights import prithvi_select_patch_embed_weights
from terratorch.models.backbones.vit_encoder_decoder import TemporalViTEncoder
from terratorch.models.backbones.utils import _estimate_in_chans

PRETRAINED_BANDS = [
HLSBands.BLUE,
Expand Down Expand Up @@ -81,7 +83,7 @@ def _create_prithvi(
if "features_only" in kwargs:
kwargs = {k: v for k, v in kwargs.items() if k != "features_only"}

kwargs["in_chans"] = len(model_bands)
kwargs["in_chans"] = _estimate_in_chans(model_bands=model_bands)

def checkpoint_filter_wrapper_fn(state_dict, model):
return checkpoint_filter_fn(state_dict, model, pretrained_bands, model_bands)
Expand Down Expand Up @@ -139,13 +141,22 @@ def create_prithvi_vit_100(
"norm_layer": partial(nn.LayerNorm, eps=1e-6),
"num_frames": 1,
}

# It is possible to overwrite default parameters using
# config file
kwargs_ = defaultdict()
kwargs_.update(model_args)
kwargs_.update(kwargs)
kwargs_ = dict(kwargs_)

model = _create_prithvi(
model_name,
pretrained=pretrained,
model_bands=bands,
pretrained_bands=pretrained_bands,
**dict(model_args, **kwargs),
**kwargs_,
)

return model


Expand Down
35 changes: 35 additions & 0 deletions terratorch/models/backbones/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from terratorch.datasets import HLSBands

def _are_sublists_of_int(item) -> (bool, bool):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is never used. if so, can we remove it?


if all([isinstance(i, list) for i in item]):
if all([isinstance(i, int) for i in sum(item, [])]):
return True, True
else:
raise Exception(f"It's expected sublists be [int, int], but rceived {model_bands}")
elif len(item) == 2 and type(item[0]) == type(item[1]) == int:
return False, True
else:
return False, False

def _estimate_in_chans(model_bands: list[HLSBands] | list[str] | tuple[int, int] = None) -> int:

# Conditional to deal with the different possible choices for the bands
# Bands as lists of strings or enum
is_sublist, requires_special_eval = _are_sublists_of_int(model_bands)

# Bands as intervals limited by integers
# The bands numbering follows the Python convention (starts with 0)
# and includes the extrema (so the +1 in order to include the last band)
if requires_special_eval:

if is_sublist:
in_chans = sum([i[-1] - i[0] + 1 for i in model_bands])
else:
in_chans = model_bands[-1] - model_bands[0] + 1
else:
in_chans = len(model_bands)

return in_chans


1 change: 1 addition & 0 deletions terratorch/tasks/loss_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def compute_loss(
If there are auxiliary heads, the main decode head is returned under the key "decode_head".
All other heads are returned with the same key as their name.
"""

loss = self._compute_loss(model_output.output, ground_truth, criterion)
if not model_output.auxiliary_heads:
return {"loss": loss}
Expand Down
135 changes: 135 additions & 0 deletions tests/manufactured-finetune_prithvi_swin_B_segmentation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: cpu
strategy: auto
devices: auto
num_nodes: 1
# precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: tests/
name: all_ecos_random
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 100
max_epochs: 5
check_val_every_n_epoch: 1
log_every_n_steps: 20
enable_checkpointing: true
default_root_dir: tests/
data:
class_path: GenericNonGeoSegmentationDataModule
init_args:
batch_size: 4
num_workers: 4
train_transform:
- class_path: albumentations.HorizontalFlip
init_args:
p: 0.5
- class_path: albumentations.Rotate
init_args:
limit: 30
border_mode: 0 # cv2.BORDER_CONSTANT
value: 0
# mask_value: 1
p: 0.5
- class_path: ToTensorV2
dataset_bands:
- [0, 11]
output_bands:
- [1, 3]
- [4, 6]
rgb_indices:
- 2
- 1
- 0
train_data_root: tests/
train_label_data_root: tests/
val_data_root: tests/
val_label_data_root: tests/
test_data_root: tests/
test_label_data_root: tests/
img_grep: "segmentation*input*.tif"
label_grep: "segmentation*label*.tif"
means:
- 547.36707
- 898.5121
- 1020.9082
- 2665.5352
- 2340.584
- 1610.1407
stds:
- 411.4701
- 558.54065
- 815.94025
- 812.4403
- 1113.7145
- 1067.641
no_label_replace: -1
no_data_replace: 0
num_classes: 2
model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_args:
decoder: UperNetDecoder
pretrained: true
backbone: prithvi_swin_B
backbone_pretrained_cfg_overlay:
file: tests/prithvi_swin_B.pt
backbone_drop_path_rate: 0.3
# backbone_window_size: 8
decoder_channels: 256
in_channels: 6
bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 1
num_classes: 2
head_dropout: 0.5708022831486758
loss: ce
#aux_heads:
# - name: aux_head
# decoder: IdentityDecoder
# decoder_args:
# decoder_out_index: 2
# head_dropout: 0,5
# head_channel_list:
# - 64
# head_final_act: torch.nn.ReLU
#aux_loss:
# aux_head: 0.4
ignore_index: -1
freeze_backbone: true
freeze_decoder: false
model_factory: PrithviModelFactory

# uncomment this block for tiled inference
# tiled_inference_parameters:
# h_crop: 224
# h_stride: 192
# w_crop: 224
# w_stride: 192
# average_patches: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.00013524680528283027
weight_decay: 0.047782217873995426
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss

Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ trainer:
init_args:
monitor: val/loss
patience: 100
max_epochs: 5
max_epochs: 3
check_val_every_n_epoch: 1
log_every_n_steps: 20
enable_checkpointing: true
Expand Down
12 changes: 12 additions & 0 deletions tests/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ def test_finetune_metrics_from_file(model_name):
_ = build_lightning_cli(command_list)

"""
@pytest.mark.parametrize("model_name", ["prithvi_swin_B"])
def test_finetune_segmentation(model_name):

model_instance = timm.create_model(model_name)

state_dict = model_instance.state_dict()

torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))

# Running the terratorch CLI
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_segmentation.yaml"]

@pytest.mark.parametrize("model_name", ["prithvi_swin_B"])
def test_finetune_bands_str(model_name):
# Running the terratorch CLI
Expand Down
Loading