Skip to content

Commit

Permalink
Merge pull request #378 from IBM/freeze/head/decoder
Browse files Browse the repository at this point in the history
Freeze/head/decoder
  • Loading branch information
Joao-L-S-Almeida authored Jan 30, 2025
2 parents 519ee8a + b5c0c0e commit 89b913d
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 4 deletions.
2 changes: 1 addition & 1 deletion requirements/required.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ lightning==2.4.0
git+https://github.com/qubvel-org/segmentation_models.pytorch.git@3952e1f8e9684a385a81e30381b8fb5b1ac086cf
timm==1.0.11
numpy==1.26.4
jsonargparse==4.32.0
jsonargparse<=4.35.0

# These dependencies are optional
# and must be installed just in case
Expand Down
2 changes: 2 additions & 0 deletions terratorch/models/pixel_wise_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def freeze_encoder(self):

def freeze_decoder(self):
freeze_module(self.decoder)

def freeze_head(self):
freeze_module(self.head)

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions terratorch/models/scalar_output_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def freeze_encoder(self):

def freeze_decoder(self):
freeze_module(self.decoder)

def freeze_head(self):
freeze_module(self.head)

def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput:
Expand Down
3 changes: 3 additions & 0 deletions terratorch/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def configure_models(self) -> None:
if self.hparams["freeze_decoder"]:
self.model.freeze_decoder()

if self.hparams["freeze_head"]:
self.model.freeze_head()

def configure_optimizers(
self,
) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
Expand Down
4 changes: 3 additions & 1 deletion terratorch/tasks/classification_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
#
freeze_backbone: bool = False, # noqa: FBT001, FBT002
freeze_decoder: bool = False, # noqa: FBT002, FBT001
freeze_head: bool = False, # noqa: FBT002, FBT001
class_names: list[str] | None = None,
test_dataloaders_names: list[str] | None = None,
lr_overrides: dict[str, float] | None = None,
Expand Down Expand Up @@ -99,7 +100,8 @@ def __init__(
scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
Overriden by config / cli specification through LightningCLI.
freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False.
freeze_head (bool, optional): Whether to freeze the segmentation_head. Defaults to False.
class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
Defaults to numeric ordering.
test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when
Expand Down
4 changes: 3 additions & 1 deletion terratorch/tasks/regression_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(
#
freeze_backbone: bool = False, # noqa: FBT001, FBT002
freeze_decoder: bool = False, # noqa: FBT001, FBT002
freeze_head: bool = False, # noqa: FBT001, FBT002
plot_on_val: bool | int = 10,
tiled_inference_parameters: TiledInferenceParameters | None = None,
test_dataloaders_names: list[str] | None = None,
Expand Down Expand Up @@ -186,7 +187,8 @@ def __init__(
scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
Overriden by config / cli specification through LightningCLI.
freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False.
freeze_head (bool, optional): Whether to freeze the segmentation head. Defaults to False.
plot_on_val (bool | int, optional): Whether to plot visualizations on validation.
If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters
Expand Down
4 changes: 3 additions & 1 deletion terratorch/tasks/segmentation_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
#
freeze_backbone: bool = False, # noqa: FBT001, FBT002
freeze_decoder: bool = False, # noqa: FBT002, FBT001
freeze_head: bool = False,
plot_on_val: bool | int = 10,
class_names: list[str] | None = None,
tiled_inference_parameters: TiledInferenceParameters = None,
Expand Down Expand Up @@ -97,7 +98,8 @@ def __init__(
scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
Overriden by config / cli specification through LightningCLI.
freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False.
freeze_head (bool, optional): Whether to freeze the segmentation head. Defaults to False.
plot_on_val (bool | int, optional): Whether to plot visualizations on validation.
If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ model:
ignore_index: -1
freeze_backbone: true
freeze_decoder: false
freeze_head: false
model_factory: PrithviModelFactory

# uncomment this block for tiled inference
Expand Down

0 comments on commit 89b913d

Please sign in to comment.