Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Implement option to have multiple learning rates #329

Merged
merged 9 commits into from
Jan 17, 2025
19 changes: 18 additions & 1 deletion terratorch/tasks/base_task.py
Joao-L-S-Almeida marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from collections.abc import Iterable

import lightning
from lightning.pytorch.callbacks import Callback
Expand Down Expand Up @@ -52,10 +53,26 @@ def configure_optimizers(
optimizer = self.hparams["optimizer"]
if optimizer is None:
optimizer = "Adam"

parameters: Iterable
if self.hparams.get("lr_overrides", None) is not None and len(self.hparams["lr_overrides"]) > 0:
parameters = []
for param_name, custom_lr in self.hparams["lr_overrides"].items():
p = [p for n, p in self.model.named_parameters() if param_name in n]
parameters.append({"params": p, "lr": custom_lr})
rest_p = [
p
for n, p in self.model.named_parameters()
if all(param_name not in n for param_name in self.hparams["lr_overrides"])
]
parameters.append({"params": rest_p})
else:
parameters = self.parameters()

return optimizer_factory(
optimizer,
self.hparams["lr"],
self.parameters(),
parameters,
self.hparams["optimizer_hparams"],
self.hparams["scheduler"],
self.monitor,
Expand Down
20 changes: 12 additions & 8 deletions terratorch/tasks/classification_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from terratorch.tasks.optimizer_factory import optimizer_factory
from terratorch.tasks.base_task import TerraTorchTask

logger = logging.getLogger('terratorch')
logger = logging.getLogger("terratorch")


def to_class_prediction(y: ModelOutput) -> Tensor:
y_hat = y.output
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(
freeze_backbone: bool = False, # noqa: FBT001, FBT002
freeze_decoder: bool = False, # noqa: FBT002, FBT001
class_names: list[str] | None = None,
lr_overrides: dict[str, float] | None = None,
) -> None:
"""Constructor

Expand Down Expand Up @@ -97,6 +99,9 @@ def __init__(
freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
Defaults to numeric ordering.
lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
parameters. The key should be a substring of the parameter names (it will check the substring is
contained in the parameter name)and the value should be the new lr. Defaults to None.
"""
self.aux_loss = aux_loss
self.aux_heads = aux_heads
Expand All @@ -120,7 +125,6 @@ def __init__(
self.val_loss_handler = LossHandler(self.val_metrics.prefix)
self.monitor = f"{self.val_metrics.prefix}loss"


def configure_losses(self) -> None:
"""Initialize the loss criterion.

Expand All @@ -131,8 +135,8 @@ def configure_losses(self) -> None:
ignore_index = self.hparams["ignore_index"]

class_weights = (
torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None
)
torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None
)
if loss == "ce":
ignore_value = -100 if ignore_index is None else ignore_index
self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=class_weights)
Expand Down Expand Up @@ -200,7 +204,7 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) ->
x = batch["image"]
y = batch["label"]
other_keys = batch.keys() - {"image", "label", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}

model_output: ModelOutput = self(x, **rest)
loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
Expand All @@ -221,7 +225,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -
x = batch["image"]
y = batch["label"]
other_keys = batch.keys() - {"image", "label", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)
loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
Expand All @@ -239,7 +243,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
x = batch["image"]
y = batch["label"]
other_keys = batch.keys() - {"image", "label", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)
loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
Expand All @@ -260,7 +264,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T
x = batch["image"]
file_names = batch["filename"] if "filename" in batch else None
other_keys = batch.keys() - {"image", "label", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)

y_hat = self(x).output
Expand Down
15 changes: 10 additions & 5 deletions terratorch/tasks/regression_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

BATCH_IDX_FOR_VALIDATION_PLOTTING = 10

logger = logging.getLogger('terratorch')
logger = logging.getLogger("terratorch")


class RootLossWrapper(nn.Module):
def __init__(self, loss_function: nn.Module, reduction: None | str = "mean") -> None:
Expand Down Expand Up @@ -152,6 +153,7 @@ def __init__(
freeze_decoder: bool = False, # noqa: FBT001, FBT002
plot_on_val: bool | int = 10,
tiled_inference_parameters: TiledInferenceParameters | None = None,
lr_overrides: dict[str, float] | None = None,
) -> None:
"""Constructor

Expand Down Expand Up @@ -186,6 +188,9 @@ def __init__(
If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters
used to determine if inference is done on the whole image or through tiling.
lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
parameters. The key should be a substring of the parameter names (it will check the substring is
contained in the parameter name)and the value should be the new lr. Defaults to None.
"""
self.tiled_inference_parameters = tiled_inference_parameters
self.aux_loss = aux_loss
Expand Down Expand Up @@ -266,7 +271,7 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) ->
x = batch["image"]
y = batch["mask"]
other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}

model_output: ModelOutput = self(x, **rest)
loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
Expand All @@ -287,7 +292,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -
x = batch["image"]
y = batch["mask"]
other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)
loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
Expand Down Expand Up @@ -329,7 +334,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
x = batch["image"]
y = batch["mask"]
other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)
loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0])
Expand All @@ -350,7 +355,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T
x = batch["image"]
file_names = batch["filename"] if "filename" in batch else None
other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k:batch[k] for k in other_keys}
rest = {k: batch[k] for k in other_keys}

def model_forward(x):
return self(x).output
Expand Down
11 changes: 9 additions & 2 deletions terratorch/tasks/segmentation_tasks.py
Joao-L-S-Almeida marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
class_names: list[str] | None = None,
tiled_inference_parameters: TiledInferenceParameters = None,
test_dataloaders_names: list[str] | None = None,
lr_overrides: dict[str, float] | None = None,
) -> None:
"""Constructor

Expand Down Expand Up @@ -106,6 +107,9 @@ def __init__(
test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None,
which assumes only one test dataloader is used.
lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
parameters. The key should be a substring of the parameter names (it will check the substring is
contained in the parameter name)and the value should be the new lr. Defaults to None.
"""
self.tiled_inference_parameters = tiled_inference_parameters
self.aux_loss = aux_loss
Expand Down Expand Up @@ -294,7 +298,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -
batch["prediction"] = y_hat_hard

if isinstance(batch["image"], dict):
if hasattr(datamodule, 'rgb_modality'):
if hasattr(datamodule, "rgb_modality"):
# Generic multimodal dataset
batch["image"] = batch["image"][datamodule.rgb_modality]
else:
Expand Down Expand Up @@ -343,7 +347,10 @@ def model_forward(x):
if self.tiled_inference_parameters:
y_hat: Tensor = tiled_inference(
# TODO: tiled inference does not work with additional input data (**rest)
model_forward, x, self.hparams["model_args"]["num_classes"], self.tiled_inference_parameters
model_forward,
x,
self.hparams["model_args"]["num_classes"],
self.tiled_inference_parameters,
)
else:
y_hat: Tensor = self(x, **rest).output
Expand Down
Loading