diff --git a/terratorch/tasks/base_task.py b/terratorch/tasks/base_task.py index 53a26cec..f5ac30b3 100644 --- a/terratorch/tasks/base_task.py +++ b/terratorch/tasks/base_task.py @@ -88,8 +88,9 @@ def on_validation_epoch_end(self) -> None: self.val_metrics.reset() def on_test_epoch_end(self) -> None: - self.log_dict(self.test_metrics.compute(), sync_dist=True) - self.test_metrics.reset() + for metrics in self.test_metrics: + self.log_dict(metrics.compute(), sync_dist=True) + metrics.reset() def _do_plot_samples(self, batch_index): if not self.plot_on_val: # dont plot if self.plot_on_val is 0 diff --git a/terratorch/tasks/classification_tasks.py b/terratorch/tasks/classification_tasks.py index 8630a0ff..249634bf 100644 --- a/terratorch/tasks/classification_tasks.py +++ b/terratorch/tasks/classification_tasks.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Any import logging import lightning @@ -34,6 +35,7 @@ class ClassificationTask(TerraTorchTask): - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - Allows the setting of optimizers in the constructor - It provides mIoU with both Micro and Macro averaging + - Allows to evaluate on multiple test dataloaders .. note:: * 'Micro' averaging suits overall performance evaluation but may not reflect @@ -63,6 +65,7 @@ def __init__( freeze_backbone: bool = False, # noqa: FBT001, FBT002 freeze_decoder: bool = False, # noqa: FBT002, FBT001 class_names: list[str] | None = None, + test_dataloaders_names: list[str] | None = None, lr_overrides: dict[str, float] | None = None, ) -> None: """Constructor @@ -99,6 +102,9 @@ def __init__( freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False. class_names (list[str] | None, optional): List of class names passed to metrics for better naming. Defaults to numeric ordering. + test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when + multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, + which assumes only one test dataloader is used. lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific parameters. The key should be a substring of the parameter names (it will check the substring is contained in the parameter name)and the value should be the new lr. Defaults to None. @@ -121,7 +127,9 @@ def __init__( self.model = model self.train_loss_handler = LossHandler(self.train_metrics.prefix) - self.test_loss_handler = LossHandler(self.test_metrics.prefix) + self.test_loss_handler: list[LossHandler] = [] + for metrics in self.test_metrics: + self.test_loss_handler.append(LossHandler(metrics.prefix)) self.val_loss_handler = LossHandler(self.val_metrics.prefix) self.monitor = f"{self.val_metrics.prefix}loss" @@ -191,7 +199,12 @@ def configure_metrics(self) -> None: ) self.train_metrics = metrics.clone(prefix="train/") self.val_metrics = metrics.clone(prefix="val/") - self.test_metrics = metrics.clone(prefix="test/") + if self.hparams["test_dataloaders_names"] is not None: + self.test_metrics = nn.ModuleList( + [metrics.clone(prefix=f"test/{dl_name}/") for dl_name in self.hparams["test_dataloaders_names"]] + ) + else: + self.test_metrics = nn.ModuleList([metrics.clone(prefix="test/")]) def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the train loss and additional metrics. @@ -245,10 +258,17 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None other_keys = batch.keys() - {"image", "label", "filename"} rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) - loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) - self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) + if dataloader_idx >= len(self.test_loss_handler): + msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names." + raise ValueError(msg) + loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss) + self.test_loss_handler[dataloader_idx].log_loss( + partial(self.log, add_dataloader_idx=False), # We don't need the dataloader idx as prefixes are different + loss_dict=loss, + batch_size=x.shape[0], + ) y_hat_hard = to_class_prediction(model_output) - self.test_metrics.update(y_hat_hard, y) + self.test_metrics[dataloader_idx].update(y_hat_hard, y) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the predicted class probabilities. diff --git a/terratorch/tasks/regression_tasks.py b/terratorch/tasks/regression_tasks.py index c57541fa..c03eeb32 100644 --- a/terratorch/tasks/regression_tasks.py +++ b/terratorch/tasks/regression_tasks.py @@ -1,6 +1,7 @@ """This module contains the regression task and its auxiliary classes.""" from collections.abc import Sequence +from functools import partial from typing import Any import logging @@ -130,7 +131,8 @@ class PixelwiseRegressionTask(TerraTorchTask): - Accepts the specification of a model factory - Logs metrics per class - Does not have any callbacks by default (TorchGeo tasks do early stopping by default) - - Allows the setting of optimizers in the constructor""" + - Allows the setting of optimizers in the constructor + - Allows to evaluate on multiple test dataloaders""" def __init__( self, @@ -153,6 +155,7 @@ def __init__( freeze_decoder: bool = False, # noqa: FBT001, FBT002 plot_on_val: bool | int = 10, tiled_inference_parameters: TiledInferenceParameters | None = None, + test_dataloaders_names: list[str] | None = None, lr_overrides: dict[str, float] | None = None, ) -> None: """Constructor @@ -188,6 +191,9 @@ def __init__( If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs. tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters used to determine if inference is done on the whole image or through tiling. + test_dataloaders_names (list[str] | None, optional): Names used to differentiate metrics when + multiple dataloaders are returned by test_dataloader in the datamodule. Defaults to None, + which assumes only one test dataloader is used. lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific parameters. The key should be a substring of the parameter names (it will check the substring is contained in the parameter name)and the value should be the new lr. Defaults to None. @@ -211,7 +217,9 @@ def __init__( self.model = model self.train_loss_handler = LossHandler(self.train_metrics.prefix) - self.test_loss_handler = LossHandler(self.test_metrics.prefix) + self.test_loss_handler: list[LossHandler] = [] + for metrics in self.test_metrics: + self.test_loss_handler.append(LossHandler(metrics.prefix)) self.val_loss_handler = LossHandler(self.val_metrics.prefix) self.monitor = f"{self.val_metrics.prefix}loss" self.plot_on_val = int(plot_on_val) @@ -258,7 +266,17 @@ def wrap_metrics_with_ignore_index(metrics): self.train_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="train/") self.val_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="val/") - self.test_metrics = MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="test/") + if self.hparams["test_dataloaders_names"] is not None: + self.test_metrics = nn.ModuleList( + [ + MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix=f"test/{dl_name}/") + for dl_name in self.hparams["test_dataloaders_names"] + ] + ) + else: + self.test_metrics = nn.ModuleList( + [MetricCollection(wrap_metrics_with_ignore_index(instantiate_metrics()), prefix="test/")] + ) def training_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the train loss and additional metrics. @@ -336,10 +354,17 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None other_keys = batch.keys() - {"image", "mask", "filename"} rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) - loss = self.test_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) - self.test_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=x.shape[0]) + if dataloader_idx >= len(self.test_loss_handler): + msg = "You are returning more than one test dataloader but not defining enough test_dataloaders_names." + raise ValueError(msg) + loss = self.test_loss_handler[dataloader_idx].compute_loss(model_output, y, self.criterion, self.aux_loss) + self.test_loss_handler[dataloader_idx].log_loss( + partial(self.log, add_dataloader_idx=False), # We don't need the dataloader idx as prefixes are different + loss_dict=loss, + batch_size=x.shape[0], + ) y_hat = model_output.output - self.test_metrics.update(y_hat, y) + self.test_metrics[dataloader_idx].update(y_hat, y) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Tensor: """Compute the predicted class probabilities. diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index 328dbac8..aff17ab0 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -268,11 +268,6 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None y_hat_hard = to_segmentation_prediction(model_output) self.test_metrics[dataloader_idx].update(y_hat_hard, y) - def on_test_epoch_end(self) -> None: - for metrics in self.test_metrics: - self.log_dict(metrics.compute(), sync_dist=True) - metrics.reset() - def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the validation loss and additional metrics. Args: