Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Vision Transformer from ScaleMAE #422

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions torch_em/model/unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


#
# UNETR IMPLEMENTATION [Vision Transformer (ViT from MAE / ViT from SAM) + UNet Decoder from `torch_em`]
# UNETR IMPLEMENTATION [Vision Transformer (ViT from SAM / MAE / ScaleMAE) + UNet Decoder from `torch_em`]
#


Expand Down Expand Up @@ -57,14 +57,30 @@ def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):
# vit initialization hints from:
# - https://github.com/facebookresearch/mae/blob/main/main_finetune.py#L233-L242
encoder_state = torch.load(checkpoint)["model"]
encoder_state = OrderedDict(
{k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))}
)
# Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
current_encoder_state = self.encoder.state_dict()
if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
del self.encoder.head

elif backbone == "scalemae":
# Load the encoder state directly from a checkpoint.
encoder_state = torch.load(checkpoint)["model"]
encoder_state = OrderedDict({
k: v for k, v in encoder_state.items() if (k != "mask_token" and not k.startswith("decoder"))
k: v for k, v in encoder_state.items()
if not k.startswith(("mask_token", "decoder", "fcn", "fpn", "pos_embed"))
})

# Let's remove the `head` from our current encoder (as the MAE pretrained don't expect it)
current_encoder_state = self.encoder.state_dict()
if ("head.weight" in current_encoder_state) and ("head.bias" in current_encoder_state):
del self.encoder.head

if "pos_embed" in current_encoder_state: # NOTE: ScaleMAE uses 'pos. embeddings' in a diff. format.
del self.encoder.pos_embed

else:
encoder_state = checkpoint

Expand Down Expand Up @@ -96,6 +112,7 @@ def __init__(
if isinstance(encoder, str): # "vit_b" / "vit_l" / "vit_h"
print(f"Using {encoder} from {backbone.upper()}")
self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder)

if encoder_checkpoint is not None:
self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)

Expand Down Expand Up @@ -225,7 +242,7 @@ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(1, -1, 1, 1).to(device)
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(1, -1, 1, 1).to(device)
elif self.use_mae_stats:
# TODO: add mean std from mae experiments (or open up arguments for this)
# TODO: add mean std from mae / scalemae experiments (or open up arguments for this)
raise NotImplementedError
else:
pixel_mean = torch.Tensor([0.0, 0.0, 0.0]).view(1, -1, 1, 1).to(device)
Expand Down
239 changes: 230 additions & 9 deletions torch_em/model/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.nn as nn

# we catch ImportErrors here because segment_anything, micro_sam and timm should
# we catch ImportErrors here because segment_anything, micro_sam, scale_mae and timm should
# only be optional dependencies for torch_em
try:
from segment_anything.modeling import ImageEncoderViT
Expand All @@ -14,10 +14,11 @@
_sam_import_success = False

try:
from timm.models.vision_transformer import VisionTransformer
from timm.models.vision_transformer import VisionTransformer, PatchEmbed
_timm_import_success = True
except ImportError:
VisionTransformer = object
PatchEmbed = object
_timm_import_success = False


Expand Down Expand Up @@ -47,11 +48,7 @@ def __init__(
"and then rerun your code."
)

super().__init__(
embed_dim=embed_dim,
global_attn_indexes=global_attn_indexes,
**kwargs,
)
super().__init__(embed_dim=embed_dim, global_attn_indexes=global_attn_indexes, **kwargs)
self.chunks_for_projection = global_attn_indexes
self.in_chans = in_chans
self.embed_dim = embed_dim
Expand Down Expand Up @@ -159,6 +156,208 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x, list_from_encoder


#
# Utilities for ScaleMAE's ViT
#


class CustomCompose:
def __init__(self, rescale_transform, other_transforms, src_transform):
self.rescale_transform = rescale_transform
self.other_transforms = other_transforms
self.src_transform = src_transform

def __call__(self, x, valid_masks=None):
if valid_masks is not None:
nodata = (x * (1 - valid_masks.float())).max()
x_aug = self.rescale_transform(x)
parms = self.rescale_transform._params

# sanity check, comment if this is working
# valid_masks = self.rescale_transform(valid_masks.float(), params=parms)
# assert (x_aug==self.rescale_transform(x, params=parms)).all() #

if valid_masks is not None:
valid_masks = x_aug != nodata
_, c, h, w = x_aug.shape
zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy()
else:
zero_ratio = -1

if self.other_transforms:
x_aug = self.other_transforms(x_aug)
x_src = self.src_transform(x_aug)
dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0]

# dy = (parms['src'][:,2,1] - parms['src'][:,1,1])
# assert (dx == dy).all()

h, w = x_aug.shape[-2:]
# assert h == w

return x_aug, x_src, dx / h, zero_ratio, valid_masks


def get_2d_sincos_pos_embed_with_resolution(embed_dim, grid_size, res, cls_token=False, device="cpu"):
"""
grid_size: int of the grid height and width
res: array of size n, representing the resolution of a pixel (say, in meters),
return:
pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
# res = torch.FloatTensor(res).to(device)
res = res.to(device)
grid_h = torch.arange(grid_size, dtype=torch.float32, device=device)
grid_w = torch.arange(grid_size, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here h goes first,direction reversed for numpy
grid = torch.stack(grid, dim=0) # 2 x h x w

# grid = grid.reshape([2, 1, grid_size, grid_size])
grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w
_, n, h, w = grid.shape
pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid) # (nxH*W, D/2)
pos_embed = pos_embed.reshape(n, h * w, embed_dim)
if cls_token:
pos_embed = torch.cat(
[torch.zeros([n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device), pos_embed], dim=1
)

return pos_embed


def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid):
assert embed_dim % 2 == 0

# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1]) # (H*W, D/2)

emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D)
return emb


def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
# old_shape = pos
omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)

pos = pos.reshape(-1) # (M,)
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product

emb_sin = torch.sin(out) # (M, D/2)
emb_cos = torch.cos(out) # (M, D/2)

emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
return emb


class PatchEmbedUnSafe(PatchEmbed):
"""Image to Patch Embedding"""

def forward(self, x):
B, C, H, W = x.shape

# NOTE: Comment code from ScaleMAE: Dropped size check in timm
# assert H == self.img_size[0] and W == self.img_size[1], \
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

x = self.proj(x).flatten(2).transpose(1, 2)
return x


class ViT_ScaleMAE(VisionTransformer):
"""Vision Transformer dervied from the Scale Masked Auto Encoder codebase (TODO: paper and github link).
"""

def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=12, **kwargs):
super().__init__(img_size=img_size, embed_dim=embed_dim, **kwargs)
self.img_size = img_size
self.in_chans = in_chans
self.depth = depth

self.patch_embed = PatchEmbedUnSafe(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)

def transform_inputs(self, x):
import kornia.augmentation as K
from kornia.constants import Resample

# "base_resoulution" needs to be adjusted manually when using the model on a different zoom factor dataset.
base_resolution = 2.5

self._transforms = CustomCompose(
rescale_transform=K.RandomResizedCrop(
(448, 448),
ratio=(1.0, 1.0),
scale=(1.0, 1.0),
resample=Resample.BICUBIC.name,
),
other_transforms=None,
src_transform=K.Resize((224, 224)),
)
x, _, ratios, _, _ = self._transforms(x)
input_res = ratios * base_resolution
return x, input_res

def convert_to_expected_dim(self, x):
inputs_ = x[:, 1:, :] # removing the class tokens
# reshape the outputs to desired shape (N X H*W X C -> N X H X W X C)
rdim = inputs_.shape[1]
dshape = int(rdim ** 0.5) # finding square root of the outputs for obtaining the patch shape
inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape))
inputs_ = inputs_.permute(0, 3, 1, 2)
return inputs_

def forward_features(self, x):
x, input_res = self.transform_inputs(x)

B, _, h, w = x.shape
x = self.patch_embed(x)

num_patches = int((h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]))
pos_embed = get_2d_sincos_pos_embed_with_resolution(
x.shape[-1],
int(num_patches ** 0.5),
input_res,
cls_token=True,
device=x.device,
)

cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + pos_embed
x = self.pos_drop(x)

# chunks obtained for getting the projections for conjuctions with upsampling blocks
_chunks = int(self.depth / 4)
chunks_for_projection = [_chunks - 1, 2*_chunks - 1, 3*_chunks - 1, 4*_chunks - 1]

list_from_encoder = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in chunks_for_projection:
list_from_encoder.append(self.convert_to_expected_dim(x))

x = self.convert_to_expected_dim(x)

return x, list_from_encoder

def forward(self, x):
x, list_from_encoder = self.forward_features(x)
return x, list_from_encoder


def get_vision_transformer(backbone: str, model: str, img_size: int = 1024) -> nn.Module:
"""Get vision transformer encoder.

Expand Down Expand Up @@ -196,7 +395,7 @@ def get_vision_transformer(backbone: str, model: str, img_size: int = 1024) -> n
window_size=14, out_chans=256
)
else:
raise ValueError(f"{model} is not supported by SAM. Currently vit_b, vit_l, vit_h are supported.")
raise ValueError(f"'{model}' is not supported by SAM. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")

elif backbone == "mae":
if model == "vit_b":
Expand All @@ -215,9 +414,31 @@ def get_vision_transformer(backbone: str, model: str, img_size: int = 1024) -> n
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
)
else:
raise ValueError(f"{model} is not supported by MAE. Currently vit_b, vit_l, vit_h are supported.")
raise ValueError(f"'{model}' is not supported by MAE. Currently, 'vit_b', 'vit_l', 'vit_h' are supported.")

elif backbone == "scalemae":
if model == "vit_b":
encoder = ViT_ScaleMAE(
img_size=img_size, patch_size=8, embed_dim=768, depth=12, num_heads=12,
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
)
elif model == "vit_l":
encoder = ViT_ScaleMAE(
img_size=img_size, patch_size=8, embed_dim=1024, depth=24, num_heads=16,
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
)
elif model == "vit_h":
encoder = ViT_ScaleMAE(
img_size=img_size, patch_size=8, embed_dim=1280, depth=32, num_heads=16,
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
)
else:
raise ValueError(
f"'{model}' is not supported by ScaleMAE. Currently, 'vit_b', 'vit_l' and 'vit_h' are supported."
)

else:
raise ValueError("The 'UNETR' supported backbones are `sam`, `mae` or 'scalemae. Please choose one of them.")
raise ValueError("The UNETR supported backbones are `sam` or `mae`. Please choose either of the two.")

return encoder
Loading