Skip to content

Commit

Permalink
Merge pull request #299 from IBM/improve/custom
Browse files Browse the repository at this point in the history
Improve/custom
  • Loading branch information
Joao-L-S-Almeida authored Dec 6, 2024
2 parents 5e0bb9e + 07245e8 commit fc1bc25
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 26 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ venv/*
examples/notebooks/config.yaml
examples/notebooks/wxc_input_u_v_t_p_output_theta_uw_vw_era5_training_data_hourly_2015_constant_mu_sigma_scaling05.nc
tests/all_ecos_random/*
examples/**/*tif*
**/climatology/*
**/lightning_logs/*
**/merra-2/*
**/*.bin
*.stdout
*.log
**/*.un~
9 changes: 4 additions & 5 deletions terratorch/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def save_prediction(prediction, input_file_name, out_dir, dtype:str="int16"):
logger.info(f"Saving output to {out_file_name} ...")
write_tiff(result, os.path.join(out_dir, out_file_name), metadata)

def import_custom_modules(custom_modules_path:None | Path | str =None) -> None:

def import_custom_modules(custom_modules_path: str | Path | None = None) -> None:

if custom_modules_path:

Expand Down Expand Up @@ -394,11 +395,9 @@ def instantiate_classes(self) -> None:
elif hasattr(self.config, "predict") and hasattr(self.config.predict, "custom_modules_path"):
custom_modules_path = self.config.predict.custom_modules_path
else:
logger.info("No custom module is being used.")
custom_modules_path = None
custom_modules_path = os.getenv("TERRATORCH_CUSTOM_MODULE_PATH", None)

if custom_modules_path:
import_custom_modules(custom_modules_path)
import_custom_modules(custom_modules_path)

def build_lightning_cli(
args: ArgsType = None,
Expand Down
11 changes: 7 additions & 4 deletions terratorch/models/backbones/prithvi_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from timm.layers import to_2tuple
from timm.models.vision_transformer import Block

logger = logging.getLogger(__name__)

def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
"""
Expand Down Expand Up @@ -152,14 +153,16 @@ def __init__(

self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
self.log_warning = True

def forward(self, x):
B, C, T, H, W = x.shape

if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1:
logging.getLogger(__name__).warning(
f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}."
f"The border will be ignored, add backbone_padding for pixel-wise tasks.")
if (self.log_warning and
(T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1)):
logger.warning(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}."
f"The border will be ignored, add backbone_padding for pixel-wise tasks.")
self.log_warning = False

x = self.proj(x)
if self.flatten:
Expand Down
3 changes: 2 additions & 1 deletion terratorch/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
wxc_present = True
from terratorch.tasks.wxc_downscaling_task import WxCDownscalingTask
except ImportError as e:
print('wxc_downscaling not installed')
import logging
logging.getLogger('terratorch').debug('wxc_downscaling not installed')
wxc_present = False


Expand Down
32 changes: 28 additions & 4 deletions terratorch/tasks/classification_tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any

import logging
import lightning
import matplotlib.pyplot as plt
import torch
Expand All @@ -16,6 +16,7 @@
from terratorch.tasks.loss_handler import LossHandler
from terratorch.tasks.optimizer_factory import optimizer_factory

logger = logging.getLogger('terratorch')

def to_class_prediction(y: ModelOutput) -> Tensor:
y_hat = y.output
Expand Down Expand Up @@ -43,7 +44,8 @@ class ClassificationTask(BaseTask):
def __init__(
self,
model_args: dict,
model_factory: str,
model_factory: str | None = None,
model: torch.nn.Module | None = None,
loss: str = "ce",
aux_heads: list[AuxiliaryHead] | None = None,
aux_loss: dict[str, float] | None = None,
Expand All @@ -67,7 +69,9 @@ def __init__(
Defaults to None.
model_args (Dict): Arguments passed to the model factory.
model_factory (str): ModelFactory class to be used to instantiate the model.
model_factory (str, optional): ModelFactory class to be used to instantiate the model.
Is ignored when model is provided.
model (torch.nn.Module, optional): Custom model.
loss (str, optional): Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss.
Defaults to "ce".
aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
Expand Down Expand Up @@ -96,8 +100,21 @@ def __init__(
"""
self.aux_loss = aux_loss
self.aux_heads = aux_heads
self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

if model is not None and model_factory is not None:
logger.warning("A model_factory and a model was provided. The model_factory is ignored.")
if model is None and model_factory is None:
raise ValueError("A model_factory or a model (torch.nn.Module) must be provided.")

if model_factory and model is None:
self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

super().__init__()

if model:
# Custom model
self.model = model

self.train_loss_handler = LossHandler(self.train_metrics.prefix)
self.test_loss_handler = LossHandler(self.test_metrics.prefix)
self.val_loss_handler = LossHandler(self.val_metrics.prefix)
Expand All @@ -108,9 +125,16 @@ def configure_callbacks(self) -> list[Callback]:
return []

def configure_models(self) -> None:
if not hasattr(self, "model_factory"):
if self.hparams["freeze_backbone"] or self.hparams["freeze_decoder"]:
logger.warning("freeze_backbone and freeze_decoder are ignored if a custom model is provided.")
# Skipping model factory because custom model is provided
return

self.model: Model = self.model_factory.build_model(
"classification", aux_decoders=self.aux_heads, **self.hparams["model_args"]
)

if self.hparams["freeze_backbone"]:
if self.hparams.get("peft_config", None) is not None:
msg = "PEFT should be run with freeze_backbone = False"
Expand Down
36 changes: 31 additions & 5 deletions terratorch/tasks/regression_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Sequence
from typing import Any

import logging
import lightning
import matplotlib.pyplot as plt
import torch
Expand All @@ -23,6 +24,7 @@

BATCH_IDX_FOR_VALIDATION_PLOTTING = 10

logger = logging.getLogger('terratorch')

class RootLossWrapper(nn.Module):
def __init__(self, loss_function: nn.Module, reduction: None | str = "mean") -> None:
Expand Down Expand Up @@ -132,7 +134,8 @@ class PixelwiseRegressionTask(BaseTask):
def __init__(
self,
model_args: dict,
model_factory: str,
model_factory: str | None = None,
model: torch.nn.Module | None = None,
loss: str = "mse",
aux_heads: list[AuxiliaryHead] | None = None,
aux_loss: dict[str, float] | None = None,
Expand All @@ -154,7 +157,9 @@ def __init__(
Args:
model_args (Dict): Arguments passed to the model factory.
model_factory (str): Name of ModelFactory class to be used to instantiate the model.
model_factory (str, optional): Name of ModelFactory class to be used to instantiate the model.
Is ignored when model is provided.
model (torch.nn.Module, optional): Custom model.
loss (str, optional): Loss to be used. Currently, supports 'mse', 'rmse', 'mae' or 'huber' loss.
Defaults to "mse".
aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
Expand Down Expand Up @@ -185,8 +190,21 @@ def __init__(
self.tiled_inference_parameters = tiled_inference_parameters
self.aux_loss = aux_loss
self.aux_heads = aux_heads
self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

if model is not None and model_factory is not None:
logger.warning("A model_factory and a model was provided. The model_factory is ignored.")
if model is None and model_factory is None:
raise ValueError("A model_factory or a model (torch.nn.Module) must be provided.")

if model_factory and model is None:
self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

super().__init__()

if model:
# Custom_model
self.model = model

self.train_loss_handler = LossHandler(self.train_metrics.prefix)
self.test_loss_handler = LossHandler(self.test_metrics.prefix)
self.val_loss_handler = LossHandler(self.val_metrics.prefix)
Expand All @@ -198,14 +216,22 @@ def configure_callbacks(self) -> list[Callback]:
return []

def configure_models(self) -> None:
if not hasattr(self, "model_factory"):
if self.hparams["freeze_backbone"] or self.hparams["freeze_decoder"]:
logger.warning("freeze_backbone and freeze_decoder are ignored if a custom model is provided.")
# Skipping model factory because custom model is provided
return

self.model: Model = self.model_factory.build_model(
"regression", aux_decoders=self.aux_heads, **self.hparams["model_args"]
)

if self.hparams["freeze_backbone"]:
if self.hparams.get("peft_config", None) is not None:
msg = "PEFT should be run with freeze_backbone = False"
raise ValueError(msg)
self.model.freeze_encoder()

if self.hparams["freeze_decoder"]:
self.model.freeze_decoder()

Expand Down Expand Up @@ -393,13 +419,13 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T
file_names = batch["filename"]
other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k:batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)

def model_forward(x):
return self(x).output

if self.tiled_inference_parameters:
# TODO: tiled inference does not work with additional input data (**rest)
y_hat: Tensor = tiled_inference(model_forward, x, 1, self.tiled_inference_parameters)
else:
y_hat: Tensor = self(x).output
y_hat: Tensor = self(x, **rest).output
return y_hat, file_names
40 changes: 33 additions & 7 deletions terratorch/tasks/segmentation_tasks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@

from typing import Any
from functools import partial
import os
from typing import Any

import logging
import lightning
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
Expand All @@ -22,6 +22,7 @@

BATCH_IDX_FOR_VALIDATION_PLOTTING = 10

logger = logging.getLogger('terratorch')

def to_segmentation_prediction(y: ModelOutput) -> Tensor:
y_hat = y.output
Expand All @@ -43,7 +44,8 @@ class SemanticSegmentationTask(BaseTask):
def __init__(
self,
model_args: dict,
model_factory: str,
model_factory: str | None = None,
model: torch.nn.Module | None = None,
loss: str = "ce",
aux_heads: list[AuxiliaryHead] | None = None,
aux_loss: dict[str, float] | None = None,
Expand All @@ -69,7 +71,9 @@ def __init__(
Defaults to None.
model_args (Dict): Arguments passed to the model factory.
model_factory (str): ModelFactory class to be used to instantiate the model.
model_factory (str, optional): ModelFactory class to be used to instantiate the model.
Is ignored when model is provided.
model (torch.nn.Module, optional): Custom model.
loss (str, optional): Loss to be used. Currently, supports 'ce', 'jaccard' or 'focal' loss.
Defaults to "ce".
aux_loss (dict[str, float] | None, optional): Auxiliary loss weights.
Expand Down Expand Up @@ -106,8 +110,21 @@ def __init__(
self.tiled_inference_parameters = tiled_inference_parameters
self.aux_loss = aux_loss
self.aux_heads = aux_heads
self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

if model is not None and model_factory is not None:
logger.warning("A model_factory and a model was provided. The model_factory is ignored.")
if model is None and model_factory is None:
raise ValueError("A model_factory or a model (torch.nn.Module) must be provided.")

if model_factory and model is None:
self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)

super().__init__()

if model is not None:
# Custom model
self.model = model

self.train_loss_handler = LossHandler(self.train_metrics.prefix)
self.test_loss_handler: list[LossHandler] = []
for metrics in self.test_metrics:
Expand All @@ -121,9 +138,16 @@ def configure_callbacks(self) -> list[Callback]:
return []

def configure_models(self) -> None:
if not hasattr(self, "model_factory"):
if self.hparams["freeze_backbone"] or self.hparams["freeze_decoder"]:
logger.warning("freeze_backbone and freeze_decoder are ignored if a custom model is provided.")
# Skipping model factory because custom model is provided
return

self.model: Model = self.model_factory.build_model(
"segmentation", aux_decoders=self.aux_heads, **self.hparams["model_args"]
)

if self.hparams["freeze_backbone"]:
if self.hparams.get("peft_config", None) is not None:
msg = "PEFT should be run with freeze_backbone = False"
Expand Down Expand Up @@ -279,9 +303,11 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -
"""
x = batch["image"]
y = batch["mask"]

other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)

loss = self.val_loss_handler.compute_loss(model_output, y, self.criterion, self.aux_loss)
self.val_loss_handler.log_loss(self.log, loss_dict=loss, batch_size=y.shape[0])
y_hat_hard = to_segmentation_prediction(model_output)
Expand Down Expand Up @@ -368,16 +394,16 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> T
file_names = batch["filename"]
other_keys = batch.keys() - {"image", "mask", "filename"}
rest = {k: batch[k] for k in other_keys}
model_output: ModelOutput = self(x, **rest)

def model_forward(x):
return self(x).output

if self.tiled_inference_parameters:
y_hat: Tensor = tiled_inference(
# TODO: tiled inference does not work with additional input data (**rest)
model_forward, x, self.hparams["model_args"]["num_classes"], self.tiled_inference_parameters
)
else:
y_hat: Tensor = self(x).output
y_hat: Tensor = self(x, **rest).output
y_hat = y_hat.argmax(dim=1)
return y_hat, file_names

0 comments on commit fc1bc25

Please sign in to comment.