Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Implement option to have multiple learning rates #329

Merged
merged 9 commits into from
Jan 17, 2025
Next Next commit
Implement reduce_lr option for multiple lr
Signed-off-by: Francesc Marti Escofet <f.martiescofet@gmail.com>
fmartiescofet committed Dec 18, 2024
commit 7ecbbb21b050320403bff67c197c2a1bdf7f9adb
19 changes: 18 additions & 1 deletion terratorch/tasks/base_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from collections.abc import Iterable

import lightning
from lightning.pytorch.callbacks import Callback
@@ -52,10 +53,26 @@ def configure_optimizers(
optimizer = self.hparams["optimizer"]
if optimizer is None:
optimizer = "Adam"

parameters: Iterable
if self.hparams.get("reduce_lr", None) is not None and len(self.hparams["reduce_lr"]) > 0:
parameters = []
for param_name, reduce_factor in self.hparams["reduce_lr"]:
p = [p for n, p in self.model.named_parameters() if param_name in n]
parameters.append({"params": p, "lr": self.hparams["lr"] / reduce_factor})
rest_p = [
p
for n, p in self.model.named_parameters()
if all(param_name not in n for param_name, _ in self.hparams["reduce_lr"])
]
parameters.append({"params": rest_p})
else:
parameters = self.parameters()

return optimizer_factory(
optimizer,
self.hparams["lr"],
self.parameters(),
parameters,
self.hparams["optimizer_hparams"],
self.hparams["scheduler"],
self.monitor,
4 changes: 3 additions & 1 deletion terratorch/tasks/classification_tasks.py
Original file line number Diff line number Diff line change
@@ -62,6 +62,7 @@ def __init__(
freeze_backbone: bool = False, # noqa: FBT001, FBT002
freeze_decoder: bool = False, # noqa: FBT002, FBT001
class_names: list[str] | None = None,
reduce_lr: list[tuple[str, float]] | None = None,
) -> None:
"""Constructor

@@ -97,6 +98,8 @@ def __init__(
freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
Defaults to numeric ordering.
reduce_lr (list[tuple[str, float]] | None, optional): List of tuples with a substring of the parameter names
to reduce the learning rate and the factor to reduce it by. Defaults to None.
"""
self.aux_loss = aux_loss
self.aux_heads = aux_heads
@@ -120,7 +123,6 @@ def __init__(
self.val_loss_handler = LossHandler(self.val_metrics.prefix)
self.monitor = f"{self.val_metrics.prefix}loss"


def configure_losses(self) -> None:
"""Initialize the loss criterion.

3 changes: 3 additions & 0 deletions terratorch/tasks/regression_tasks.py
Original file line number Diff line number Diff line change
@@ -152,6 +152,7 @@ def __init__(
freeze_decoder: bool = False, # noqa: FBT001, FBT002
plot_on_val: bool | int = 10,
tiled_inference_parameters: TiledInferenceParameters | None = None,
reduce_lr: list[tuple[str, float]] | None = None,
) -> None:
"""Constructor

@@ -186,6 +187,8 @@ def __init__(
If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters
used to determine if inference is done on the whole image or through tiling.
reduce_lr (list[tuple[str, float]] | None, optional): List of tuples with a substring of the parameter names
to reduce the learning rate and the factor to reduce it by. Defaults to None.
"""
self.tiled_inference_parameters = tiled_inference_parameters
self.aux_loss = aux_loss
3 changes: 3 additions & 0 deletions terratorch/tasks/segmentation_tasks.py
Original file line number Diff line number Diff line change
@@ -64,6 +64,7 @@ def __init__(
class_names: list[str] | None = None,
tiled_inference_parameters: TiledInferenceParameters = None,
test_dataloaders_names: list[str] | None = None,
reduce_lr: list[tuple[str, float]] | None = None,
) -> None:
"""Constructor

@@ -106,6 +107,8 @@ def __init__(
test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None,
which assumes only one test dataloader is used.
reduce_lr (list[tuple[str, float]] | None, optional): List of tuples with a substring of the parameter names
to reduce the learning rate and the factor to reduce it by. Defaults to None.
"""
self.tiled_inference_parameters = tiled_inference_parameters
self.aux_loss = aux_loss