Skip to content

Commit

Permalink
Merge pull request #62 from IBM/feature/upernet_scale_modules
Browse files Browse the repository at this point in the history
add scale modules to upernet for vit backbone
  • Loading branch information
CarlosGomes98 authored Jul 30, 2024
2 parents da480f0 + 57b8851 commit 3670c7b
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 229 deletions.
133 changes: 30 additions & 103 deletions terratorch/models/decoders/upernet_decoder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
# Copyright contributors to the Terratorch project

import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn

"""
Adapted from https://github.com/yassouali/pytorch-segmentation/blob/master/models/upernet.py
"""


class ConvModule(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding=0, inplace=False) -> None: # noqa: FBT002
super().__init__()
Expand All @@ -19,103 +15,6 @@ def __init__(self, in_channels, out_channels, kernel_size, padding=0, inplace=Fa
def forward(self, x):
return self.act(self.norm(self.conv(x)))


# class PSPModule(nn.Module):
# # In the original inmplementation they use precise RoI pooling
# # Instead of using adaptative average pooling
# def __init__(self, in_channels: int, bin_sizes: list[int] | None = None):
# super().__init__()
# if bin_sizes is None:
# bin_sizes = [1, 2, 3, 6]
# out_channels = in_channels // len(bin_sizes)
# self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s) for b_s in bin_sizes])
# self.bottleneck = nn.Sequential(
# nn.Conv2d(
# in_channels + (out_channels * len(bin_sizes)),
# in_channels,
# kernel_size=3,
# padding=1,
# bias=False,
# ),
# nn.BatchNorm2d(in_channels),
# nn.ReLU(inplace=True),
# nn.Dropout2d(0.1),
# )

# def _make_stages(self, in_channels, out_channels, bin_sz):
# prior = nn.AdaptiveAvgPool2d(output_size=bin_sz)
# conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
# bn = nn.BatchNorm2d(out_channels)
# relu = nn.ReLU(inplace=True)
# return nn.Sequential(prior, conv, bn, relu)

# def forward(self, features):
# h, w = features.size()[2], features.size()[3]
# pyramids = [features]
# pyramids.extend(
# [F.interpolate(stage(features), size=(h, w), mode="bilinear", align_corners=True) for stage in self.stages]
# )
# output = self.bottleneck(torch.cat(pyramids, dim=1))
# return output


# def up_and_add(x, y):
# return F.interpolate(x, size=(y.size(2), y.size(3)), mode="bilinear", align_corners=True) + y


# class FPNFuse(nn.Module):
# def __init__(self, feature_channels=None, fpn_out=256):
# super().__init__()
# if feature_channels is None:
# feature_channels = [256, 512, 1024, 2048]
# if not feature_channels[0] == fpn_out:
# msg = f"First index of feature channel ({feature_channels[0]}) did not match fpn_out ({fpn_out})"
# raise Exception(msg)
# self.conv1x1 = nn.ModuleList([nn.Conv2d(ft_size, fpn_out, kernel_size=1) for ft_size in feature_channels[1:]])
# self.smooth_conv = nn.ModuleList(
# [nn.Conv2d(fpn_out, fpn_out, kernel_size=3, padding=1)] * (len(feature_channels) - 1)
# )
# self.conv_fusion = nn.Sequential(
# nn.Conv2d(
# len(feature_channels) * fpn_out,
# fpn_out,
# kernel_size=3,
# padding=1,
# bias=False,
# ),
# nn.BatchNorm2d(fpn_out),
# nn.ReLU(inplace=True),
# )

# def forward(self, features):
# features[1:] = [conv1x1(feature) for feature, conv1x1 in zip(features[1:], self.conv1x1, strict=False)]
# p = [up_and_add(features[i], features[i - 1]) for i in reversed(range(1, len(features)))]
# p = [smooth_conv(x) for smooth_conv, x in zip(self.smooth_conv, p, strict=False)]
# p = list(reversed(p))
# p.append(features[-1]) # P = [P1, P2, P3, P4]
# h, w = p[0].size(2), p[0].size(3)
# p[1:] = [F.interpolate(feature, size=(h, w), mode="bilinear", align_corners=True) for feature in p[1:]]

# x = self.conv_fusion(torch.cat(p, dim=1))
# return x


# class UperNetDecoder(nn.Module):
# def __init__(self, embed_dim: list[int]) -> None:
# super().__init__()
# self.embed_dim = embed_dim
# self.output_embed_dim = embed_dim[0]
# self.PPN = PSPModule(embed_dim[-1])
# self.FPN = FPNFuse(embed_dim, fpn_out=self.output_embed_dim)

# def forward(self, x: Tensor):
# x = [f.clone() for f in x]
# x[-1] = self.PPN(x[-1])
# x = self.FPN(x)

# return x


# Adapted from MMSegmentation
class UperNetDecoder(nn.Module):
"""UperNetDecoder. Adapted from MMSegmentation."""
Expand All @@ -126,6 +25,7 @@ def __init__(
pool_scales: tuple[int] = (1, 2, 3, 6),
channels: int = 256,
align_corners: bool = True, # noqa: FBT001, FBT002
scale_modules: bool = False
):
"""Constructor
Expand All @@ -134,10 +34,29 @@ def __init__(
pool_scales (tuple[int], optional): Pooling scales used in Pooling Pyramid
Module applied on the last feature. Default: (1, 2, 3, 6).
channels (int, optional): Channels used in the decoder. Defaults to 256.
align_corners (bool, optional): Whter to align corners in rescaling. Defaults to True.
align_corners (bool, optional): Wheter to align corners in rescaling. Defaults to True.
scale_modules (bool, optional): Whether to apply scale modules to the inputs. Needed for plain ViT.
Defaults to False.
"""
super().__init__()
self.embed_dim = embed_dim
self.scale_modules = scale_modules
if scale_modules:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim[0],
embed_dim[0] // 2, 2, 2),
nn.BatchNorm2d(embed_dim[0] // 2),
nn.GELU(),
nn.ConvTranspose2d(embed_dim[0] // 2,
embed_dim[0] // 4, 2, 2))
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(embed_dim[1],
embed_dim[1] // 2, 2, 2))
self.fpn3 = nn.Sequential(nn.Identity())
self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2))
self.embed_dim = [embed_dim[0] // 4, embed_dim[1] // 2, embed_dim[2], embed_dim[3]]
else:
self.embed_dim = embed_dim

self.output_embed_dim = channels
self.channels = channels
self.align_corners = align_corners
Expand Down Expand Up @@ -192,6 +111,14 @@ def forward(self, inputs):
feats (Tensor): A tensor of shape (batch_size, self.channels,
H, W) which is feature map for last layer of decoder head.
"""

if self.scale_modules:
scaled_inputs = []
scaled_inputs.append(self.fpn1(inputs[0]))
scaled_inputs.append(self.fpn2(inputs[1]))
scaled_inputs.append(self.fpn3(inputs[2]))
scaled_inputs.append(self.fpn4(inputs[3]))
inputs = scaled_inputs
# build laterals
laterals = [lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
laterals.append(self.psp_forward(inputs))
Expand Down
Loading

0 comments on commit 3670c7b

Please sign in to comment.