From 124cd0230e3aaf891d6bf17481107fe0f69c65c5 Mon Sep 17 00:00:00 2001 From: Francesc Marti Escofet Date: Wed, 18 Dec 2024 12:04:47 +0100 Subject: [PATCH 1/6] Implement multiple test dataloaders all tasks Signed-off-by: Francesc Marti Escofet --- terratorch/tasks/base_task.py | 5 ++- terratorch/tasks/classification_tasks.py | 46 ++++++++++++++++------- terratorch/tasks/regression_tasks.py | 48 ++++++++++++++++++------ terratorch/tasks/segmentation_tasks.py | 12 +++--- 4 files changed, 78 insertions(+), 33 deletions(-) diff --git a/terratorch/tasks/base_task.py b/terratorch/tasks/base_task.py index e59aaf39..69f744b7 100644 --- a/terratorch/tasks/base_task.py +++ b/terratorch/tasks/base_task.py @@ -71,8 +71,9 @@ def on_validation_epoch_end(self) -> None: self.val_metrics.reset() def on_test_epoch_end(self) -> None: - self.log_dict(self.test_metrics.compute(), sync_dist=True) - self.test_metrics.reset() + for metrics in self.test_metrics: + self.log_dict(metrics.compute(), sync_dist=True) + metrics.reset() def _do_plot_samples(self, batch_index): if not self.plot_on_val: # dont plot if self.plot_on_val is 0 diff --git a/terratorch/tasks/classification_tasks.py b/terratorch/tasks/classification_tasks.py index 89974004..bf952ab8 100644 --- a/terratorch/tasks/classification_tasks.py +++ b/terratorch/tasks/classification_tasks.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Any import logging import lightning @@ -16,7 +17,8 @@ from terratorch.tasks.optimizer_factory import optimizer_factory from terratorch.tasks.base_task import TerraTorchTask -logger = logging.getLogger('terratorch') +logger = logging.getLogger("terratorch") + def to_class_prediction(y: ModelOutput) -> Tensor: y_hat = y.output @@ -33,6 +35,7 @@ class ClassificationTask(TerraTorchTask): - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor - It provides mIoU with both Micro and Macro averaging + - Allows to evaluate on multiple test dataloaders .. note:: * 'Micro' averaging suits overall performance evaluation but may not reflect @@ -62,6 +65,7 @@ def __init__( freeze_backbone: bool = False, # noqa: FBT001, FBT002 freeze_decoder: bool = False, # noqa: FBT002, FBT001 class_names: list[str] | None = None, + test_dataloaders_names: list[str] | None = None, ) -> None: """Constructor @@ -97,6 +101,9 @@ def __init__( freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False. class_names (list[str] | None, optional): List of class names passed to metrics for better naming. Defaults to numeric ordering. + test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when + multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, + which assumes only one test dataloader is used. """ self.aux_loss = aux_loss self.aux_heads = aux_heads @@ -116,11 +123,12 @@ def __init__( self.model = model self.train_loss_handler = LossHandler(self.train_metrics.prefix) - self.test_loss_handler = LossHandler(self.test_metrics.prefix) + self.test_loss_handler: list[LossHandler] = [] + for metrics in self.test_metrics: + self.test_loss_handler.append(LossHandler(metrics.prefix)) self.val_loss_handler = LossHandler(self.val_metrics.prefix) self.monitor = f"{self.val_metrics.prefix}loss" - def configure_losses(self) -> None: """Initialize the loss criterion. @@ -131,8 +139,8 @@ def configure_losses(self) -> None: ignore_index = self.hparams["ignore_index"] class_weights = ( - torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None - ) + torch.Tensor(self.hparams["class_weights"]) if self.hparams["class_weights"] is not None else None + ) if loss == "ce": ignore_value = -100 if ignore_index is None else ignore_index self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=class_weights) @@ -187,7 +195,12 @@ def configure_metrics(self) -> None: ) self.train_metrics = metrics.clone(prefix="train/") self.val_metrics = metrics.clone(prefix="val/") - self.test_metrics = metrics.clone(prefix="test/") + if self.hparams["test_dataloaders_names"] is not None: + self.test_metrics = nn.ModuleList( + [metrics.clone(prefix=f"test/{dl_name}/") for dl_name in self.hparams["test_dataloaders_names"]] + ) + else: + self.test_metrics = nn.ModuleList([metrics.clone(prefix="test/")]) def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the train loss and additional metrics. @@ -200,7 +213,7 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> x = batch["image"] y = batch["label"] other_keys = batch.keys() - {"image", "label", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) @@ -221,7 +234,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - x = batch["image"] y = batch["label"] other_keys = batch.keys() - {"image", "label", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) @@ -239,12 +252,19 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None x = batch["image"] y = batch["label"] other_keys = batch.keys() - {"image", "label", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) - loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) - self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) + if dataloader_idx >= len(self.test_loss_handler): + msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names." + raise ValueError(msg) + loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss) + self.test_loss_handler[dataloader_idx].log_loss( + partial(self.log, add_dataloader_idx=False), # We don't need the dataloader idx as prefixes are different + loss_dict=loss, + batch_size=x.shape[0], + ) y_hat_hard = to_class_prediction(model_output) - self.test_metrics.update(y_hat_hard, y) + self.test_metrics[dataloader_idx].update(y_hat_hard, y) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the predicted class probabilities. @@ -260,7 +280,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T x = batch["image"] file_names = batch["filename"] other_keys = batch.keys() - {"image", "label", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) y_hat = self(x).output diff --git a/terratorch/tasks/regression_tasks.py b/terratorch/tasks/regression_tasks.py index 29bbc00f..f856e1b8 100644 --- a/terratorch/tasks/regression_tasks.py +++ b/terratorch/tasks/regression_tasks.py @@ -1,6 +1,7 @@ """This module contains the regression task and its auxiliary classes.""" from collections.abc import Sequence +from functools import partial from typing import Any import logging @@ -24,7 +25,8 @@ BATCH_IDX_FOR_VALIDATION_PLOTTING = 10 -logger = logging.getLogger('terratorch') +logger = logging.getLogger("terratorch") + class RootLossWrapper(nn.Module): def __init__(self, loss_function: nn.Module, reduction: None | str = "mean") -> None: @@ -129,7 +131,8 @@ class PixelwiseRegressionTask(TerraTorchTask): - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - - Allows the setting of optimizers in the constructor""" + - Allows the setting of optimizers in the constructor + - Allows to evaluate on multiple test dataloaders""" def __init__( self, @@ -152,6 +155,7 @@ def __init__( freeze_decoder: bool = False, # noqa: FBT001, FBT002 plot_on_val: bool | int = 10, tiled_inference_parameters: TiledInferenceParameters | None = None, + test_dataloaders_names: list[str] | None = None, ) -> None: """Constructor @@ -186,6 +190,9 @@ def __init__( If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs. tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters used to determine if inference is done on the whole image or through tiling. + test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when + multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, + which assumes only one test dataloader is used. """ self.tiled_inference_parameters = tiled_inference_parameters self.aux_loss = aux_loss @@ -206,7 +213,9 @@ def __init__( self.model = model self.train_loss_handler = LossHandler(self.train_metrics.prefix) - self.test_loss_handler = LossHandler(self.test_metrics.prefix) + self.test_loss_handler: list[LossHandler] = [] + for metrics in self.test_metrics: + self.test_loss_handler.append(LossHandler(metrics.prefix)) self.val_loss_handler = LossHandler(self.val_metrics.prefix) self.monitor = f"{self.val_metrics.prefix}loss" self.plot_on_val = int(plot_on_val) @@ -253,7 +262,17 @@ def wrap_metrics_with_ignore_index(metrics): self.train_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="train/") self.val_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="val/") - self.test_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="test/") + if self.hparams["test_dataloaders_names"] is not None: + self.test_metrics = nn.ModuleList( + [ + MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix=f"test/{dl_name}/") + for dl_name in self.hparams["test_dataloaders_names"] + ] + ) + else: + self.test_metrics = nn.ModuleList( + [MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="test/")] + ) def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the train loss and additional metrics. @@ -266,7 +285,7 @@ def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> x = batch["image"] y = batch["mask"] other_keys = batch.keys() - {"image", "mask", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) loss = self.train_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) @@ -287,7 +306,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - x = batch["image"] y = batch["mask"] other_keys = batch.keys() - {"image", "mask", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0]) @@ -329,12 +348,19 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None x = batch["image"] y = batch["mask"] other_keys = batch.keys() - {"image", "mask", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) - loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) - self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) + if dataloader_idx >= len(self.test_loss_handler): + msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names." + raise ValueError(msg) + loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss) + self.test_loss_handler[dataloader_idx].log_loss( + partial(self.log, add_dataloader_idx=False), # We don't need the dataloader idx as prefixes are different + loss_dict=loss, + batch_size=x.shape[0], + ) y_hat = model_output.output - self.test_metrics.update(y_hat, y) + self.test_metrics[dataloader_idx].update(y_hat, y) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the predicted class probabilities. @@ -350,7 +376,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T x = batch["image"] file_names = batch["filename"] other_keys = batch.keys() - {"image", "mask", "filename"} - rest = {k:batch[k] for k in other_keys} + rest = {k: batch[k] for k in other_keys} def model_forward(x): return self(x).output diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index 48e80221..0d999dd1 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -261,11 +261,6 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None y_hat_hard = to_segmentation_prediction(model_output) self.test_metrics[dataloader_idx].update(y_hat_hard, y) - def on_test_epoch_end(self) -> None: - for metrics in self.test_metrics: - self.log_dict(metrics.compute(), sync_dist=True) - metrics.reset() - def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the validation loss and additional metrics. Args: @@ -291,7 +286,7 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - batch["prediction"] = y_hat_hard if isinstance(batch["image"], dict): - if hasattr(datamodule, 'rgb_modality'): + if hasattr(datamodule, "rgb_modality"): # Generic multimodal dataset batch["image"] = batch["image"][datamodule.rgb_modality] else: @@ -337,7 +332,10 @@ def model_forward(x): if self.tiled_inference_parameters: y_hat: Tensor = tiled_inference( # TODO: tiled inference does not work with additional input data (**rest) - model_forward, x, self.hparams["model_args"]["num_classes"], self.tiled_inference_parameters + model_forward, + x, + self.hparams["model_args"]["num_classes"], + self.tiled_inference_parameters, ) else: y_hat: Tensor = self(x, **rest).output From 7f1105f87d3d19ad7c72fcd5488e4cc1cb45fbbd Mon Sep 17 00:00:00 2001 From: Francesc Marti Escofet Date: Mon, 20 Jan 2025 15:41:02 +0100 Subject: [PATCH 2/6] Trigger tests Signed-off-by: Francesc Marti Escofet From d2fbe8b729f69870d90688cdb59cd2378c59787f Mon Sep 17 00:00:00 2001 From: Francesc Marti Escofet Date: Wed, 22 Jan 2025 09:47:25 +0100 Subject: [PATCH 3/6] Pin jsonargparse Signed-off-by: Francesc Marti Escofet --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 6a477ac4..010efcf6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ # broken due to https://github.com/Lightning-AI/pytorch-lightning/issues/19977 "lightning[pytorch-extra]>=2,!=2.3.*", "segmentation-models-pytorch>=0.3" + "jsonargparse<=4.35.0", # Dependencies not available on PyPI ] From 0c82fd1f058e4a197409ad0266faebd5d790753f Mon Sep 17 00:00:00 2001 From: Francesc Marti Escofet Date: Wed, 22 Jan 2025 10:38:34 +0100 Subject: [PATCH 4/6] Fix Signed-off-by: Francesc Marti Escofet --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 010efcf6..8343f6b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ dependencies = [ "mlflow>=2.12.1", # broken due to https://github.com/Lightning-AI/pytorch-lightning/issues/19977 "lightning[pytorch-extra]>=2,!=2.3.*", - "segmentation-models-pytorch>=0.3" + "segmentation-models-pytorch>=0.3", "jsonargparse<=4.35.0", # Dependencies not available on PyPI ] From 025c2e47b70ad2bfa1fd002b7fd94ce5a74add0c Mon Sep 17 00:00:00 2001 From: Francesc Marti Escofet Date: Wed, 22 Jan 2025 12:37:01 +0100 Subject: [PATCH 5/6] Pin jsonargparse in required.txt Signed-off-by: Francesc Marti Escofet --- requirements/required.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/required.txt b/requirements/required.txt index c4aec6ea..986de635 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -15,6 +15,7 @@ lightning==2.4.0 git+https://github.com/qubvel-org/segmentation_models.pytorch.git@3952e1f8e9684a385a81e30381b8fb5b1ac086cf timm==1.0.11 numpy==1.26.4 +jsonargparse<=4.35.0 # These dependencies are optional # and must be installed just in case From e641e7cf63cc6fa059d76163187f931771d6cd1a Mon Sep 17 00:00:00 2001 From: Francesc Marti Escofet Date: Thu, 23 Jan 2025 10:49:31 +0100 Subject: [PATCH 6/6] Remove pin Signed-off-by: Francesc Marti Escofet --- requirements/required.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements/required.txt b/requirements/required.txt index 986de635..c4aec6ea 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -15,7 +15,6 @@ lightning==2.4.0 git+https://github.com/qubvel-org/segmentation_models.pytorch.git@3952e1f8e9684a385a81e30381b8fb5b1ac086cf timm==1.0.11 numpy==1.26.4 -jsonargparse<=4.35.0 # These dependencies are optional # and must be installed just in case