Skip to content

Commit

Permalink
Merge pull request #357 from fmartiescofet/unet_decoder
Browse files Browse the repository at this point in the history
Feat: Implement Terratorch UNet decoder
  • Loading branch information
Joao-L-S-Almeida authored Jan 9, 2025
2 parents fb94f7e + a8e99a6 commit 8b14211
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 21 deletions.
14 changes: 13 additions & 1 deletion terratorch/models/decoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,17 @@
from terratorch.models.decoders.upernet_decoder import UperNetDecoder
from terratorch.models.decoders.aspp_head import ASPPSegmentationHead, ASPPRegressionHead
from terratorch.models.decoders.mlp_decoder import MLPDecoder
from terratorch.models.decoders.unet_decoder import UNetDecoder

__all__ = ["FCNDecoder", "UperNetDecoder", "IdentityDecoder", "SatMAEHead", "SatMAEHeadViT", "SMPDecoderWrapper", "ASPPSegmentationHead", "ASPPRegressionHead", "MLPDecoder"]
__all__ = [
"FCNDecoder",
"UperNetDecoder",
"IdentityDecoder",
"SatMAEHead",
"SatMAEHeadViT",
"SMPDecoderWrapper",
"ASPPSegmentationHead",
"ASPPRegressionHead",
"MLPDecoder",
"UNetDecoder",
]
39 changes: 39 additions & 0 deletions terratorch/models/decoders/unet_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
from segmentation_models_pytorch.base.initialization import initialize_decoder
from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder
from torch import nn

from terratorch.registry import TERRATORCH_DECODER_REGISTRY


@TERRATORCH_DECODER_REGISTRY.register
class UNetDecoder(nn.Module):
"""UNetDecoder. Wrapper around UNetDecoder from segmentation_models_pytorch to avoid ignoring the first layer."""

def __init__(
self, embed_dim: list[int], channels: list[int], use_batchnorm: bool = True, attention_type: str | None = None
):
"""Constructor
Args:
embed_dim (list[int]): Input embedding dimension for each input.
channels (list[int]): Channels used in the decoder.
use_batchnorm (bool, optional): Whether to use batchnorm. Defaults to True.
attention_type (str | None, optional): Attention type to use. Defaults to None
"""
super().__init__()
self.decoder = UnetDecoder(
encoder_channels=[embed_dim[0], *embed_dim],
decoder_channels=channels,
n_blocks=len(channels),
use_batchnorm=use_batchnorm,
center=False,
attention_type=attention_type,
)
initialize_decoder(self.decoder)
self.out_channels = channels[-1]

def forward(self, x: list[torch.Tensor]) -> torch.Tensor:
# The first layer is ignored in the original UnetDecoder, so we need to duplicate the first layer
x = [x[0].clone(), *x]
return self.decoder(*x)
28 changes: 26 additions & 2 deletions tests/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,40 @@
import terratorch # noqa: F401

from terratorch.models.decoders.aspp_head import ASPPSegmentationHead
from terratorch.models.decoders.unet_decoder import UNetDecoder
import gc


def test_aspphead():
dilations = (1, 6, 12, 18)
in_channels=6
channels=10
in_channels = 6
channels = 10
decoder = ASPPSegmentationHead(dilations=dilations, in_channels=in_channels, channels=channels, num_classes=2)

image = [torch.from_numpy(np.random.rand(2, 6, 224, 224).astype("float32"))]

assert decoder(image).shape == (2, 2, 224, 224)

gc.collect()


def test_unetdecoder():
embed_dim = [64, 128, 256, 512]
channels = [256, 128, 64, 32]
decoder = UNetDecoder(embed_dim=embed_dim, channels=channels)

image = [
torch.from_numpy(np.random.rand(2, 64, 224, 224).astype("float32")),
torch.from_numpy(np.random.rand(2, 128, 112, 112).astype("float32")),
torch.from_numpy(np.random.rand(2, 256, 56, 56).astype("float32")),
torch.from_numpy(np.random.rand(2, 512, 28, 28).astype("float32")),
]

assert decoder(image).shape == (
2,
32,
448,
448,
) # it doubles the size of the first input as it assumes it is already downsampled from the original image

gc.collect()
45 changes: 33 additions & 12 deletions tests/test_encoder_decoder_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from terratorch.models import EncoderDecoderFactory
from terratorch.models.backbones.prithvi_vit import PRETRAINED_BANDS
from terratorch.models.model import AuxiliaryHead
import gc
import gc

NUM_CHANNELS = 6
NUM_CLASSES = 2
Expand Down Expand Up @@ -37,6 +37,7 @@ def model_factory() -> EncoderDecoderFactory:
def model_input() -> torch.Tensor:
return torch.ones((1, NUM_CHANNELS, 224, 224))


def test_unused_args_raise_exception(model_factory: EncoderDecoderFactory):
with pytest.raises(ValueError) as excinfo:
model_factory.build_model(
Expand All @@ -46,12 +47,13 @@ def test_unused_args_raise_exception(model_factory: EncoderDecoderFactory):
backbone_bands=PRETRAINED_BANDS,
backbone_pretrained=False,
num_classes=NUM_CLASSES,
unused_argument="unused_argument"
unused_argument="unused_argument",
)
assert "unused_argument" in str(excinfo.value)

gc.collect()


@pytest.mark.parametrize("backbone", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"])
def test_create_classification_model(backbone, model_factory: EncoderDecoderFactory, model_input):
model = model_factory.build_model(
Expand All @@ -69,6 +71,7 @@ def test_create_classification_model(backbone, model_factory: EncoderDecoderFact

gc.collect()


@pytest.mark.parametrize("backbone", ["prithvi_eo_v1_100", "prithvi_eo_v2_300"])
def test_create_classification_model_no_in_channels(backbone, model_factory: EncoderDecoderFactory, model_input):
model = model_factory.build_model(
Expand All @@ -86,9 +89,10 @@ def test_create_classification_model_no_in_channels(backbone, model_factory: Enc

gc.collect()


@pytest.mark.parametrize("backbone", ["prithvi_eo_v1_100"])
@pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT)
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder", "UNetDecoder"])
def test_create_pixelwise_model(backbone, task, expected, decoder, model_factory: EncoderDecoderFactory, model_input):
model_args = {
"task": task,
Expand All @@ -100,8 +104,10 @@ def test_create_pixelwise_model(backbone, task, expected, decoder, model_factory

if task == "segmentation":
model_args["num_classes"] = NUM_CLASSES
if decoder == "UperNetDecoder" and backbone.startswith("prithvi_vit"):
if decoder in ["UperNetDecoder", "UNetDecoder"] and backbone.startswith("prithvi_eo"):
model_args["necks"] = VIT_UPERNET_NECK
if decoder == "UNetDecoder":
model_args["decoder_channels"] = [256, 128, 64, 32]

model = model_factory.build_model(**model_args)
model.eval()
Expand All @@ -111,6 +117,7 @@ def test_create_pixelwise_model(backbone, task, expected, decoder, model_factory

gc.collect()


@pytest.mark.parametrize("backbone", ["prithvi_eo_v1_100"])
@pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT)
def test_create_model_with_smp_fpn_decoder(backbone, task, expected, model_factory: EncoderDecoderFactory, model_input):
Expand All @@ -134,6 +141,7 @@ def test_create_model_with_smp_fpn_decoder(backbone, task, expected, model_facto

gc.collect()


@pytest.mark.parametrize("backbone", ["prithvi_eo_v1_100"])
@pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT)
def test_create_model_with_smp_unet_decoder(
Expand All @@ -160,6 +168,7 @@ def test_create_model_with_smp_unet_decoder(

gc.collect()


@pytest.mark.skip(reason="Failing without clear reason.")
@pytest.mark.parametrize("backbone", ["prithvi_eo_v1_100"])
@pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT)
Expand All @@ -186,6 +195,7 @@ def test_create_model_with_smp_deeplabv3plus_decoder(

gc.collect()


@pytest.mark.skipif(not importlib.util.find_spec("mmseg"), reason="mmsegmentation not installed")
@pytest.mark.parametrize("backbone", ["prithvi_eo_v1_100"])
@pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT)
Expand All @@ -199,8 +209,9 @@ def test_create_model_with_mmseg_fcn_decoder(
"decoder_channels": 128,
"backbone_bands": PRETRAINED_BANDS,
"backbone_pretrained": False,
"necks": [{"name": "SelectIndices", "indices": [-1]},
{"name": "ReshapeTokensToImage"},
"necks": [
{"name": "SelectIndices", "indices": [-1]},
{"name": "ReshapeTokensToImage"},
],
}

Expand All @@ -217,6 +228,7 @@ def test_create_model_with_mmseg_fcn_decoder(

gc.collect()


@pytest.mark.skipif(not importlib.util.find_spec("mmseg"), reason="mmsegmentation not installed")
@pytest.mark.parametrize("backbone", ["prithvi_eo_v1_100"])
@pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT)
Expand Down Expand Up @@ -250,9 +262,10 @@ def test_create_model_with_mmseg_uperhead_decoder(

gc.collect()


@pytest.mark.parametrize("backbone", ["prithvi_eo_v1_100"])
@pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT)
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder", "UNetDecoder"])
def test_create_pixelwise_model_no_in_channels(
backbone, task, expected, decoder, model_factory: EncoderDecoderFactory, model_input
):
Expand All @@ -266,8 +279,10 @@ def test_create_pixelwise_model_no_in_channels(

if task == "segmentation":
model_args["num_classes"] = NUM_CLASSES
if decoder == "UperNetDecoder" and backbone.startswith("prithvi_vit"):
if decoder in ["UperNetDecoder", "UNetDecoder"] and backbone.startswith("prithvi_eo"):
model_args["necks"] = VIT_UPERNET_NECK
if decoder == "UNetDecoder":
model_args["decoder_channels"] = [256, 128, 64, 32]

model = model_factory.build_model(**model_args)
model.eval()
Expand All @@ -277,9 +292,10 @@ def test_create_pixelwise_model_no_in_channels(

gc.collect()


@pytest.mark.parametrize("backbone", ["prithvi_eo_v1_100"])
@pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT)
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder", "UNetDecoder"])
def test_create_pixelwise_model_with_aux_heads(
backbone, task, expected, decoder, model_factory: EncoderDecoderFactory, model_input
):
Expand All @@ -296,8 +312,10 @@ def test_create_pixelwise_model_with_aux_heads(
if task == "segmentation":
model_args["num_classes"] = NUM_CLASSES

if decoder == "UperNetDecoder" and backbone.startswith("prithvi_vit"):
if decoder in ["UperNetDecoder", "UNetDecoder"] and backbone.startswith("prithvi_eo"):
model_args["necks"] = VIT_UPERNET_NECK
if decoder == "UNetDecoder":
model_args["decoder_channels"] = [256, 128, 64, 32]

model = model_factory.build_model(**model_args)
model.eval()
Expand All @@ -312,9 +330,10 @@ def test_create_pixelwise_model_with_aux_heads(

gc.collect()


@pytest.mark.parametrize("backbone", ["prithvi_eo_v1_100"])
@pytest.mark.parametrize("task,expected", PIXELWISE_TASK_EXPECTED_OUTPUT)
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder", "UNetDecoder"])
def test_create_pixelwise_model_with_extra_bands(
backbone, task, expected, decoder, model_factory: EncoderDecoderFactory
):
Expand All @@ -329,8 +348,10 @@ def test_create_pixelwise_model_with_extra_bands(
if task == "segmentation":
model_args["num_classes"] = NUM_CLASSES

if decoder == "UperNetDecoder" and backbone.startswith("prithvi_vit"):
if decoder in ["UperNetDecoder", "UNetDecoder"] and backbone.startswith("prithvi_eo"):
model_args["necks"] = VIT_UPERNET_NECK
if decoder == "UNetDecoder":
model_args["decoder_channels"] = [256, 128, 64, 32]
model = model_factory.build_model(**model_args)
model.eval()
model_input = torch.ones((1, NUM_CHANNELS + 1, 224, 224))
Expand Down
20 changes: 14 additions & 6 deletions tests/test_prithvi_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def model_input() -> torch.Tensor:


@pytest.mark.parametrize("backbone", ["prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder", "UNetDecoder"])
@pytest.mark.parametrize("loss", ["ce", "jaccard", "focal", "dice"])
def test_create_segmentation_task(backbone, decoder, loss, model_factory: str):
model_args = {
Expand All @@ -40,8 +40,10 @@ def test_create_segmentation_task(backbone, decoder, loss, model_factory: str):
"num_classes": NUM_CLASSES,
}

if decoder == "UperNetDecoder" and backbone.startswith("prithvi_vit"):
if decoder in ["UperNetDecoder", "UNetDecoder"] and backbone.startswith("prithvi_eo"):
model_args["necks"] = VIT_UPERNET_NECK
if decoder == "UNetDecoder":
model_args["decoder_channels"] = [256, 128, 64, 32]
SemanticSegmentationTask(
model_args,
model_factory,
Expand All @@ -50,8 +52,9 @@ def test_create_segmentation_task(backbone, decoder, loss, model_factory: str):

gc.collect()


@pytest.mark.parametrize("backbone", ["prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder", "UNetDecoder"])
@pytest.mark.parametrize("loss", ["mae", "rmse", "huber"])
def test_create_regression_task(backbone, decoder, loss, model_factory: str):
model_args = {
Expand All @@ -61,8 +64,10 @@ def test_create_regression_task(backbone, decoder, loss, model_factory: str):
"backbone_pretrained": False,
}

if decoder == "UperNetDecoder" and backbone.startswith("prithvi_vit"):
if decoder in ["UperNetDecoder", "UNetDecoder"] and backbone.startswith("prithvi_eo"):
model_args["necks"] = VIT_UPERNET_NECK
if decoder == "UNetDecoder":
model_args["decoder_channels"] = [256, 128, 64, 32]

PixelwiseRegressionTask(
model_args,
Expand All @@ -72,8 +77,9 @@ def test_create_regression_task(backbone, decoder, loss, model_factory: str):

gc.collect()


@pytest.mark.parametrize("backbone", ["prithvi_eo_v1_100", "prithvi_eo_v2_300", "prithvi_swin_B"])
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder", "UNetDecoder"])
@pytest.mark.parametrize("loss", ["ce", "bce", "jaccard", "focal"])
def test_create_classification_task(backbone, decoder, loss, model_factory: str):
model_args = {
Expand All @@ -84,8 +90,10 @@ def test_create_classification_task(backbone, decoder, loss, model_factory: str)
"num_classes": NUM_CLASSES,
}

if decoder == "UperNetDecoder" and backbone.startswith("prithvi_vit"):
if decoder in ["UperNetDecoder", "UNetDecoder"] and backbone.startswith("prithvi_eo"):
model_args["necks"] = VIT_UPERNET_NECK
if decoder == "UNetDecoder":
model_args["decoder_channels"] = [256, 128, 64, 32]

ClassificationTask(
model_args,
Expand Down

0 comments on commit 8b14211

Please sign in to comment.