Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Implement option to have multiple learning rates #329

Merged
merged 9 commits into from
Jan 17, 2025
133 changes: 133 additions & 0 deletions examples/confs/sen1floods11_vit_dual_lr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# lightning.pytorch==2.1.1
seed_everything: 0
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: 16-mixed
logger: True # will use tensorboardlogger
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch

max_epochs: 200
check_val_every_n_epoch: 1
log_every_n_steps: 50
enable_checkpointing: true
default_root_dir: <your_path_here>
data:
class_path: GenericNonGeoSegmentationDataModule
init_args:
batch_size: 16
num_workers: 8
constant_scale: 0.0001
dataset_bands:
- COASTAL_AEROSOL
- BLUE
- GREEN
- RED
- RED_EDGE_1
- RED_EDGE_2
- RED_EDGE_3
- NIR_BROAD
- NIR_NARROW
- WATER_VAPOR
- CIRRUS
- SWIR_1
- SWIR_2
output_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
rgb_indices:
- 2
- 1
- 0
train_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/S2Hand/
train_label_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/LabelHand
val_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/S2Hand/
val_label_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/LabelHand
test_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/S2Hand/
test_label_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/LabelHand
# these must be obtained by running terratorch/examples/scripts/convert_sen1floods11_splits.py on the original split csv files
train_split: <sen1floods11_root>/v1.1/splits/flood_handlabeled/flood_train_data.txt
test_split: <sen1floods11_root>/v1.1/splits/flood_handlabeled/flood_test_data.txt
val_split: <sen1floods11_root>/v1.1/splits/flood_handlabeled/flood_valid_data.txt
img_grep: "*_S2Hand.tif"
label_grep: "*_LabelHand.tif"
no_label_replace: -1
no_data_replace: 0
means:
- 0.1412956
- 0.13795798
- 0.12353792
- 0.30902815
- 0.2044958
- 0.11912015
stds:
- 0.07406382
- 0.07370365
- 0.08692279
- 0.11798815
- 0.09772074
- 0.07659938
num_classes: 2

model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_args:
decoder: FCNDecoder
backbone_pretrained: true
backbone: prithvi_vit_100
decoder_channels: 256
backbone_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
num_classes: 2
head_dropout: 0.1
decoder_num_convs: 4
head_channel_list:
- 256
necks:
- name: SelectIndices
indices:
- -1
- name: ReshapeTokensToImage
loss: ce
aux_heads:
- name: aux_head
decoder: FCNDecoder
decoder_args:
decoder_channels: 256
decoder_in_index: -1
decoder_num_convs: 2
head_dropout: 0.1
# head_channel_list:
# - 64
aux_loss:
aux_head: 1.0
ignore_index: -1
class_weights:
- 0.3
- 0.7
freeze_backbone: false
freeze_decoder: false
model_factory: EncoderDecoderFactory
optimizer: AdamW
lr: 1e-4
lr_overrides:
encoder: 1e-5
scheduler: ReduceLROnPlateau
scheduler_hparams:
monitor: val/loss
19 changes: 18 additions & 1 deletion terratorch/tasks/base_task.py
Joao-L-S-Almeida marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from collections.abc import Iterable

import lightning
from lightning.pytorch.callbacks import Callback
Expand Down Expand Up @@ -52,10 +53,26 @@ def configure_optimizers(
optimizer = self.hparams["optimizer"]
if optimizer is None:
optimizer = "Adam"

parameters: Iterable
if self.hparams.get("lr_overrides", None) is not None and len(self.hparams["lr_overrides"]) > 0:
parameters = []
for param_name, custom_lr in self.hparams["lr_overrides"].items():
p = [p for n, p in self.named_parameters() if param_name in n]
parameters.append({"params": p, "lr": custom_lr})
rest_p = [
p
for n, p in self.named_parameters()
if all(param_name not in n for param_name in self.hparams["lr_overrides"])
]
parameters.append({"params": rest_p})
else:
parameters = self.parameters()

return optimizer_factory(
optimizer,
self.hparams["lr"],
self.parameters(),
parameters,
self.hparams["optimizer_hparams"],
self.hparams["scheduler"],
self.monitor,
Expand Down
20 changes: 12 additions & 8 deletions terratorch/tasks/classification_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from terratorch.tasks.optimizer_factory import optimizer_factory
from terratorch.tasks.base_task import TerraTorchTask

logger = logging.getLogger('terratorch')
logger = logging.getLogger("terratorch")


def to_class_prediction(y: ModelOutput) -> Tensor:
y_hat = y.output
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(
freeze_backbone: bool = False, # noqa: FBT001, FBT002
freeze_decoder: bool = False, # noqa: FBT002, FBT001
class_names: list[str] | None = None,
lr_overrides: dict[str, float] | None = None,
) -> None:
"""Constructor

Expand Down Expand Up @@ -97,6 +99,9 @@ def __init__(
freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
Defaults to numeric ordering.
lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
parameters. The key should be a substring of the parameter names (it will check the substring is
contained in the parameter name)and the value should be the new lr. Defaults to None.
"""
self.aux_loss = aux_loss
self.aux_heads = aux_heads
Expand All @@ -120,7 +125,6 @@ def __init__(
self.val_loss_handler = LossHandler(self.val_metrics.prefix)
self.monitor = f"{self.val_metrics.prefix}loss"


def configure_losses(self) -> None:
"""Initialize the loss criterion.

Expand All @@ -131,8 +135,8 @@ def configure_losses(self) -> None:
ignore_index = self.hparams["ignore_index"]

class_weights = (
torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None
)
torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None
)
if loss == "ce":
ignore_value = -100 if ignore_index is None else ignore_index
self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=class_weights)
Expand Down Expand Up @@ -200,7 +204,7 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) ->
x = batch["image"]
y = batch["label"]
other_keys = batch.keys() - {"image", "label", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}

model_output: ModelOutput = self(x, **rest)
loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
Expand All @@ -221,7 +225,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -
x = batch["image"]
y = batch["label"]
other_keys = batch.keys() - {"image", "label", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)
loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
Expand All @@ -239,7 +243,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
x = batch["image"]
y = batch["label"]
other_keys = batch.keys() - {"image", "label", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)
loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
Expand All @@ -260,7 +264,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T
x = batch["image"]
file_names = batch["filename"] if "filename" in batch else None
other_keys = batch.keys() - {"image", "label", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)

y_hat = self(x).output
Expand Down
15 changes: 10 additions & 5 deletions terratorch/tasks/regression_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

BATCH_IDX_FOR_VALIDATION_PLOTTING = 10

logger = logging.getLogger('terratorch')
logger = logging.getLogger("terratorch")


class RootLossWrapper(nn.Module):
def __init__(self, loss_function: nn.Module, reduction: None | str = "mean") -> None:
Expand Down Expand Up @@ -152,6 +153,7 @@ def __init__(
freeze_decoder: bool = False, # noqa: FBT001, FBT002
plot_on_val: bool | int = 10,
tiled_inference_parameters: TiledInferenceParameters | None = None,
lr_overrides: dict[str, float] | None = None,
) -> None:
"""Constructor

Expand Down Expand Up @@ -186,6 +188,9 @@ def __init__(
If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters
used to determine if inference is done on the whole image or through tiling.
lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
parameters. The key should be a substring of the parameter names (it will check the substring is
contained in the parameter name)and the value should be the new lr. Defaults to None.
"""
self.tiled_inference_parameters = tiled_inference_parameters
self.aux_loss = aux_loss
Expand Down Expand Up @@ -266,7 +271,7 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) ->
x = batch["image"]
y = batch["mask"]
other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}

model_output: ModelOutput = self(x, **rest)
loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
Expand All @@ -287,7 +292,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -
x = batch["image"]
y = batch["mask"]
other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)
loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
Expand Down Expand Up @@ -329,7 +334,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
x = batch["image"]
y = batch["mask"]
other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)
loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
Expand All @@ -350,7 +355,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T
x = batch["image"]
file_names = batch["filename"] if "filename" in batch else None
other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}

def model_forward(x):
return self(x).output
Expand Down
11 changes: 9 additions & 2 deletions terratorch/tasks/segmentation_tasks.py
Joao-L-S-Almeida marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
class_names: list[str] | None = None,
tiled_inference_parameters: TiledInferenceParameters = None,
test_dataloaders_names: list[str] | None = None,
lr_overrides: dict[str, float] | None = None,
) -> None:
"""Constructor

Expand Down Expand Up @@ -106,6 +107,9 @@ def __init__(
test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None,
which assumes only one test dataloader is used.
lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
parameters. The key should be a substring of the parameter names (it will check the substring is
contained in the parameter name)and the value should be the new lr. Defaults to None.
"""
self.tiled_inference_parameters = tiled_inference_parameters
self.aux_loss = aux_loss
Expand Down Expand Up @@ -294,7 +298,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -
batch["prediction"] = y_hat_hard

if isinstance(batch["image"], dict):
if hasattr(datamodule, 'rgb_modality'):
if hasattr(datamodule, "rgb_modality"):
# Generic multimodal dataset
batch["image"] = batch["image"][datamodule.rgb_modality]
else:
Expand Down Expand Up @@ -343,7 +347,10 @@ def model_forward(x):
if self.tiled_inference_parameters:
y_hat: Tensor = tiled_inference(
# TODO: tiled inference does not work with additional input data (**rest)
model_forward, x, self.hparams["model_args"]["num_classes"], self.tiled_inference_parameters
model_forward,
x,
self.hparams["model_args"]["num_classes"],
self.tiled_inference_parameters,
)
else:
y_hat: Tensor = self(x, **rest).output
Expand Down
Loading
Loading