Skip to content

Commit

Permalink
merging
Browse files Browse the repository at this point in the history
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
  • Loading branch information
Joao-L-S-Almeida committed Sep 11, 2024
2 parents b10d892 + 5e3d2c5 commit 3232145
Show file tree
Hide file tree
Showing 20 changed files with 347 additions and 172 deletions.
37 changes: 23 additions & 14 deletions examples/confs/sen1floods11_vit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,17 @@ data:
num_workers: 8
constant_scale: 0.0001
dataset_bands:
- RED
- GREEN
- COASTAL_AEROSOL
- BLUE
- GREEN
- RED
- RED_EDGE_1
- RED_EDGE_2
- RED_EDGE_3
- NIR_BROAD
- NIR_NARROW
- WATER_VAPOR
- CIRRUS
- SWIR_1
- SWIR_2
output_bands:
Expand Down Expand Up @@ -57,19 +64,19 @@ data:
no_label_replace: -1
no_data_replace: 0
means:
- 0.107582
- 0.13471393
- 0.12520133
- 0.3236181
- 0.2341743
- 0.15878009
- 0.1412956
- 0.13795798
- 0.12353792
- 0.30902815
- 0.2044958
- 0.11912015
stds:
- 0.07145836
- 0.06783548
- 0.07323416
- 0.09489725
- 0.07938496
- 0.07089546
- 0.07406382
- 0.07370365
- 0.08692279
- 0.11798815
- 0.09772074
- 0.07659938
num_classes: 2

model:
Expand Down Expand Up @@ -123,3 +130,5 @@ lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss


53 changes: 30 additions & 23 deletions examples/confs/sen1floods11_vit_local_ckpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,17 @@ data:
num_workers: 8
constant_scale: 0.0001
dataset_bands:
- RED
- GREEN
- COASTAL_AEROSOL
- BLUE
- GREEN
- RED
- RED_EDGE_1
- RED_EDGE_2
- RED_EDGE_3
- NIR_BROAD
- NIR_NARROW
- WATER_VAPOR
- CIRRUS
- SWIR_1
- SWIR_2
output_bands:
Expand All @@ -41,34 +48,34 @@ data:
- 2
- 1
- 0
train_data_root: <senfloods_root>/senfloods/v1.1/data/flood_events/HandLabeled/S2Hand/
train_label_data_root: <senfloods_root>/senfloods/v1.1/data/flood_events/HandLabeled/LabelHand
val_data_root: <senfloods_root>/senfloods/v1.1/data/flood_events/HandLabeled/S2Hand/
val_label_data_root: <senfloods_root>/senfloods/v1.1/data/flood_events/HandLabeled/LabelHand
test_data_root: <senfloods_root>/senfloods/v1.1/data/flood_events/HandLabeled/S2Hand/
test_label_data_root: <senfloods_root>/senfloods/v1.1/data/flood_events/HandLabeled/LabelHand
train_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/S2Hand/
train_label_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/LabelHand
val_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/S2Hand/
val_label_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/LabelHand
test_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/S2Hand/
test_label_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/LabelHand
# these must be obtained by running terratorch/examples/scripts/convert_sen1floods11_splits.py on the original split csv files
train_split: <senfloods_root>/senfloods/v1.1/splits/flood_handlabeled/flood_train_data.txt
test_split: <senfloods_root>/senfloods/v1.1/splits/flood_handlabeled/flood_test_data.txt
val_split: <senfloods_root>/senfloods/v1.1/splits/flood_handlabeled/flood_valid_data.txt
train_split: <sen1floods11_root>/splits/splits/flood_handlabeled/flood_train_data.txt
test_split: <sen1floods11_root>/splits/splits/flood_handlabeled/flood_test_data.txt
val_split: <sen1floods11_root>/splits/splits/flood_handlabeled/flood_valid_data.txt
img_grep: "*_S2Hand.tif"
label_grep: "*_LabelHand.tif"
no_label_replace: -1
no_data_replace: 0
means:
- 0.107582
- 0.13471393
- 0.12520133
- 0.3236181
- 0.2341743
- 0.15878009
- 0.1412956
- 0.13795798
- 0.12353792
- 0.30902815
- 0.2044958
- 0.11912015
stds:
- 0.07145836
- 0.06783548
- 0.07323416
- 0.09489725
- 0.07938496
- 0.07089546
- 0.07406382
- 0.07370365
- 0.08692279
- 0.11798815
- 0.09772074
- 0.07659938
num_classes: 2
model:
class_path: terratorch.tasks.SemanticSegmentationTask
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ include = ["terratorch*"]

[project]
name = "terratorch"
version = "0.99.1"
version = "0.99.3"
description = "TerraTorch - A model training toolkit for geospatial tasks"
license = { "text" = "Apache License, Version 2.0" }
readme = "README.md"
Expand Down Expand Up @@ -152,7 +152,7 @@ exclude_lines = [
]

[tool.bumpver]
current_version = "0.99.1"
current_version = "0.99.3"
version_pattern = "MAJOR.MINOR.PATCH[PYTAGNUM]"
commit_message = "Bump version {old_version} -> {new_version}"
commit = true
Expand Down
54 changes: 37 additions & 17 deletions terratorch/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging # noqa: I001
import os
import shutil
import warnings
from datetime import timedelta
from pathlib import Path
Expand Down Expand Up @@ -176,6 +177,13 @@ def __init__(
super().__init__(parser, config, config_filename, overwrite, multifile, save_to_log_dir)
set_dumper("deploy_config", clean_config_for_deployment_and_dump)

# Preparing information to save config file to log dir
config_dict = config.as_dict()
self.config_path_original = str(config_dict["config"][0])
_, self.config_file_original = os.path.split(self.config_path_original)

self.deploy_config_file = config_dict["deploy_config_file"]

def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
if self.already_saved:
return
Expand Down Expand Up @@ -207,29 +215,36 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non
# the `log_dir` needs to be created as we rely on the logger to do it usually
# but it hasn't logged anything at this point
fs.makedirs(log_dir, exist_ok=True)

if self.deploy_config_file:
self.parser.save(
self.config, config_path, skip_none=True, overwrite=self.overwrite, multifile=self.multifile
)

if trainer.is_global_zero:
if self.deploy_config_file:
# also save the config that will be deployed
config_name, config_ext = os.path.splitext(self.config_filename)
config_name += "_deploy"
config_name += config_ext
config_path = os.path.join(log_dir, config_name)
self.parser.save(
self.config, config_path, skip_none=True, overwrite=self.overwrite, multifile=self.multifile
self.config,
config_path,
format="deploy_config",
skip_none=True,
overwrite=self.overwrite,
multifile=self.multifile,
)
self.already_saved = True

if trainer.is_global_zero:
# also save the config that will be deployed
config_name, config_ext = os.path.splitext(self.config_filename)
config_name += "_deploy"
config_name += config_ext
config_path = os.path.join(log_dir, config_name)
self.parser.save(
self.config,
config_path,
format="deploy_config",
skip_none=True,
overwrite=self.overwrite,
multifile=self.multifile,
)
self.already_saved = True
config_path_dir, config_path_file = os.path.split(config_path)
self.config_path_new = os.path.join(config_path_dir, self.config_file_original)

# broadcast so that all ranks are in sync on future calls to .setup()
self.already_saved = trainer.strategy.broadcast(self.already_saved)

# Copying config file to log dir
shutil.copyfile(self.config_path_original, self.config_path_new)

class StateDictAwareModelCheckpoint(ModelCheckpoint):
# necessary as we wish to have one model checkpoint with only state dict and one with standard lightning checkpoints
Expand Down Expand Up @@ -284,6 +299,7 @@ class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
parser.add_argument("--predict_output_dir", default=None)
parser.add_argument("--out_dtype", default="int16")
parser.add_argument("--deploy_config_file", type=bool, default=True)

# parser.set_defaults({"trainer.enable_checkpointing": False})

Expand Down Expand Up @@ -314,6 +330,10 @@ def instantiate_classes(self) -> None:
if hasattr(config, "out_dtype"):
self.trainer.out_dtype = config.out_dtype

if hasattr(config, "deploy_config_file"):
self.trainer.deploy_config = config.deploy_config_file


def build_lightning_cli(
args: ArgsType = None,
run=True, # noqa: FBT002
Expand Down
32 changes: 23 additions & 9 deletions terratorch/datamodules/generic_pixel_wise_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from collections.abc import Callable, Iterable
from pathlib import Path
from typing import Any
import numpy as np

import albumentations as A
import kornia.augmentation as K
import numpy as np
import torch
from torch import Tensor
from torch.utils.data import DataLoader
Expand All @@ -19,6 +20,7 @@
from terratorch.datasets import GenericNonGeoPixelwiseRegressionDataset, GenericNonGeoSegmentationDataset, HLSBands
from terratorch.io.file import load_from_file_or_attribute


def wrap_in_compose_is_list(transform_list):
# set check shapes to false because of the multitemporal case
return A.Compose(transform_list, is_check_shapes=False) if isinstance(transform_list, Iterable) else transform_list
Expand Down Expand Up @@ -92,8 +94,9 @@ def __init__(
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
constant_scale: float = 1,
rgb_indices: list[int] | None = None,
train_transform: A.Compose | None | list[A.BasicTransform] = None,
Expand Down Expand Up @@ -133,9 +136,14 @@ def __init__(
allow_substring_split_file (bool, optional): Whether the split files contain substrings
that must be present in file names to be included (as in mmsegmentation), or exact
matches (e.g. eurosat). Defaults to True.
dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
predict_dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
output_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None.
output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
Naming must match that of dataset_bands. Defaults to None.
predict_dataset_bands (list[HLSBands | int] | None): Overwrites dataset_bands
with this value at predict time.
Defaults to None, which does not overwrite.
predict_output_bands (list[HLSBands | int] | None): Overwrites output_bands
with this value at predict time. Defaults to None, which does not overwrite.
constant_scale (float, optional): _description_. Defaults to 1.
rgb_indices (list[int] | None, optional): _description_. Defaults to None.
train_transform (Albumentations.Compose | None): Albumentations transform
Expand Down Expand Up @@ -185,6 +193,7 @@ def __init__(

self.dataset_bands = dataset_bands
self.predict_dataset_bands = predict_dataset_bands if predict_dataset_bands else dataset_bands
self.predict_output_bands = predict_output_bands if predict_output_bands else output_bands
self.output_bands = output_bands
self.rgb_indices = rgb_indices
self.expand_temporal_dimension = expand_temporal_dimension
Expand Down Expand Up @@ -334,9 +343,9 @@ def __init__(
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
predict_output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
constant_scale: float = 1,
rgb_indices: list[int] | None = None,
train_transform: A.Compose | None | list[A.BasicTransform] = None,
Expand Down Expand Up @@ -375,9 +384,14 @@ def __init__(
allow_substring_split_file (bool, optional): Whether the split files contain substrings
that must be present in file names to be included (as in mmsegmentation), or exact
matches (e.g. eurosat). Defaults to True.
dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
predict_dataset_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
output_bands (list[HLSBands | int] | None, optional): _description_. Defaults to None.
dataset_bands (list[HLSBands | int] | None): Bands present in the dataset. Defaults to None.
output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
Naming must match that of dataset_bands. Defaults to None.
predict_dataset_bands (list[HLSBands | int] | None): Overwrites dataset_bands
with this value at predict time.
Defaults to None, which does not overwrite.
predict_output_bands (list[HLSBands | int] | None): Overwrites output_bands
with this value at predict time. Defaults to None, which does not overwrite.
constant_scale (float, optional): _description_. Defaults to 1.
rgb_indices (list[int] | None, optional): _description_. Defaults to None.
train_transform (Albumentations.Compose | None): Albumentations transform
Expand Down
1 change: 1 addition & 0 deletions terratorch/datasets/generic_pixel_wise_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(
# self.transform = transform if transform else ToTensorV2()

import warnings

import rasterio
warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

Expand Down
29 changes: 23 additions & 6 deletions terratorch/models/backbones/prithvi_select_patch_embed_weights.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
# Copyright contributors to the Terratorch project

import logging
import warnings

import torch
from torch import nn

from terratorch.datasets import HLSBands


def patch_embed_weights_are_compatible(model_patch_embed: torch.Tensor, checkpoint_patch_embed: torch.Tensor) -> bool:
# check all dimensions are the same except for channel dimension
if len(model_patch_embed.shape) != len(checkpoint_patch_embed.shape):
return False
model_shape = [model_patch_embed.shape[i] for i in range(len(model_patch_embed.shape)) if i != 1]
checkpoint_shape = [checkpoint_patch_embed.shape[i] for i in range(len(checkpoint_patch_embed.shape)) if i != 1]
return model_shape == checkpoint_shape

def prithvi_select_patch_embed_weights(
state_dict: dict, model: nn.Module, pretrained_bands: list[HLSBands | int], model_bands: list[HLSBands | int]
) -> dict:
Expand Down Expand Up @@ -44,10 +52,19 @@ def prithvi_select_patch_embed_weights(
patch_embed_weight = state_dict[patch_embed_proj_weight_key]

temp_weight = model.state_dict()[patch_embed_proj_weight_key].clone()
torch.nn.init.xavier_uniform_(temp_weight.view([temp_weight.shape[0], -1]))
for index, band in enumerate(model_bands):
if band in pretrained_bands:
temp_weight[:, index] = patch_embed_weight[:, pretrained_bands.index(band)]

# only do this if the patch size and tubelet size match. If not, start with random weights
if patch_embed_weights_are_compatible(temp_weight, patch_embed_weight):
torch.nn.init.xavier_uniform_(temp_weight.view([temp_weight.shape[0], -1]))
for index, band in enumerate(model_bands):
if band in pretrained_bands:
temp_weight[:, index] = patch_embed_weight[:, pretrained_bands.index(band)]
else:
warnings.warn(
f"Incompatible shapes between patch embedding of model {temp_weight.shape} and of checkpoint {patch_embed_weight.shape}",
category=UserWarning,
stacklevel=1,
)

state_dict[patch_embed_proj_weight_key] = temp_weight
return state_dict
return state_dict
Loading

0 comments on commit 3232145

Please sign in to comment.