Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of a Masked Autoencoder for representation learning #8152

Merged
merged 9 commits into from
Nov 27, 2024
5 changes: 5 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,11 @@ Nets
.. autoclass:: ViTAutoEnc
:members:

`MaskedAutoEncoderViT`
~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: MaskedAutoEncoderViT
:members:

`FullyConnectedNet`
~~~~~~~~~~~~~~~~~~~
.. autoclass:: FullyConnectedNet
Expand Down
1 change: 1 addition & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from .generator import Generator
from .highresnet import HighResBlock, HighResNet
from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet
from .masked_autoencoder_vit import MaskedAutoEncoderViT
from .mednext import (
MedNeXt,
MedNext,
Expand Down
211 changes: 211 additions & 0 deletions monai/networks/nets/masked_autoencoder_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence

import numpy as np
import torch
import torch.nn as nn

from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding
from monai.networks.blocks.transformerblock import TransformerBlock
from monai.networks.layers import trunc_normal_
from monai.utils import ensure_tuple_rep
from monai.utils.module import look_up_option

SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"}

__all__ = ["MaskedAutoEncoderViT"]


class MaskedAutoEncoderViT(nn.Module):
"""
Masked Autoencoder (ViT), based on: "Kaiming et al.,
Masked Autoencoders Are Scalable Vision Learners <https://arxiv.org/abs/2111.06377>"
Only a subset of the patches passes through the encoder. The decoder tries to reconstruct
the masked patches, resulting in improved training speed.
"""

def __init__(
self,
in_channels: int,
img_size: Sequence[int] | int,
patch_size: Sequence[int] | int,
hidden_size: int = 768,
mlp_dim: int = 512,
num_layers: int = 12,
num_heads: int = 12,
masking_ratio: float = 0.75,
decoder_hidden_size: int = 384,
decoder_mlp_dim: int = 512,
decoder_num_layers: int = 4,
decoder_num_heads: int = 12,
proj_type: str = "conv",
pos_embed_type: str = "sincos",
decoder_pos_embed_type: str = "sincos",
dropout_rate: float = 0.0,
spatial_dims: int = 3,
qkv_bias: bool = False,
save_attn: bool = False,
) -> None:
"""
Args:
in_channels: dimension of input channels or the number of channels for input.
img_size: dimension of input image.
patch_size: dimension of patch size
hidden_size: dimension of hidden layer. Defaults to 768.
mlp_dim: dimension of feedforward layer. Defaults to 512.
num_layers: number of transformer blocks. Defaults to 12.
num_heads: number of attention heads. Defaults to 12.
masking_ratio: ratio of patches to be masked. Defaults to 0.75.
decoder_hidden_size: dimension of hidden layer for decoder. Defaults to 384.
decoder_mlp_dim: dimension of feedforward layer for decoder. Defaults to 512.
decoder_num_layers: number of transformer blocks for decoder. Defaults to 4.
decoder_num_heads: number of attention heads for decoder. Defaults to 12.
proj_type: position embedding layer type. Defaults to "conv".
pos_embed_type: position embedding layer type. Defaults to "sincos".
decoder_pos_embed_type: position embedding layer type for decoder. Defaults to "sincos".
dropout_rate: fraction of the input units to drop. Defaults to 0.0.
spatial_dims: number of spatial dimensions. Defaults to 3.
qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False.
save_attn: to make accessible the attention in self attention block. Defaults to False.
Examples::
# for single channel input with image size of (96,96,96), and sin-cos positional encoding
>>> net = MaskedAutoEncoderViT(in_channels=1, img_size=(96,96,96), patch_size=(16,16,16),
pos_embed_type='sincos')
# for 3-channel with image size of (128,128,128) and a learnable positional encoding
>>> net = MaskedAutoEncoderViT(in_channels=3, img_size=128, patch_size=16, pos_embed_type='learnable')
# for 3-channel with image size of (224,224) and a masking ratio of 0.25
>>> net = MaskedAutoEncoderViT(in_channels=3, img_size=(224,224), patch_size=(16,16), masking_ratio=0.25,
spatial_dims=2)
"""

super().__init__()

if not (0 <= dropout_rate <= 1):
raise ValueError(f"dropout_rate should be between 0 and 1, got {dropout_rate}.")

if hidden_size % num_heads != 0:
raise ValueError("hidden_size should be divisible by num_heads.")

if decoder_hidden_size % decoder_num_heads != 0:
raise ValueError("decoder_hidden_size should be divisible by decoder_num_heads.")

self.patch_size = ensure_tuple_rep(patch_size, spatial_dims)
self.img_size = ensure_tuple_rep(img_size, spatial_dims)
self.spatial_dims = spatial_dims
for m, p in zip(self.img_size, self.patch_size):
if m % p != 0:
raise ValueError(f"patch_size={patch_size} should be divisible by img_size={img_size}.")

self.decoder_hidden_size = decoder_hidden_size

if masking_ratio <= 0 or masking_ratio >= 1:
raise ValueError(f"masking_ratio should be in the range (0, 1), got {masking_ratio}.")

self.masking_ratio = masking_ratio
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))

self.patch_embedding = PatchEmbeddingBlock(
in_channels=in_channels,
img_size=img_size,
patch_size=patch_size,
hidden_size=hidden_size,
num_heads=num_heads,
proj_type=proj_type,
pos_embed_type=pos_embed_type,
dropout_rate=dropout_rate,
spatial_dims=self.spatial_dims,
)
blocks = [
TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn)
for _ in range(num_layers)
]
self.blocks = nn.Sequential(*blocks, nn.LayerNorm(hidden_size))

# decoder
self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size)

self.mask_tokens = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size))

self.decoder_pos_embed_type = look_up_option(decoder_pos_embed_type, SUPPORTED_POS_EMBEDDING_TYPES)
self.decoder_pos_embedding = nn.Parameter(torch.zeros(1, self.patch_embedding.n_patches, decoder_hidden_size))

decoder_blocks = [
TransformerBlock(decoder_hidden_size, decoder_mlp_dim, decoder_num_heads, dropout_rate, qkv_bias, save_attn)
for _ in range(decoder_num_layers)
]
self.decoder_blocks = nn.Sequential(*decoder_blocks, nn.LayerNorm(decoder_hidden_size))
self.decoder_pred = nn.Linear(decoder_hidden_size, int(np.prod(self.patch_size)) * in_channels)

self._init_weights()

def _init_weights(self):
"""
similar to monai/networks/blocks/patchembedding.py for the decoder positional encoding and for mask and
classification tokens
"""
if self.decoder_pos_embed_type == "none":
pass
elif self.decoder_pos_embed_type == "learnable":
trunc_normal_(self.decoder_pos_embedding, mean=0.0, std=0.02, a=-2.0, b=2.0)
elif self.decoder_pos_embed_type == "sincos":
grid_size = []
for in_size, pa_size in zip(self.img_size, self.patch_size):
grid_size.append(in_size // pa_size)

self.decoder_pos_embedding = build_sincos_position_embedding(
grid_size, self.decoder_hidden_size, self.spatial_dims
)

else:
raise ValueError(f"decoder_pos_embed_type {self.decoder_pos_embed_type} not supported.")

# initialize patch_embedding like nn.Linear (instead of nn.Conv2d)
trunc_normal_(self.mask_tokens, mean=0.0, std=0.02, a=-2.0, b=2.0)
trunc_normal_(self.cls_token, mean=0.0, std=0.02, a=-2.0, b=2.0)

def _masking(self, x, masking_ratio: float | None = None):
batch_size, num_tokens, _ = x.shape
percentage_to_keep = 1 - masking_ratio if masking_ratio is not None else 1 - self.masking_ratio
selected_indices = torch.multinomial(
torch.ones(batch_size, num_tokens), int(percentage_to_keep * num_tokens), replacement=False
)
x_masked = x[torch.arange(batch_size).unsqueeze(1), selected_indices] # gather the selected tokens
mask = torch.ones(batch_size, num_tokens, dtype=torch.int).to(x.device)
mask[torch.arange(batch_size).unsqueeze(-1), selected_indices] = 0

return x_masked, selected_indices, mask

def forward(self, x, masking_ratio: float | None = None):
x = self.patch_embedding(x)
x, selected_indices, mask = self._masking(x, masking_ratio=masking_ratio)

cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)

x = self.blocks(x)

# decoder
x = self.decoder_embed(x)

x_ = self.mask_tokens.repeat(x.shape[0], mask.shape[1], 1)
x_[torch.arange(x.shape[0]).unsqueeze(-1), selected_indices] = x[:, 1:, :] # no cls token
x_ = x_ + self.decoder_pos_embedding
x = torch.cat([x[:, :1, :], x_], dim=1)
x = self.decoder_blocks(x)
x = self.decoder_pred(x)

x = x[:, 1:, :]
return x, mask
160 changes: 160 additions & 0 deletions tests/test_masked_autoencoder_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets.masked_autoencoder_vit import MaskedAutoEncoderViT
from tests.utils import skip_if_quick

TEST_CASE_MaskedAutoEncoderViT = []
for masking_ratio in [0.5]:
for dropout_rate in [0.6]:
for in_channels in [4]:
for hidden_size in [768]:
for img_size in [96, 128]:
for patch_size in [16]:
for num_heads in [12]:
for mlp_dim in [3072]:
for num_layers in [4]:
for decoder_hidden_size in [384]:
for decoder_mlp_dim in [512]:
for decoder_num_layers in [4]:
for decoder_num_heads in [16]:
for pos_embed_type in ["sincos", "learnable"]:
for proj_type in ["conv", "perceptron"]:
for nd in (2, 3):
test_case = [
{
"in_channels": in_channels,
"img_size": (img_size,) * nd,
"patch_size": (patch_size,) * nd,
"hidden_size": hidden_size,
"mlp_dim": mlp_dim,
"num_layers": num_layers,
"decoder_hidden_size": decoder_hidden_size,
"decoder_mlp_dim": decoder_mlp_dim,
"decoder_num_layers": decoder_num_layers,
"decoder_num_heads": decoder_num_heads,
"pos_embed_type": pos_embed_type,
"masking_ratio": masking_ratio,
"decoder_pos_embed_type": pos_embed_type,
"num_heads": num_heads,
"proj_type": proj_type,
"dropout_rate": dropout_rate,
},
(2, in_channels, *([img_size] * nd)),
(
2,
(img_size // patch_size) ** nd,
in_channels * (patch_size**nd),
),
]
if nd == 2:
test_case[0]["spatial_dims"] = 2 # type: ignore
TEST_CASE_MaskedAutoEncoderViT.append(test_case)

TEST_CASE_ill_args = [
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (16, 16, 16), "dropout_rate": 5.0}],
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "pos_embed_type": "sin"}],
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "decoder_pos_embed_type": "sin"}],
[{"in_channels": 1, "img_size": (32, 32, 32), "patch_size": (64, 64, 64)}],
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "num_layers": 12, "num_heads": 14}],
[{"in_channels": 1, "img_size": (97, 97, 97), "patch_size": (16, 16, 16)}],
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "masking_ratio": 1.1}],
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "masking_ratio": -0.1}],
]


@skip_if_quick
class TestMaskedAutoencoderViT(unittest.TestCase):

@parameterized.expand(TEST_CASE_MaskedAutoEncoderViT)
def test_shape(self, input_param, input_shape, expected_shape):
net = MaskedAutoEncoderViT(**input_param)
with eval_mode(net):
result, _ = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)

def test_frozen_pos_embedding(self):
net = MaskedAutoEncoderViT(in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16))

self.assertEqual(net.decoder_pos_embedding.requires_grad, False)

@parameterized.expand(TEST_CASE_ill_args)
def test_ill_arg(self, input_param):
with self.assertRaises(ValueError):
MaskedAutoEncoderViT(**input_param)

def test_access_attn_matrix(self):
# input format
in_channels = 1
img_size = (96, 96, 96)
patch_size = (16, 16, 16)
in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2])

# no data in the matrix
no_matrix_acess_blk = MaskedAutoEncoderViT(in_channels=in_channels, img_size=img_size, patch_size=patch_size)
no_matrix_acess_blk(torch.randn(in_shape))
assert isinstance(no_matrix_acess_blk.blocks[0].attn.att_mat, torch.Tensor)
# no of elements is zero
assert no_matrix_acess_blk.blocks[0].attn.att_mat.nelement() == 0

# be able to acess the attention matrix
matrix_acess_blk = MaskedAutoEncoderViT(
in_channels=in_channels, img_size=img_size, patch_size=patch_size, save_attn=True
)
matrix_acess_blk(torch.randn(in_shape))

assert matrix_acess_blk.blocks[0].attn.att_mat.shape == (in_shape[0], 12, 55, 55)

def test_masking_ratio(self):
# input format
in_channels = 1
img_size = (96, 96, 96)
patch_size = (16, 16, 16)
in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2])

# masking ratio 0.25
masking_ratio_blk = MaskedAutoEncoderViT(
in_channels=in_channels, img_size=img_size, patch_size=patch_size, masking_ratio=0.25, save_attn=True
)
masking_ratio_blk(torch.randn(in_shape))
desired_num_tokens = int(
(img_size[0] // patch_size[0])
* (img_size[1] // patch_size[1])
* (img_size[2] // patch_size[2])
* (1 - 0.25)
)
assert masking_ratio_blk.blocks[0].attn.att_mat.shape[-1] - 1 == desired_num_tokens

# masking ratio 0.33
masking_ratio_blk = MaskedAutoEncoderViT(
in_channels=in_channels, img_size=img_size, patch_size=patch_size, masking_ratio=0.33, save_attn=True
)
masking_ratio_blk(torch.randn(in_shape))
desired_num_tokens = int(
(img_size[0] // patch_size[0])
* (img_size[1] // patch_size[1])
* (img_size[2] // patch_size[2])
* (1 - 0.33)
)

assert masking_ratio_blk.blocks[0].attn.att_mat.shape[-1] - 1 == desired_num_tokens


if __name__ == "__main__":
unittest.main()
Loading