Skip to content

Commit

Permalink
Merge pull request #342 from IBM/add/pad
Browse files Browse the repository at this point in the history
Adding padding at the input when necessary
  • Loading branch information
romeokienzler authored Jan 29, 2025
2 parents 3cfee48 + 0dc1e95 commit 519ee8a
Show file tree
Hide file tree
Showing 28 changed files with 665 additions and 64 deletions.
52 changes: 52 additions & 0 deletions examples/scripts/create_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from PIL import Image
import os
import random
import numpy as np
import tifffile as tiff
from argparse import ArgumentParser
from osgeo import gdal
from osgeo import osr

parser = ArgumentParser()
parser.add_argument("--input_file")
parser.add_argument("--output_dir")
parser.add_argument("--n_copies", type=int, default=2)

args = parser.parse_args()
input_file = args.input_file
output_dir = args.output_dir
n_copies = args.n_copies

pad_limit = 4

# config
GDAL_DATA_TYPE = gdal.GDT_Int32
GEOTIFF_DRIVER_NAME = r'GTiff'
NO_DATA = 15
SPATIAL_REFERENCE_SYSTEM_WKID = 4326

for c in range(n_copies):

pad = 3#random.randint(1, pad_limit)
filename = os.path.split(input_file)[-1]
output_file = os.path.join(output_dir, filename.replace(".tif", f"_{c}.tif"))
print(pad)
imarray = tiff.imread(input_file)
im_shape = imarray.shape
im_shape_ext = tuple([i+2*pad for i in list(im_shape[:-1])]) + (im_shape[-1],)
#print(im_shape_ext)
output = np.zeros(im_shape_ext)
#print(output.shape)
output[pad:-pad, pad:-pad, :] = imarray
#print(output.shape)
#tiff.imwrite(output_file, output)

# create driver
driver = gdal.GetDriverByName(GEOTIFF_DRIVER_NAME)

output_raster = driver.Create(output_file,
output.shape[1],
output.shape[0],
output.shape[-1],
eType = GDAL_DATA_TYPE)

14 changes: 0 additions & 14 deletions terratorch/models/backbones/prithvi_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,20 +157,6 @@ def checkpoint_filter_fn_mae(

return state_dict


def pad_images(imgs: Tensor,patch_size: int, padding:str) -> Tensor:
p = patch_size
# h, w = imgs.shape[3], imgs.shape[4]
t, h, w = imgs.shape[-3:]
h_pad, w_pad = (p - h % p) % p, (p - w % p) % p # Ensure padding is within bounds
if h_pad > 0 or w_pad > 0:
imgs = torch.stack([
nn.functional.pad(img, (0, w_pad, 0, h_pad), mode=padding)
for img in imgs # Apply per image to avoid NotImplementedError from torch.nn.functional.pad
])
return imgs


def _create_prithvi(
variant: str,
pretrained: bool = False, # noqa: FBT001, FBT002
Expand Down
4 changes: 2 additions & 2 deletions terratorch/models/backbones/select_patch_embed_weights.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright contributors to the Terratorch project


import logging
import warnings

Expand All @@ -14,6 +13,7 @@ def patch_embed_weights_are_compatible(model_patch_embed: torch.Tensor, checkpoi
# check all dimensions are the same except for channel dimension
if len(model_patch_embed.shape) != len(checkpoint_patch_embed.shape):
return False

model_shape = [model_patch_embed.shape[i] for i in range(len(model_patch_embed.shape)) if i != 1]
checkpoint_shape = [checkpoint_patch_embed.shape[i] for i in range(len(checkpoint_patch_embed.shape)) if i != 1]
return model_shape == checkpoint_shape
Expand Down Expand Up @@ -68,4 +68,4 @@ def select_patch_embed_weights(

state_dict[proj_key] = temp_weight

return state_dict
return state_dict
24 changes: 21 additions & 3 deletions terratorch/models/clay_model_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib
import sys
from collections.abc import Callable
import logging

import timm
import torch
Expand Down Expand Up @@ -108,6 +109,7 @@ def build_model(

# Path for accessing the model source code.
self.syspath_kwarg = "model_sys_path"
backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")

# TODO: support auxiliary heads
if not isinstance(backbone, nn.Module):
Expand All @@ -120,8 +122,6 @@ def build_model(
msg = f"Task {task} not supported. Please choose one of {SUPPORTED_TASKS}"
raise NotImplementedError(msg)

backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")

# Trying to find the model on HuggingFace.
try:
backbone: nn.Module = timm.create_model(
Expand All @@ -143,6 +143,16 @@ def build_model(
backbone: nn.Module = Embedder(ckpt_path=checkpoint_path, **backbone_kwargs)
print("Model Clay was successfully restored.")

# If patch size is not provided in the config or by the model, it might lead to errors due to irregular images.
patch_size = backbone_kwargs.get("patch_size", None)
if patch_size is None:
# Infer patch size from model by checking all backbone modules
for module in backbone.modules():
if hasattr(module, "patch_size"):
patch_size = module.patch_size
break
padding = backbone_kwargs.get("padding", "reflect")

# allow decoder to be a module passed directly
decoder_cls = _get_decoder(decoder)
decoder_kwargs, kwargs = extract_prefix_keys(kwargs, "decoder_")
Expand All @@ -157,7 +167,7 @@ def build_model(
head_kwargs["num_classes"] = num_classes
if aux_decoders is None:
return _build_appropriate_model(
task, backbone, decoder, head_kwargs, prepare_features_for_image_model, rescale=rescale
task, backbone, decoder, head_kwargs, prepare_features_for_image_model, patch_size=patch_size, padding=padding, rescale=rescale
)

to_be_aux_decoders: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] = []
Expand Down Expand Up @@ -186,6 +196,8 @@ def build_model(
decoder,
head_kwargs,
prepare_features_for_image_model,
patch_size=patch_size,
padding=padding,
rescale=rescale,
auxiliary_heads=to_be_aux_decoders,
)
Expand All @@ -197,6 +209,8 @@ def _build_appropriate_model(
decoder: nn.Module,
head_kwargs: dict,
prepare_features_for_image_model: Callable,
patch_size: int | list | None,
padding: str,
rescale: bool = True, # noqa: FBT001, FBT002
auxiliary_heads: dict | None = None,
):
Expand All @@ -206,6 +220,8 @@ def _build_appropriate_model(
backbone,
decoder,
head_kwargs,
patch_size=patch_size,
padding=padding,
rescale=rescale,
auxiliary_heads=auxiliary_heads,
)
Expand All @@ -215,6 +231,8 @@ def _build_appropriate_model(
backbone,
decoder,
head_kwargs,
patch_size=patch_size,
padding=padding,
auxiliary_heads=auxiliary_heads,
)

Expand Down
27 changes: 25 additions & 2 deletions terratorch/models/encoder_decoder_factory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright contributors to the Terratorch project


from typing import List
import warnings

import logging
from torch import nn

from terratorch.models.model import (
Expand Down Expand Up @@ -65,6 +65,8 @@ def _check_all_args_used(kwargs):
msg = f"arguments {kwargs} were passed but not used."
raise ValueError(msg)

def _get_argument_from_instance(model, name):
return getattr(model._timm_module.patch_embed, name)[-1]

@MODEL_FACTORY_REGISTRY.register
class EncoderDecoderFactory(ModelFactory):
Expand Down Expand Up @@ -128,6 +130,17 @@ def build_model(
backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_")
backbone = _get_backbone(backbone, **backbone_kwargs)

# If patch size is not provided in the config or by the model, it might lead to errors due to irregular images.
patch_size = backbone_kwargs.get("patch_size", None)

if patch_size is None:
# Infer patch size from model by checking all backbone modules
for module in backbone.modules():
if hasattr(module, "patch_size"):
patch_size = module.patch_size
break
padding = backbone_kwargs.get("padding", "reflect")

if peft_config is not None:
if not backbone_kwargs.get("pretrained", False):
msg = (
Expand Down Expand Up @@ -166,6 +179,8 @@ def build_model(
backbone,
decoder,
head_kwargs,
patch_size=patch_size,
padding=padding,
necks=neck_list,
decoder_includes_head=decoder_includes_head,
rescale=rescale,
Expand All @@ -191,6 +206,8 @@ def build_model(
backbone,
decoder,
head_kwargs,
patch_size=patch_size,
padding=padding,
necks=neck_list,
decoder_includes_head=decoder_includes_head,
rescale=rescale,
Expand All @@ -203,6 +220,8 @@ def _build_appropriate_model(
backbone: nn.Module,
decoder: nn.Module,
head_kwargs: dict,
patch_size: int | list | None,
padding: str,
decoder_includes_head: bool = False,
necks: list[Neck] | None = None,
rescale: bool = True, # noqa: FBT001, FBT002
Expand All @@ -218,6 +237,8 @@ def _build_appropriate_model(
backbone,
decoder,
head_kwargs,
patch_size=patch_size,
padding=padding,
decoder_includes_head=decoder_includes_head,
neck=neck_module,
rescale=rescale,
Expand All @@ -229,6 +250,8 @@ def _build_appropriate_model(
backbone,
decoder,
head_kwargs,
patch_size=patch_size,
padding=padding,
decoder_includes_head=decoder_includes_head,
neck=neck_module,
auxiliary_heads=auxiliary_heads,
Expand Down
49 changes: 32 additions & 17 deletions terratorch/models/pixel_wise_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# Copyright contributors to the Terratorch project

from typing import List
import logging
import torch
import torch.nn.functional as F # noqa: N812
import torchvision.transforms as transforms
from segmentation_models_pytorch.base import SegmentationModel
from torch import nn

from terratorch.models.heads import RegressionHead, SegmentationHead
from terratorch.models.model import AuxiliaryHeadWithDecoderWithoutInstantiatedHead, Model, ModelOutput

from terratorch.models.utils import pad_images

def freeze_module(module: nn.Module):
for param in module.parameters():
Expand All @@ -26,6 +28,8 @@ def __init__(
encoder: nn.Module,
decoder: nn.Module,
head_kwargs: dict,
patch_size: int = None,
padding: str = None,
decoder_includes_head: bool = False,
auxiliary_heads: list[AuxiliaryHeadWithDecoderWithoutInstantiatedHead] | None = None,
neck: nn.Module | None = None,
Expand Down Expand Up @@ -69,6 +73,8 @@ def __init__(

self.neck = neck
self.rescale = rescale
self.patch_size = patch_size
self.padding = padding

def freeze_encoder(self):
freeze_module(self.encoder)
Expand All @@ -77,10 +83,6 @@ def freeze_decoder(self):
freeze_module(self.decoder)
freeze_module(self.head)

# TODO: do this properly
def check_input_shape(self, x: torch.Tensor) -> bool: # noqa: ARG002
return True

@staticmethod
def _check_for_single_channel_and_squeeze(x):
if x.shape[1] == 1:
Expand All @@ -89,19 +91,27 @@ def _check_for_single_channel_and_squeeze(x):

def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput:
"""Sequentially pass `x` through model`s encoder, decoder and heads"""
self.check_input_shape(x)
if isinstance(x, torch.Tensor):
input_size = x.shape[-2:]
elif hasattr(kwargs, 'image_size'):
input_size = kwargs['image_size']
elif isinstance(x, dict):
# Multimodal input in passed as dict
input_size = list(x.values())[0].shape[-2:]
else:
ValueError('Could not infer input shape.')

def _get_size(x):
if isinstance(x, torch.Tensor):
return x.shape[-2:]
elif isinstance(x, dict):
# Multimodal input in passed as dict (Assuming first modality to be an image)
return list(x.values())[0].shape[-2:]
elif hasattr(kwargs, 'image_size'):
return kwargs['image_size']
else:
ValueError('Could not infer image shape.')

image_size = _get_size(x)
if isinstance(x, torch.Tensor) and self.patch_size:
# Only works for single image modalities
x = pad_images(x, self.patch_size, self.padding)
input_size = _get_size(x)

features = self.encoder(x, **kwargs)

## only for backwards compatibility with pre-neck times.
# only for backwards compatibility with pre-neck times.
if self.neck:
prepare = self.neck
else:
Expand All @@ -114,13 +124,18 @@ def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput:
if self.rescale and mask.shape[-2:] != input_size:
mask = F.interpolate(mask, size=input_size, mode="bilinear")
mask = self._check_for_single_channel_and_squeeze(mask)
mask = mask[..., :image_size[0], :image_size[1]]

aux_outputs = {}
for name, decoder in self.aux_heads.items():
aux_output = decoder([f.clone() for f in features])
if self.rescale and aux_output.shape[-2:] != input_size:
aux_output = F.interpolate(aux_output, size=input_size, mode="bilinear")
aux_output = self._check_for_single_channel_and_squeeze(aux_output)
aux_output = aux_output[..., :image_size[0], :image_size[1]]
aux_outputs[name] = aux_output


return ModelOutput(output=mask, auxiliary_heads=aux_outputs)

def _get_head(self, task: str, input_embed_dim: int, head_kwargs):
Expand Down
Loading

0 comments on commit 519ee8a

Please sign in to comment.