Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom modules #217

Merged
merged 28 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
19e5a4e
Adapting segmentation tasks to support already instantiated models
Joao-L-S-Almeida Nov 1, 2024
30ee4cc
model_factory wasn't being initialized
Joao-L-S-Almeida Nov 1, 2024
d55c1d2
Minor adjusts
Joao-L-S-Almeida Nov 1, 2024
a4aa68b
Example of combined modules
Joao-L-S-Almeida Nov 1, 2024
558e23b
default value for kwarg
Joao-L-S-Almeida Nov 1, 2024
e3c8540
merging with main
Joao-L-S-Almeida Nov 14, 2024
9daa9cb
Exception for the case in which no model info is passed
Joao-L-S-Almeida Nov 14, 2024
70e5b58
merging with main
Joao-L-S-Almeida Nov 18, 2024
e04ca01
Proposal for a scalar regression task
Joao-L-S-Almeida Nov 19, 2024
c776e43
Minor adjust
Joao-L-S-Almeida Nov 19, 2024
cf01875
Enforcing some kind of padding when necessary
Joao-L-S-Almeida Nov 19, 2024
4b1dbf4
merging
Joao-L-S-Almeida Nov 21, 2024
9ce51f2
updates
Joao-L-S-Almeida Nov 28, 2024
41f0e4a
padding for PatchEmbed
Joao-L-S-Almeida Nov 28, 2024
64728d1
solving conflict
Joao-L-S-Almeida Nov 28, 2024
44abfc0
Updating WxC script
Joao-L-S-Almeida Nov 29, 2024
7652bb4
Regression tasks must support instantiated models
Joao-L-S-Almeida Nov 29, 2024
5f4894f
repeated methods
Joao-L-S-Almeida Nov 29, 2024
bb97cf6
unnecessary method
Joao-L-S-Almeida Nov 29, 2024
2e07271
unnecessary
Joao-L-S-Almeida Nov 29, 2024
21cb3cd
regression tasks must support extra parameters
Joao-L-S-Almeida Nov 30, 2024
6ecd3d5
Merge branch 'custom_modules' of github.com:IBM/terratorch into custo…
Joao-L-S-Almeida Nov 30, 2024
07fe3e3
Merge branch 'main' into custom_modules
Joao-L-S-Almeida Dec 2, 2024
5673488
Merge branch 'custom_modules' of github.com:IBM/terratorch into custo…
Joao-L-S-Almeida Dec 2, 2024
0c6a38d
model as property
Joao-L-S-Almeida Dec 2, 2024
5e06707
Merge branch 'custom_modules' of github.com:IBM/terratorch into custo…
Joao-L-S-Almeida Dec 2, 2024
b4c602d
workaround to supress tons of warnings
Joao-L-S-Almeida Dec 3, 2024
55acdb9
Merge branch 'main' into custom_modules
romeokienzler Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 235 additions & 0 deletions examples/scripts/test_combined_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
from terratorch.models import EncoderDecoderFactory
from terratorch.datasets import HLSBands
import torch

import os
import subprocess

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint, RichProgressBar
from lightning.pytorch.loggers import TensorBoardLogger

from terratorch.datamodules import GenericNonGeoSegmentationDataModule
from terratorch.models.model import AuxiliaryHead
from terratorch.tasks import SemanticSegmentationTask
from terratorch.models.model import AuxiliaryHeadWithDecoderWithoutInstantiatedHead, Model, ModelOutput

import shutil
import matplotlib.pyplot as plt
import rioxarray as rio


class CustomModel(torch.nn.Module):
def __init__(self, model_1:torch.nn.Module=None, model_2:torch.nn.Module=None):

super().__init__()
self.model_1 = model_1
self.model_2 = model_2

def forward(self, x:torch.Tensor):

output_1 = self.model_1(x)
output_2 = self.model_2(x)
mask = (output_1.output + output_2.output)/2

return ModelOutput(output=mask)

def freeze_encoder(self):

self.model_1.freeze_encoder()
self.model_2.freeze_encoder()

def freeze_decoder(self):

self.model_1.freeze_decoder()
self.model_2.freeze_decoder()

model_factory = EncoderDecoderFactory()

batch_size = 1
num_workers = 19

subprocess.run("wget https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-Burn-scars-demo/resolve/main/subsetted_512x512_HLS.S30.T10TGS.2018285.v1.4_merged.tif", shell=True)

input_file_name = "subsetted_512x512_HLS.S30.T10TGS.2018285.v1.4_merged.tif"
label_file_name = "subsetted_512x512_HLS.S30.T10TGS.2018285.v1.4.mask.tif"

# organize the data directory
if not os.path.isdir("burn_scar_segmentation_toy"):
os.mkdir("burn_scar_segmentation_toy")

for data_dir in ["train_images", "test_images", "val_images"]:
os.mkdir(os.path.join("burn_scar_segmentation_toy", data_dir))
shutil.copy(input_file_name, os.path.join("burn_scar_segmentation_toy", data_dir, input_file_name))

for label_dir in ["train_labels", "test_labels", "val_labels"]:
os.mkdir(os.path.join("burn_scar_segmentation_toy", label_dir))
shutil.copy(label_file_name, os.path.join("burn_scar_segmentation_toy", label_dir, label_file_name))

train_val_test = [
"burn_scar_segmentation_toy/train_images",
"burn_scar_segmentation_toy/val_images",
"burn_scar_segmentation_toy/test_images",
]

train_val_test_labels = {
"train_label_data_root": "burn_scar_segmentation_toy/train_labels",
"val_label_data_root": "burn_scar_segmentation_toy/val_labels",
"test_label_data_root": "burn_scar_segmentation_toy/test_labels",
}

# from https://github.com/NASA-IMPACT/hls-foundation-os/blob/main/configs/burn_scars.py
means=[
0.033349706741586264,
0.05701185520536176,
0.05889748132001316,
0.2323245113436119,
0.1972854853760658,
0.11944914225186566,
]
stds=[
0.02269135568823774,
0.026807560223070237,
0.04004109844362779,
0.07791732423672691,
0.08708738838140137,
0.07241979477437814,
]
datamodule = GenericNonGeoSegmentationDataModule(
batch_size,
num_workers,
*train_val_test,
"*_merged.tif", # img grep
"*.mask.tif", # label grep
means,
stds,
2, # num classes
**train_val_test_labels,

# if transforms are defined with Albumentations, you can pass them here
# train_transform=train_transform,
# val_transform=val_transform,
# test_transform=test_transform,

# edit the below for your usecase
dataset_bands=[
HLSBands.BLUE,
HLSBands.GREEN,
HLSBands.RED,
HLSBands.NIR_NARROW,
HLSBands.SWIR_1,
HLSBands.SWIR_2,
],
output_bands=[
HLSBands.BLUE,
HLSBands.GREEN,
HLSBands.RED,
HLSBands.NIR_NARROW,
HLSBands.SWIR_1,
HLSBands.SWIR_2,
],
no_data_replace=0,
no_label_replace=-1,
)
# we want to access some properties of the train dataset later on, so lets call setup here
# if not, we would not need to
datamodule.setup("fit")

model_1 = model_factory.build_model(task="segmentation",
backbone="prithvi_vit_tiny",
decoder="IdentityDecoder",
backbone_bands=[
HLSBands.BLUE,
HLSBands.GREEN,
HLSBands.RED,
HLSBands.NIR_NARROW,
HLSBands.SWIR_1,
HLSBands.SWIR_2,
],
num_classes=2,
backbone_pretrained=False, #True,
backbone_num_frames=1,
head_dropout=0.2
)

model_2 = model_factory.build_model(task="segmentation",
backbone="prithvi_vit_tiny",
decoder="IdentityDecoder",
backbone_bands=[
HLSBands.BLUE,
HLSBands.GREEN,
HLSBands.RED,
HLSBands.NIR_NARROW,
HLSBands.SWIR_1,
HLSBands.SWIR_2,
],
num_classes=2,
backbone_pretrained=False, #True,
backbone_num_frames=1,
head_dropout=0.2
)

model = CustomModel(model_1=model_1, model_2=model_2)
model.freeze_encoder()

epochs = 1 # 1 epoch for demo
lr = 1e-3

model_args = {
"num_classes": 2,
"backbone_bands": [
HLSBands.RED,
HLSBands.GREEN,
HLSBands.BLUE,
HLSBands.NIR_NARROW,
HLSBands.SWIR_1,
HLSBands.SWIR_2,
],
"backbone_pretrained": False, #True,
"backbone_num_frames":1, # this is the default
"decoder_channels":128,
"head_dropout":0.2,
"necks": [
{"name": "SelectIndices", "indices": [-1]},
{"name": "ReshapeTokensToImage"}
]
}

task = SemanticSegmentationTask(
model_args,
None,
model=model,
loss="ce",
#aux_loss={"fcn_aux_head": 0.4},
lr=lr,
ignore_index=-1,
optimizer="AdamW",
optimizer_hparams={"weight_decay": 0.05},
)

accelerator = "gpu"
experiment = "tutorial"
if not os.path.isdir("tutorial_experiments"):
os.mkdir("tutorial_experiments")
default_root_dir = os.path.join("tutorial_experiments", experiment)
checkpoint_callback = ModelCheckpoint(monitor=task.monitor, save_top_k=1, save_last=True)
early_stopping_callback = EarlyStopping(monitor=task.monitor, min_delta=0.00, patience=20)
logger = TensorBoardLogger(save_dir=default_root_dir, name=experiment)

trainer = Trainer(
# precision="16-mixed",
accelerator=accelerator,
callbacks=[
RichProgressBar(),
checkpoint_callback,
LearningRateMonitor(logging_interval="epoch"),
],
logger=logger,
max_epochs=epochs, # train only one epoch for demo
default_root_dir=default_root_dir,
log_every_n_steps=1,
check_val_every_n_epoch=200

)

trainer.fit(model=task, datamodule=datamodule)
37 changes: 32 additions & 5 deletions terratorch/tasks/segmentation_tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional

import lightning
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -40,6 +40,7 @@ def __init__(
self,
model_args: dict,
model_factory: str,
model: Optional[torch.nn.Module]=None,
loss: str = "ce",
aux_heads: list[AuxiliaryHead] | None = None,
aux_loss: dict[str, float] | None = None,
Expand Down Expand Up @@ -98,22 +99,47 @@ def __init__(
self.tiled_inference_parameters = tiled_inference_parameters
self.aux_loss = aux_loss
self.aux_heads = aux_heads
self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)
self._model_module = None

if model_factory:
self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)
self.model_builder = self._build
else:
self.model_builder = self._bypass_build

self._model_module = None

super().__init__()

self._model_module = model
self.model = model

self.train_loss_handler = LossHandler(self.train_metrics.prefix)
self.test_loss_handler = LossHandler(self.test_metrics.prefix)
self.val_loss_handler = LossHandler(self.val_metrics.prefix)
self.monitor = f"{self.val_metrics.prefix}loss"
self.plot_on_val = int(plot_on_val)

# overwrite early stopping
@property
def model_module(self):
return self._model_module

# overwrite early stopping
def configure_callbacks(self) -> list[Callback]:
return []

def configure_models(self) -> None:
self.model: Model = self.model_factory.build_model(
def _bypass_build(self):
return self.model_module

def _build(self):

return self.model_factory.build_model(
"segmentation", aux_decoders=self.aux_heads, **self.hparams["model_args"]
)

def configure_models(self) -> None:
self.model: Model = self.model_builder()

if self.hparams["freeze_backbone"]:
self.model.freeze_encoder()
if self.hparams["freeze_decoder"]:
Expand Down Expand Up @@ -259,6 +285,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -
"""
x = batch["image"]
y = batch["mask"]

model_output: ModelOutput = self(x)
loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
Expand Down
Loading