diff --git a/.gitignore b/.gitignore index a88104df..1b3897ff 100644 --- a/.gitignore +++ b/.gitignore @@ -12,9 +12,11 @@ venv/* examples/notebooks/config.yaml examples/notebooks/wxc_input_u_v_t_p_output_theta_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling05.nc tests/all_ecos_random/* +examples/**/*tif* **/climatology/* **/lightning_logs/* **/merra-2/* **/*.bin *.stdout *.log +**/*.un~ diff --git a/terratorch/cli_tools.py b/terratorch/cli_tools.py index ebd1c971..a00f2ad3 100644 --- a/terratorch/cli_tools.py +++ b/terratorch/cli_tools.py @@ -101,7 +101,8 @@ def save_prediction(prediction, input_file_name, out_dir, dtype:str="int16"): logger.info(f"Saving output to {out_file_name} ...") write_tiff(result, os.path.join(out_dir, out_file_name), metadata) -def import_custom_modules(custom_modules_path:None | Path | str =None) -> None: + +def import_custom_modules(custom_modules_path: str | Path | None = None) -> None: if custom_modules_path: @@ -394,11 +395,9 @@ def instantiate_classes(self) -> None: elif hasattr(self.config, "predict") and hasattr(self.config.predict, "custom_modules_path"): custom_modules_path = self.config.predict.custom_modules_path else: - logger.info("No custom module is being used.") - custom_modules_path = None + custom_modules_path = os.getenv("TERRATORCH_CUSTOM_MODULE_PATH", None) - if custom_modules_path: - import_custom_modules(custom_modules_path) + import_custom_modules(custom_modules_path) def build_lightning_cli( args: ArgsType = None, diff --git a/terratorch/models/backbones/prithvi_mae.py b/terratorch/models/backbones/prithvi_mae.py index e1a9e52a..c209b25d 100644 --- a/terratorch/models/backbones/prithvi_mae.py +++ b/terratorch/models/backbones/prithvi_mae.py @@ -28,6 +28,7 @@ from timm.layers import to_2tuple from timm.models.vision_transformer import Block +logger = logging.getLogger(__name__) def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): """ @@ -152,14 +153,16 @@ def __init__( self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + self.log_warning = True def forward(self, x): B, C, T, H, W = x.shape - if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1: - logging.getLogger(__name__).warning( - f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}." - f"The border will be ignored, add backbone_padding for pixel-wise tasks.") + if (self.log_warning and + (T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1)): + logger.warning(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}." + f"The border will be ignored, add backbone_padding for pixel-wise tasks.") + self.log_warning = False x = self.proj(x) if self.flatten: diff --git a/terratorch/tasks/__init__.py b/terratorch/tasks/__init__.py index 82eea1be..348cc501 100644 --- a/terratorch/tasks/__init__.py +++ b/terratorch/tasks/__init__.py @@ -6,7 +6,8 @@ wxc_present = True from terratorch.tasks.wxc_downscaling_task import WxCDownscalingTask except ImportError as e: - print('wxc_downscaling not installed') + import logging + logging.getLogger('terratorch').debug('wxc_downscaling not installed') wxc_present = False diff --git a/terratorch/tasks/classification_tasks.py b/terratorch/tasks/classification_tasks.py index be585eeb..c7ab25bc 100644 --- a/terratorch/tasks/classification_tasks.py +++ b/terratorch/tasks/classification_tasks.py @@ -1,5 +1,5 @@ from typing import Any - +import logging import lightning import matplotlib.pyplot as plt import torch @@ -16,6 +16,7 @@ from terratorch.tasks.loss_handler import LossHandler from terratorch.tasks.optimizer_factory import optimizer_factory +logger = logging.getLogger('terratorch') def to_class_prediction(y: ModelOutput) -> Tensor: y_hat = y.output @@ -43,7 +44,8 @@ class ClassificationTask(BaseTask): def __init__( self, model_args: dict, - model_factory: str, + model_factory: str | None = None, + model: torch.nn.Module | None = None, loss: str = "ce", aux_heads: list[AuxiliaryHead] | None = None, aux_loss: dict[str, float] | None = None, @@ -67,7 +69,9 @@ def __init__( Defaults to None. model_args (Dict): Arguments passed to the model factory. - model_factory (str): ModelFactory class to be used to instantiate the model. + model_factory (str, optional): ModelFactory class to be used to instantiate the model. + Is ignored when model is provided. + model (torch.nn.Module, optional): Custom model. loss (str, optional): Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss. Defaults to "ce". aux_loss (dict[str, float] | None, optional): Auxiliary loss weights. @@ -96,8 +100,21 @@ def __init__( """ self.aux_loss = aux_loss self.aux_heads = aux_heads - self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory) + + if model is not None and model_factory is not None: + logger.warning("A model_factory and a model was provided. The model_factory is ignored.") + if model is None and model_factory is None: + raise ValueError("A model_factory or a model (torch.nn.Module) must be provided.") + + if model_factory and model is None: + self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory) + super().__init__() + + if model: + # Custom model + self.model = model + self.train_loss_handler = LossHandler(self.train_metrics.prefix) self.test_loss_handler = LossHandler(self.test_metrics.prefix) self.val_loss_handler = LossHandler(self.val_metrics.prefix) @@ -108,9 +125,16 @@ def configure_callbacks(self) -> list[Callback]: return [] def configure_models(self) -> None: + if not hasattr(self, "model_factory"): + if self.hparams["freeze_backbone"] or self.hparams["freeze_decoder"]: + logger.warning("freeze_backbone and freeze_decoder are ignored if a custom model is provided.") + # Skipping model factory because custom model is provided + return + self.model: Model = self.model_factory.build_model( "classification", aux_decoders=self.aux_heads, **self.hparams["model_args"] ) + if self.hparams["freeze_backbone"]: if self.hparams.get("peft_config", None) is not None: msg = "PEFT should be run with freeze_backbone = False" diff --git a/terratorch/tasks/regression_tasks.py b/terratorch/tasks/regression_tasks.py index 3e9d2cde..a7211a37 100644 --- a/terratorch/tasks/regression_tasks.py +++ b/terratorch/tasks/regression_tasks.py @@ -3,6 +3,7 @@ from collections.abc import Sequence from typing import Any +import logging import lightning import matplotlib.pyplot as plt import torch @@ -23,6 +24,7 @@ BATCH_IDX_FOR_VALIDATION_PLOTTING = 10 +logger = logging.getLogger('terratorch') class RootLossWrapper(nn.Module): def __init__(self, loss_function: nn.Module, reduction: None | str = "mean") -> None: @@ -132,7 +134,8 @@ class PixelwiseRegressionTask(BaseTask): def __init__( self, model_args: dict, - model_factory: str, + model_factory: str | None = None, + model: torch.nn.Module | None = None, loss: str = "mse", aux_heads: list[AuxiliaryHead] | None = None, aux_loss: dict[str, float] | None = None, @@ -154,7 +157,9 @@ def __init__( Args: model_args (Dict): Arguments passed to the model factory. - model_factory (str): Name of ModelFactory class to be used to instantiate the model. + model_factory (str, optional): Name of ModelFactory class to be used to instantiate the model. + Is ignored when model is provided. + model (torch.nn.Module, optional): Custom model. loss (str, optional): Loss to be used. Currently, supports 'mse', 'rmse', 'mae' or 'huber' loss. Defaults to "mse". aux_loss (dict[str, float] | None, optional): Auxiliary loss weights. @@ -185,8 +190,21 @@ def __init__( self.tiled_inference_parameters = tiled_inference_parameters self.aux_loss = aux_loss self.aux_heads = aux_heads - self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory) + + if model is not None and model_factory is not None: + logger.warning("A model_factory and a model was provided. The model_factory is ignored.") + if model is None and model_factory is None: + raise ValueError("A model_factory or a model (torch.nn.Module) must be provided.") + + if model_factory and model is None: + self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory) + super().__init__() + + if model: + # Custom_model + self.model = model + self.train_loss_handler = LossHandler(self.train_metrics.prefix) self.test_loss_handler = LossHandler(self.test_metrics.prefix) self.val_loss_handler = LossHandler(self.val_metrics.prefix) @@ -198,14 +216,22 @@ def configure_callbacks(self) -> list[Callback]: return [] def configure_models(self) -> None: + if not hasattr(self, "model_factory"): + if self.hparams["freeze_backbone"] or self.hparams["freeze_decoder"]: + logger.warning("freeze_backbone and freeze_decoder are ignored if a custom model is provided.") + # Skipping model factory because custom model is provided + return + self.model: Model = self.model_factory.build_model( "regression", aux_decoders=self.aux_heads, **self.hparams["model_args"] ) + if self.hparams["freeze_backbone"]: if self.hparams.get("peft_config", None) is not None: msg = "PEFT should be run with freeze_backbone = False" raise ValueError(msg) self.model.freeze_encoder() + if self.hparams["freeze_decoder"]: self.model.freeze_decoder() @@ -393,13 +419,13 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T file_names = batch["filename"] other_keys = batch.keys() - {"image", "mask", "filename"} rest = {k:batch[k] for k in other_keys} - model_output: ModelOutput = self(x, **rest) def model_forward(x): return self(x).output if self.tiled_inference_parameters: + # TODO: tiled inference does not work with additional input data (**rest) y_hat: Tensor = tiled_inference(model_forward, x, 1, self.tiled_inference_parameters) else: - y_hat: Tensor = self(x).output + y_hat: Tensor = self(x, **rest).output return y_hat, file_names diff --git a/terratorch/tasks/segmentation_tasks.py b/terratorch/tasks/segmentation_tasks.py index 523c3250..290bdccd 100644 --- a/terratorch/tasks/segmentation_tasks.py +++ b/terratorch/tasks/segmentation_tasks.py @@ -1,8 +1,8 @@ +from typing import Any from functools import partial import os -from typing import Any - +import logging import lightning import matplotlib.pyplot as plt import segmentation_models_pytorch as smp @@ -22,6 +22,7 @@ BATCH_IDX_FOR_VALIDATION_PLOTTING = 10 +logger = logging.getLogger('terratorch') def to_segmentation_prediction(y: ModelOutput) -> Tensor: y_hat = y.output @@ -43,7 +44,8 @@ class SemanticSegmentationTask(BaseTask): def __init__( self, model_args: dict, - model_factory: str, + model_factory: str | None = None, + model: torch.nn.Module | None = None, loss: str = "ce", aux_heads: list[AuxiliaryHead] | None = None, aux_loss: dict[str, float] | None = None, @@ -69,7 +71,9 @@ def __init__( Defaults to None. model_args (Dict): Arguments passed to the model factory. - model_factory (str): ModelFactory class to be used to instantiate the model. + model_factory (str, optional): ModelFactory class to be used to instantiate the model. + Is ignored when model is provided. + model (torch.nn.Module, optional): Custom model. loss (str, optional): Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss. Defaults to "ce". aux_loss (dict[str, float] | None, optional): Auxiliary loss weights. @@ -106,8 +110,21 @@ def __init__( self.tiled_inference_parameters = tiled_inference_parameters self.aux_loss = aux_loss self.aux_heads = aux_heads - self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory) + + if model is not None and model_factory is not None: + logger.warning("A model_factory and a model was provided. The model_factory is ignored.") + if model is None and model_factory is None: + raise ValueError("A model_factory or a model (torch.nn.Module) must be provided.") + + if model_factory and model is None: + self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory) + super().__init__() + + if model is not None: + # Custom model + self.model = model + self.train_loss_handler = LossHandler(self.train_metrics.prefix) self.test_loss_handler: list[LossHandler] = [] for metrics in self.test_metrics: @@ -121,9 +138,16 @@ def configure_callbacks(self) -> list[Callback]: return [] def configure_models(self) -> None: + if not hasattr(self, "model_factory"): + if self.hparams["freeze_backbone"] or self.hparams["freeze_decoder"]: + logger.warning("freeze_backbone and freeze_decoder are ignored if a custom model is provided.") + # Skipping model factory because custom model is provided + return + self.model: Model = self.model_factory.build_model( "segmentation", aux_decoders=self.aux_heads, **self.hparams["model_args"] ) + if self.hparams["freeze_backbone"]: if self.hparams.get("peft_config", None) is not None: msg = "PEFT should be run with freeze_backbone = False" @@ -279,9 +303,11 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) - """ x = batch["image"] y = batch["mask"] + other_keys = batch.keys() - {"image", "mask", "filename"} rest = {k: batch[k] for k in other_keys} model_output: ModelOutput = self(x, **rest) + loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss) self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0]) y_hat_hard = to_segmentation_prediction(model_output) @@ -368,16 +394,16 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T file_names = batch["filename"] other_keys = batch.keys() - {"image", "mask", "filename"} rest = {k: batch[k] for k in other_keys} - model_output: ModelOutput = self(x, **rest) def model_forward(x): return self(x).output if self.tiled_inference_parameters: y_hat: Tensor = tiled_inference( + # TODO: tiled inference does not work with additional input data (**rest) model_forward, x, self.hparams["model_args"]["num_classes"], self.tiled_inference_parameters ) else: - y_hat: Tensor = self(x).output + y_hat: Tensor = self(x, **rest).output y_hat = y_hat.argmax(dim=1) return y_hat, file_names