diff --git a/requirements/required.txt b/requirements/required.txt index f298ec97..986de635 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -15,7 +15,7 @@ lightning==2.4.0 git+https://github.com/qubvel-org/segmentation_models.pytorch.git@3952e1f8e9684a385a81e30381b8fb5b1ac086cf timm==1.0.11 numpy==1.26.4 -jsonargparse==4.32.0 +jsonargparse<=4.35.0 # These dependencies are optional # and must be installed just in case diff --git a/terratorch/models/pixel_wise_model.py b/terratorch/models/pixel_wise_model.py index 2ee84ce7..04597d53 100644 --- a/terratorch/models/pixel_wise_model.py +++ b/terratorch/models/pixel_wise_model.py @@ -81,6 +81,8 @@ def freeze_encoder(self): def freeze_decoder(self): freeze_module(self.decoder) + + def freeze_head(self): freeze_module(self.head) @staticmethod diff --git a/terratorch/models/scalar_output_model.py b/terratorch/models/scalar_output_model.py index 4b73eea6..c0b5db99 100644 --- a/terratorch/models/scalar_output_model.py +++ b/terratorch/models/scalar_output_model.py @@ -75,6 +75,8 @@ def freeze_encoder(self): def freeze_decoder(self): freeze_module(self.decoder) + + def freeze_head(self): freeze_module(self.head) def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput: diff --git a/terratorch/tasks/base_task.py b/terratorch/tasks/base_task.py index f5ac30b3..c04d7f22 100644 --- a/terratorch/tasks/base_task.py +++ b/terratorch/tasks/base_task.py @@ -47,6 +47,9 @@ def configure_models(self) -> None: if self.hparams["freeze_decoder"]: self.model.freeze_decoder() + if self.hparams["freeze_head"]: + self.model.freeze_head() + def configure_optimizers( self, ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig": diff --git a/terratorch/tasks/classification_tasks.py b/terratorch/tasks/classification_tasks.py index 249634bf..60b33235 100644 --- a/terratorch/tasks/classification_tasks.py +++ b/terratorch/tasks/classification_tasks.py @@ -64,6 +64,7 @@ def __init__( # freeze_backbone: bool = False, # noqa: FBT001, FBT002 freeze_decoder: bool = False, # noqa: FBT002, FBT001 + freeze_head: bool = False, # noqa: FBT002, FBT001 class_names: list[str] | None = None, test_dataloaders_names: list[str] | None = None, lr_overrides: dict[str, float] | None = None, @@ -99,7 +100,8 @@ def __init__( scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler. Overriden by config / cli specification through LightningCLI. freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False. - freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False. + freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False. + freeze_head (bool, optional): Whether to freeze the segmentation_head. Defaults to False. class_names (list[str] | None, optional): List of class names passed to metrics for better naming. Defaults to numeric ordering. test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when diff --git a/terratorch/tasks/regression_tasks.py b/terratorch/tasks/regression_tasks.py index c03eeb32..9849b0b2 100644 --- a/terratorch/tasks/regression_tasks.py +++ b/terratorch/tasks/regression_tasks.py @@ -153,6 +153,7 @@ def __init__( # freeze_backbone: bool = False, # noqa: FBT001, FBT002 freeze_decoder: bool = False, # noqa: FBT001, FBT002 + freeze_head: bool = False, # noqa: FBT001, FBT002 plot_on_val: bool | int = 10, tiled_inference_parameters: TiledInferenceParameters | None = None, test_dataloaders_names: list[str] | None = None, @@ -186,7 +187,8 @@ def __init__( scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler. Overriden by config / cli specification through LightningCLI. freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False. - freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False. + freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False. + freeze_head (bool, optional): Whether to freeze the segmentation head. Defaults to False. plot_on_val (bool | int, optional): Whether to plot visualizations on validation. If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs. tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index aff17ab0..f8ab82af 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -60,6 +60,7 @@ def __init__( # freeze_backbone: bool = False, # noqa: FBT001, FBT002 freeze_decoder: bool = False, # noqa: FBT002, FBT001 + freeze_head: bool = False, plot_on_val: bool | int = 10, class_names: list[str] | None = None, tiled_inference_parameters: TiledInferenceParameters = None, @@ -97,7 +98,8 @@ def __init__( scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler. Overriden by config / cli specification through LightningCLI. freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False. - freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False. + freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False. + freeze_head (bool, optional): Whether to freeze the segmentation head. Defaults to False. plot_on_val (bool | int, optional): Whether to plot visualizations on validation. If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs. class_names (list[str] | None, optional): List of class names passed to metrics for better naming. diff --git a/tests/resources/configs/manufactured-finetune_prithvi_swin_B.yaml b/tests/resources/configs/manufactured-finetune_prithvi_swin_B.yaml index cea8a0ea..11f21975 100644 --- a/tests/resources/configs/manufactured-finetune_prithvi_swin_B.yaml +++ b/tests/resources/configs/manufactured-finetune_prithvi_swin_B.yaml @@ -129,6 +129,7 @@ model: ignore_index: -1 freeze_backbone: true freeze_decoder: false + freeze_head: false model_factory: PrithviModelFactory # uncomment this block for tiled inference