Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom modules #217

Merged
merged 28 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
19e5a4e
Adapting segmentation tasks to support already instantiated models
Joao-L-S-Almeida Nov 1, 2024
30ee4cc
model_factory wasn't being initialized
Joao-L-S-Almeida Nov 1, 2024
d55c1d2
Minor adjusts
Joao-L-S-Almeida Nov 1, 2024
a4aa68b
Example of combined modules
Joao-L-S-Almeida Nov 1, 2024
558e23b
default value for kwarg
Joao-L-S-Almeida Nov 1, 2024
e3c8540
merging with main
Joao-L-S-Almeida Nov 14, 2024
9daa9cb
Exception for the case in which no model info is passed
Joao-L-S-Almeida Nov 14, 2024
70e5b58
merging with main
Joao-L-S-Almeida Nov 18, 2024
e04ca01
Proposal for a scalar regression task
Joao-L-S-Almeida Nov 19, 2024
c776e43
Minor adjust
Joao-L-S-Almeida Nov 19, 2024
cf01875
Enforcing some kind of padding when necessary
Joao-L-S-Almeida Nov 19, 2024
4b1dbf4
merging
Joao-L-S-Almeida Nov 21, 2024
9ce51f2
updates
Joao-L-S-Almeida Nov 28, 2024
41f0e4a
padding for PatchEmbed
Joao-L-S-Almeida Nov 28, 2024
64728d1
solving conflict
Joao-L-S-Almeida Nov 28, 2024
44abfc0
Updating WxC script
Joao-L-S-Almeida Nov 29, 2024
7652bb4
Regression tasks must support instantiated models
Joao-L-S-Almeida Nov 29, 2024
5f4894f
repeated methods
Joao-L-S-Almeida Nov 29, 2024
bb97cf6
unnecessary method
Joao-L-S-Almeida Nov 29, 2024
2e07271
unnecessary
Joao-L-S-Almeida Nov 29, 2024
21cb3cd
regression tasks must support extra parameters
Joao-L-S-Almeida Nov 30, 2024
6ecd3d5
Merge branch 'custom_modules' of github.com:IBM/terratorch into custo…
Joao-L-S-Almeida Nov 30, 2024
07fe3e3
Merge branch 'main' into custom_modules
Joao-L-S-Almeida Dec 2, 2024
5673488
Merge branch 'custom_modules' of github.com:IBM/terratorch into custo…
Joao-L-S-Almeida Dec 2, 2024
0c6a38d
model as property
Joao-L-S-Almeida Dec 2, 2024
5e06707
Merge branch 'custom_modules' of github.com:IBM/terratorch into custo…
Joao-L-S-Almeida Dec 2, 2024
b4c602d
workaround to supress tons of warnings
Joao-L-S-Almeida Dec 3, 2024
55acdb9
Merge branch 'main' into custom_modules
romeokienzler Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ venv/*
examples/notebooks/config.yaml
examples/notebooks/wxc_input_u_v_t_p_output_theta_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling05.nc
tests/all_ecos_random/*
examples/**/*tif*
**/climatology/*
**/lightning_logs/*
**/merra-2/*
Expand Down
18 changes: 11 additions & 7 deletions examples/scripts/WxCTrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,17 @@
#


config = get_config('../confs/granite-wxc-merra2-downscale-config.yaml')
config.download_path = './'
config = get_config('../confs/granite-wxc-merra2-downscale-large-config.yaml')
download_path = os.getcwd()
config.download_path = download_path

config.data.data_path_surface = os.path.join(config.download_path,'merra-2')
config.data.data_path_vertical = os.path.join(config.download_path, 'merra-2')
config.data.climatology_path_surface = os.path.join(config.download_path,'climatology')
config.data.climatology_path_vertical = os.path.join(config.download_path,'climatology')
config.data.data_path_surface = os.path.join(download_path,'merra-2')
config.data.data_path_vertical = os.path.join(download_path, 'merra-2')
config.data.climatology_path_surface = os.path.join(download_path,'climatology')
config.data.climatology_path_vertical = os.path.join(download_path,'climatology')

extra_kwargs = config.model.init_args["extra_kwargs"]
model_args = config.model.init_args["model_args"]

config.model.input_scalers_surface_path = os.path.join(config.download_path,'climatology/musigma_surface.nc')
config.model.input_scalers_vertical_path = os.path.join(config.download_path,'climatology/musigma_vertical.nc')
Expand Down Expand Up @@ -130,7 +134,7 @@

print("This is our config:")

task = WxCDownscalingTask(model_args = {}, model_factory = 'WxCModelFactory', model_config=config, optimizer='AdamW', optimizer_hparams={'weight_decay': 0.05})
task = WxCDownscalingTask(model_args = model_args, model_factory = 'WxCModelFactory',extra_kwargs=extra_kwargs, model_config=config, optimizer='AdamW', optimizer_hparams={'weight_decay': 0.05})


#
Expand Down
235 changes: 235 additions & 0 deletions examples/scripts/test_combined_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
from terratorch.models import EncoderDecoderFactory
from terratorch.datasets import HLSBands
import torch

import os
import subprocess

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint, RichProgressBar
from lightning.pytorch.loggers import TensorBoardLogger

from terratorch.datamodules import GenericNonGeoSegmentationDataModule
from terratorch.models.model import AuxiliaryHead
from terratorch.tasks import SemanticSegmentationTask
from terratorch.models.model import AuxiliaryHeadWithDecoderWithoutInstantiatedHead, Model, ModelOutput

import shutil
import matplotlib.pyplot as plt
import rioxarray as rio


class CustomModel(torch.nn.Module):
def __init__(self, model_1:torch.nn.Module=None, model_2:torch.nn.Module=None):

super().__init__()
self.model_1 = model_1
self.model_2 = model_2

def forward(self, x:torch.Tensor):

output_1 = self.model_1(x)
output_2 = self.model_2(x)
mask = (output_1.output + output_2.output)/2

return ModelOutput(output=mask)

def freeze_encoder(self):

self.model_1.freeze_encoder()
self.model_2.freeze_encoder()

def freeze_decoder(self):

self.model_1.freeze_decoder()
self.model_2.freeze_decoder()

model_factory = EncoderDecoderFactory()

batch_size = 1
num_workers = 19

subprocess.run("wget https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-Burn-scars-demo/resolve/main/subsetted_512x512_HLS.S30.T10TGS.2018285.v1.4_merged.tif", shell=True)

input_file_name = "subsetted_512x512_HLS.S30.T10TGS.2018285.v1.4_merged.tif"
label_file_name = "subsetted_512x512_HLS.S30.T10TGS.2018285.v1.4.mask.tif"

# organize the data directory
if not os.path.isdir("burn_scar_segmentation_toy"):
os.mkdir("burn_scar_segmentation_toy")

for data_dir in ["train_images", "test_images", "val_images"]:
os.mkdir(os.path.join("burn_scar_segmentation_toy", data_dir))
shutil.copy(input_file_name, os.path.join("burn_scar_segmentation_toy", data_dir, input_file_name))

for label_dir in ["train_labels", "test_labels", "val_labels"]:
os.mkdir(os.path.join("burn_scar_segmentation_toy", label_dir))
shutil.copy(label_file_name, os.path.join("burn_scar_segmentation_toy", label_dir, label_file_name))

train_val_test = [
"burn_scar_segmentation_toy/train_images",
"burn_scar_segmentation_toy/val_images",
"burn_scar_segmentation_toy/test_images",
]

train_val_test_labels = {
"train_label_data_root": "burn_scar_segmentation_toy/train_labels",
"val_label_data_root": "burn_scar_segmentation_toy/val_labels",
"test_label_data_root": "burn_scar_segmentation_toy/test_labels",
}

# from https://github.com/NASA-IMPACT/hls-foundation-os/blob/main/configs/burn_scars.py
means=[
0.033349706741586264,
0.05701185520536176,
0.05889748132001316,
0.2323245113436119,
0.1972854853760658,
0.11944914225186566,
]
stds=[
0.02269135568823774,
0.026807560223070237,
0.04004109844362779,
0.07791732423672691,
0.08708738838140137,
0.07241979477437814,
]
datamodule = GenericNonGeoSegmentationDataModule(
batch_size,
num_workers,
*train_val_test,
"*_merged.tif", # img grep
"*.mask.tif", # label grep
means,
stds,
2, # num classes
**train_val_test_labels,

# if transforms are defined with Albumentations, you can pass them here
# train_transform=train_transform,
# val_transform=val_transform,
# test_transform=test_transform,

# edit the below for your usecase
dataset_bands=[
HLSBands.BLUE,
HLSBands.GREEN,
HLSBands.RED,
HLSBands.NIR_NARROW,
HLSBands.SWIR_1,
HLSBands.SWIR_2,
],
output_bands=[
HLSBands.BLUE,
HLSBands.GREEN,
HLSBands.RED,
HLSBands.NIR_NARROW,
HLSBands.SWIR_1,
HLSBands.SWIR_2,
],
no_data_replace=0,
no_label_replace=-1,
)
# we want to access some properties of the train dataset later on, so lets call setup here
# if not, we would not need to
datamodule.setup("fit")

model_1 = model_factory.build_model(task="segmentation",
backbone="prithvi_vit_tiny",
decoder="IdentityDecoder",
backbone_bands=[
HLSBands.BLUE,
HLSBands.GREEN,
HLSBands.RED,
HLSBands.NIR_NARROW,
HLSBands.SWIR_1,
HLSBands.SWIR_2,
],
num_classes=2,
backbone_pretrained=False, #True,
backbone_num_frames=1,
head_dropout=0.2
)

model_2 = model_factory.build_model(task="segmentation",
backbone="prithvi_vit_tiny",
decoder="IdentityDecoder",
backbone_bands=[
HLSBands.BLUE,
HLSBands.GREEN,
HLSBands.RED,
HLSBands.NIR_NARROW,
HLSBands.SWIR_1,
HLSBands.SWIR_2,
],
num_classes=2,
backbone_pretrained=False, #True,
backbone_num_frames=1,
head_dropout=0.2
)

model = CustomModel(model_1=model_1, model_2=model_2)
model.freeze_encoder()

epochs = 1 # 1 epoch for demo
lr = 1e-3

model_args = {
"num_classes": 2,
"backbone_bands": [
HLSBands.RED,
HLSBands.GREEN,
HLSBands.BLUE,
HLSBands.NIR_NARROW,
HLSBands.SWIR_1,
HLSBands.SWIR_2,
],
"backbone_pretrained": False, #True,
"backbone_num_frames":1, # this is the default
"decoder_channels":128,
"head_dropout":0.2,
"necks": [
{"name": "SelectIndices", "indices": [-1]},
{"name": "ReshapeTokensToImage"}
]
}

task = SemanticSegmentationTask(
model_args,
None,
model=model,
loss="ce",
#aux_loss={"fcn_aux_head": 0.4},
lr=lr,
ignore_index=-1,
optimizer="AdamW",
optimizer_hparams={"weight_decay": 0.05},
)

accelerator = "gpu"
experiment = "tutorial"
if not os.path.isdir("tutorial_experiments"):
os.mkdir("tutorial_experiments")
default_root_dir = os.path.join("tutorial_experiments", experiment)
checkpoint_callback = ModelCheckpoint(monitor=task.monitor, save_top_k=1, save_last=True)
early_stopping_callback = EarlyStopping(monitor=task.monitor, min_delta=0.00, patience=20)
logger = TensorBoardLogger(save_dir=default_root_dir, name=experiment)

trainer = Trainer(
# precision="16-mixed",
accelerator=accelerator,
callbacks=[
RichProgressBar(),
checkpoint_callback,
LearningRateMonitor(logging_interval="epoch"),
],
logger=logger,
max_epochs=epochs, # train only one epoch for demo
default_root_dir=default_root_dir,
log_every_n_steps=1,
check_val_every_n_epoch=200

)

trainer.fit(model=task, datamodule=datamodule)
4 changes: 3 additions & 1 deletion terratorch/models/backbones/prithvi_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,15 @@ def __init__(

self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
self.warn_counter = 0

def forward(self, x):
B, C, T, H, W = x.shape

if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1:
if (T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1) and self.warn_counter == 0 :
logging.warning(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}."
f"The border will be ignored, add backbone_padding for pixel-wise tasks.")
self.warn_counter += 1

x = self.proj(x)
if self.flatten:
Expand Down
34 changes: 34 additions & 0 deletions terratorch/models/backbones/vit_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,29 @@ def __init__(
)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

def pad_images(self, imgs: Tensor, patch_size:int=None, padding:str='constant') -> Tensor:

p = self.patch_size[0]

t, h, w = imgs.shape[-3:]
h_pad = (h // p) * p - h
w_pad = (w // p) * p - w

# padding can be negative
imgs = nn.functional.pad(imgs, (0, w_pad, 0, h_pad), mode=padding)
return imgs

def forward(self, x):
if len(x.shape) == B_C_H_W_SHAPE_LEN and self.num_frames == 1:
x = x.reshape(-1, self.in_chans, self.num_frames, *x.shape[-2:])
x = self.pad_images(x)
B, C, T, H, W = x.shape # noqa: N806
x = self.proj(x)
# Hp, Wp = x.shape[3], x.shape[4]
if self.flatten:
x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
x = self.norm(x)

return x


Expand Down Expand Up @@ -249,6 +263,7 @@ def __init__(
encoder_only: bool = True, # noqa: FBT001, FBT002
coords_encoding: None | list[str] = None,
coords_scale_learn: bool = False, # noqa: ARG002, FBT001, FBT002
padding: bool | None = False,
**kwargs, # timm parameters that may be passed # noqa: ARG002
):
"""
Expand Down Expand Up @@ -288,6 +303,8 @@ def __init__(
self.feature_info = []
self.in_chans = in_chans
self.num_frames = num_frames
self.padding = padding # optional
self.pad_images = self._pad_images if self.padding else self._bypass_pad_images

self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
Expand Down Expand Up @@ -374,13 +391,29 @@ def _init_weights(self, m):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

def _pad_images(self, imgs: Tensor, patch_size:int=None, padding:str='constant') -> Tensor:

p = patch_size

t, h, w = imgs.shape[-3:]
h_pad = (h // p) * p - h # Ensure padding is within bounds
w_pad = (w // p) * p - w # Ensure padding is within bounds
#if h_pad > 0 or w_pad > 0:
imgs = nn.functional.pad(imgs, (0, w_pad, 0, h_pad), mode=padding)
return imgs

def _bypass_pad_images(self, imgs: Tensor, patch_size:int=None, padding:str='constant') -> Tensor:

return imgs

def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
"""
imgs: B, C, T, H, W
x: B, L, D
"""
p = self.patch_embed.patch_size[0]
tub = self.patch_embed.tubelet_size
imgs = self.pad_images(imgs, patch_size=p)
x = rearrange(imgs, "b c (t tub) (h p) (w q) -> b (t h w) (tub p q c)", tub=tub, p=p, q=p)

return x
Expand Down Expand Up @@ -443,6 +476,7 @@ def forward_encoder(
x = x.reshape(-1, self.in_chans, 1, *x.shape[-2:])
t, h, w = x.shape[-3:]
x = self.patch_embed(x)

pos_embed = torch.from_numpy(
get_3d_sincos_pos_embed(
self.embed_dim,
Expand Down
2 changes: 1 addition & 1 deletion terratorch/models/encoder_decoder_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from terratorch.registry import BACKBONE_REGISTRY, DECODER_REGISTRY, MODEL_FACTORY_REGISTRY

PIXEL_WISE_TASKS = ["segmentation", "regression"]
SCALAR_TASKS = ["classification"]
SCALAR_TASKS = ["classification", "scalar_regression"]
SUPPORTED_TASKS = PIXEL_WISE_TASKS + SCALAR_TASKS


Expand Down
3 changes: 3 additions & 0 deletions terratorch/models/scalar_output_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,8 @@ def _get_head(self, task: str, input_embed_dim: int, head_kwargs: dict):
msg = "num_classes must be defined for classification task"
raise Exception(msg)
return ClassificationHead(input_embed_dim, **head_kwargs)
elif task == "scalar_regression":
return ClassificationHead(input_embed_dim, **head_kwargs)

msg = "Task must be classification."
raise Exception(msg)
Loading
Loading