diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 64a3a4c9d1..e2e509a99b 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -630,6 +630,11 @@ Nets .. autoclass:: ViTAutoEnc :members: +`MaskedAutoEncoderViT` +~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: MaskedAutoEncoderViT + :members: + `FullyConnectedNet` ~~~~~~~~~~~~~~~~~~~ .. autoclass:: FullyConnectedNet diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index b876e6a3fc..c1917e5293 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -53,6 +53,7 @@ from .generator import Generator from .highresnet import HighResBlock, HighResNet from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet +from .masked_autoencoder_vit import MaskedAutoEncoderViT from .mednext import ( MedNeXt, MedNext, diff --git a/monai/networks/nets/masked_autoencoder_vit.py b/monai/networks/nets/masked_autoencoder_vit.py new file mode 100644 index 0000000000..e76f097346 --- /dev/null +++ b/monai/networks/nets/masked_autoencoder_vit.py @@ -0,0 +1,211 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import numpy as np +import torch +import torch.nn as nn + +from monai.networks.blocks.patchembedding import PatchEmbeddingBlock +from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding +from monai.networks.blocks.transformerblock import TransformerBlock +from monai.networks.layers import trunc_normal_ +from monai.utils import ensure_tuple_rep +from monai.utils.module import look_up_option + +SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"} + +__all__ = ["MaskedAutoEncoderViT"] + + +class MaskedAutoEncoderViT(nn.Module): + """ + Masked Autoencoder (ViT), based on: "Kaiming et al., + Masked Autoencoders Are Scalable Vision Learners " + Only a subset of the patches passes through the encoder. The decoder tries to reconstruct + the masked patches, resulting in improved training speed. + """ + + def __init__( + self, + in_channels: int, + img_size: Sequence[int] | int, + patch_size: Sequence[int] | int, + hidden_size: int = 768, + mlp_dim: int = 512, + num_layers: int = 12, + num_heads: int = 12, + masking_ratio: float = 0.75, + decoder_hidden_size: int = 384, + decoder_mlp_dim: int = 512, + decoder_num_layers: int = 4, + decoder_num_heads: int = 12, + proj_type: str = "conv", + pos_embed_type: str = "sincos", + decoder_pos_embed_type: str = "sincos", + dropout_rate: float = 0.0, + spatial_dims: int = 3, + qkv_bias: bool = False, + save_attn: bool = False, + ) -> None: + """ + Args: + in_channels: dimension of input channels or the number of channels for input. + img_size: dimension of input image. + patch_size: dimension of patch size + hidden_size: dimension of hidden layer. Defaults to 768. + mlp_dim: dimension of feedforward layer. Defaults to 512. + num_layers: number of transformer blocks. Defaults to 12. + num_heads: number of attention heads. Defaults to 12. + masking_ratio: ratio of patches to be masked. Defaults to 0.75. + decoder_hidden_size: dimension of hidden layer for decoder. Defaults to 384. + decoder_mlp_dim: dimension of feedforward layer for decoder. Defaults to 512. + decoder_num_layers: number of transformer blocks for decoder. Defaults to 4. + decoder_num_heads: number of attention heads for decoder. Defaults to 12. + proj_type: position embedding layer type. Defaults to "conv". + pos_embed_type: position embedding layer type. Defaults to "sincos". + decoder_pos_embed_type: position embedding layer type for decoder. Defaults to "sincos". + dropout_rate: fraction of the input units to drop. Defaults to 0.0. + spatial_dims: number of spatial dimensions. Defaults to 3. + qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False. + save_attn: to make accessible the attention in self attention block. Defaults to False. + Examples:: + # for single channel input with image size of (96,96,96), and sin-cos positional encoding + >>> net = MaskedAutoEncoderViT(in_channels=1, img_size=(96,96,96), patch_size=(16,16,16), + pos_embed_type='sincos') + # for 3-channel with image size of (128,128,128) and a learnable positional encoding + >>> net = MaskedAutoEncoderViT(in_channels=3, img_size=128, patch_size=16, pos_embed_type='learnable') + # for 3-channel with image size of (224,224) and a masking ratio of 0.25 + >>> net = MaskedAutoEncoderViT(in_channels=3, img_size=(224,224), patch_size=(16,16), masking_ratio=0.25, + spatial_dims=2) + """ + + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise ValueError(f"dropout_rate should be between 0 and 1, got {dropout_rate}.") + + if hidden_size % num_heads != 0: + raise ValueError("hidden_size should be divisible by num_heads.") + + if decoder_hidden_size % decoder_num_heads != 0: + raise ValueError("decoder_hidden_size should be divisible by decoder_num_heads.") + + self.patch_size = ensure_tuple_rep(patch_size, spatial_dims) + self.img_size = ensure_tuple_rep(img_size, spatial_dims) + self.spatial_dims = spatial_dims + for m, p in zip(self.img_size, self.patch_size): + if m % p != 0: + raise ValueError(f"patch_size={patch_size} should be divisible by img_size={img_size}.") + + self.decoder_hidden_size = decoder_hidden_size + + if masking_ratio <= 0 or masking_ratio >= 1: + raise ValueError(f"masking_ratio should be in the range (0, 1), got {masking_ratio}.") + + self.masking_ratio = masking_ratio + self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) + + self.patch_embedding = PatchEmbeddingBlock( + in_channels=in_channels, + img_size=img_size, + patch_size=patch_size, + hidden_size=hidden_size, + num_heads=num_heads, + proj_type=proj_type, + pos_embed_type=pos_embed_type, + dropout_rate=dropout_rate, + spatial_dims=self.spatial_dims, + ) + blocks = [ + TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn) + for _ in range(num_layers) + ] + self.blocks = nn.Sequential(*blocks, nn.LayerNorm(hidden_size)) + + # decoder + self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size) + + self.mask_tokens = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size)) + + self.decoder_pos_embed_type = look_up_option(decoder_pos_embed_type, SUPPORTED_POS_EMBEDDING_TYPES) + self.decoder_pos_embedding = nn.Parameter(torch.zeros(1, self.patch_embedding.n_patches, decoder_hidden_size)) + + decoder_blocks = [ + TransformerBlock(decoder_hidden_size, decoder_mlp_dim, decoder_num_heads, dropout_rate, qkv_bias, save_attn) + for _ in range(decoder_num_layers) + ] + self.decoder_blocks = nn.Sequential(*decoder_blocks, nn.LayerNorm(decoder_hidden_size)) + self.decoder_pred = nn.Linear(decoder_hidden_size, int(np.prod(self.patch_size)) * in_channels) + + self._init_weights() + + def _init_weights(self): + """ + similar to monai/networks/blocks/patchembedding.py for the decoder positional encoding and for mask and + classification tokens + """ + if self.decoder_pos_embed_type == "none": + pass + elif self.decoder_pos_embed_type == "learnable": + trunc_normal_(self.decoder_pos_embedding, mean=0.0, std=0.02, a=-2.0, b=2.0) + elif self.decoder_pos_embed_type == "sincos": + grid_size = [] + for in_size, pa_size in zip(self.img_size, self.patch_size): + grid_size.append(in_size // pa_size) + + self.decoder_pos_embedding = build_sincos_position_embedding( + grid_size, self.decoder_hidden_size, self.spatial_dims + ) + + else: + raise ValueError(f"decoder_pos_embed_type {self.decoder_pos_embed_type} not supported.") + + # initialize patch_embedding like nn.Linear (instead of nn.Conv2d) + trunc_normal_(self.mask_tokens, mean=0.0, std=0.02, a=-2.0, b=2.0) + trunc_normal_(self.cls_token, mean=0.0, std=0.02, a=-2.0, b=2.0) + + def _masking(self, x, masking_ratio: float | None = None): + batch_size, num_tokens, _ = x.shape + percentage_to_keep = 1 - masking_ratio if masking_ratio is not None else 1 - self.masking_ratio + selected_indices = torch.multinomial( + torch.ones(batch_size, num_tokens), int(percentage_to_keep * num_tokens), replacement=False + ) + x_masked = x[torch.arange(batch_size).unsqueeze(1), selected_indices] # gather the selected tokens + mask = torch.ones(batch_size, num_tokens, dtype=torch.int).to(x.device) + mask[torch.arange(batch_size).unsqueeze(-1), selected_indices] = 0 + + return x_masked, selected_indices, mask + + def forward(self, x, masking_ratio: float | None = None): + x = self.patch_embedding(x) + x, selected_indices, mask = self._masking(x, masking_ratio=masking_ratio) + + cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + x = self.blocks(x) + + # decoder + x = self.decoder_embed(x) + + x_ = self.mask_tokens.repeat(x.shape[0], mask.shape[1], 1) + x_[torch.arange(x.shape[0]).unsqueeze(-1), selected_indices] = x[:, 1:, :] # no cls token + x_ = x_ + self.decoder_pos_embedding + x = torch.cat([x[:, :1, :], x_], dim=1) + x = self.decoder_blocks(x) + x = self.decoder_pred(x) + + x = x[:, 1:, :] + return x, mask diff --git a/tests/test_masked_autoencoder_vit.py b/tests/test_masked_autoencoder_vit.py new file mode 100644 index 0000000000..f8f6977cc2 --- /dev/null +++ b/tests/test_masked_autoencoder_vit.py @@ -0,0 +1,160 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.masked_autoencoder_vit import MaskedAutoEncoderViT +from tests.utils import skip_if_quick + +TEST_CASE_MaskedAutoEncoderViT = [] +for masking_ratio in [0.5]: + for dropout_rate in [0.6]: + for in_channels in [4]: + for hidden_size in [768]: + for img_size in [96, 128]: + for patch_size in [16]: + for num_heads in [12]: + for mlp_dim in [3072]: + for num_layers in [4]: + for decoder_hidden_size in [384]: + for decoder_mlp_dim in [512]: + for decoder_num_layers in [4]: + for decoder_num_heads in [16]: + for pos_embed_type in ["sincos", "learnable"]: + for proj_type in ["conv", "perceptron"]: + for nd in (2, 3): + test_case = [ + { + "in_channels": in_channels, + "img_size": (img_size,) * nd, + "patch_size": (patch_size,) * nd, + "hidden_size": hidden_size, + "mlp_dim": mlp_dim, + "num_layers": num_layers, + "decoder_hidden_size": decoder_hidden_size, + "decoder_mlp_dim": decoder_mlp_dim, + "decoder_num_layers": decoder_num_layers, + "decoder_num_heads": decoder_num_heads, + "pos_embed_type": pos_embed_type, + "masking_ratio": masking_ratio, + "decoder_pos_embed_type": pos_embed_type, + "num_heads": num_heads, + "proj_type": proj_type, + "dropout_rate": dropout_rate, + }, + (2, in_channels, *([img_size] * nd)), + ( + 2, + (img_size // patch_size) ** nd, + in_channels * (patch_size**nd), + ), + ] + if nd == 2: + test_case[0]["spatial_dims"] = 2 # type: ignore + TEST_CASE_MaskedAutoEncoderViT.append(test_case) + +TEST_CASE_ill_args = [ + [{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (16, 16, 16), "dropout_rate": 5.0}], + [{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "pos_embed_type": "sin"}], + [{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "decoder_pos_embed_type": "sin"}], + [{"in_channels": 1, "img_size": (32, 32, 32), "patch_size": (64, 64, 64)}], + [{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "num_layers": 12, "num_heads": 14}], + [{"in_channels": 1, "img_size": (97, 97, 97), "patch_size": (16, 16, 16)}], + [{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "masking_ratio": 1.1}], + [{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "masking_ratio": -0.1}], +] + + +@skip_if_quick +class TestMaskedAutoencoderViT(unittest.TestCase): + + @parameterized.expand(TEST_CASE_MaskedAutoEncoderViT) + def test_shape(self, input_param, input_shape, expected_shape): + net = MaskedAutoEncoderViT(**input_param) + with eval_mode(net): + result, _ = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_frozen_pos_embedding(self): + net = MaskedAutoEncoderViT(in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16)) + + self.assertEqual(net.decoder_pos_embedding.requires_grad, False) + + @parameterized.expand(TEST_CASE_ill_args) + def test_ill_arg(self, input_param): + with self.assertRaises(ValueError): + MaskedAutoEncoderViT(**input_param) + + def test_access_attn_matrix(self): + # input format + in_channels = 1 + img_size = (96, 96, 96) + patch_size = (16, 16, 16) + in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2]) + + # no data in the matrix + no_matrix_acess_blk = MaskedAutoEncoderViT(in_channels=in_channels, img_size=img_size, patch_size=patch_size) + no_matrix_acess_blk(torch.randn(in_shape)) + assert isinstance(no_matrix_acess_blk.blocks[0].attn.att_mat, torch.Tensor) + # no of elements is zero + assert no_matrix_acess_blk.blocks[0].attn.att_mat.nelement() == 0 + + # be able to acess the attention matrix + matrix_acess_blk = MaskedAutoEncoderViT( + in_channels=in_channels, img_size=img_size, patch_size=patch_size, save_attn=True + ) + matrix_acess_blk(torch.randn(in_shape)) + + assert matrix_acess_blk.blocks[0].attn.att_mat.shape == (in_shape[0], 12, 55, 55) + + def test_masking_ratio(self): + # input format + in_channels = 1 + img_size = (96, 96, 96) + patch_size = (16, 16, 16) + in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2]) + + # masking ratio 0.25 + masking_ratio_blk = MaskedAutoEncoderViT( + in_channels=in_channels, img_size=img_size, patch_size=patch_size, masking_ratio=0.25, save_attn=True + ) + masking_ratio_blk(torch.randn(in_shape)) + desired_num_tokens = int( + (img_size[0] // patch_size[0]) + * (img_size[1] // patch_size[1]) + * (img_size[2] // patch_size[2]) + * (1 - 0.25) + ) + assert masking_ratio_blk.blocks[0].attn.att_mat.shape[-1] - 1 == desired_num_tokens + + # masking ratio 0.33 + masking_ratio_blk = MaskedAutoEncoderViT( + in_channels=in_channels, img_size=img_size, patch_size=patch_size, masking_ratio=0.33, save_attn=True + ) + masking_ratio_blk(torch.randn(in_shape)) + desired_num_tokens = int( + (img_size[0] // patch_size[0]) + * (img_size[1] // patch_size[1]) + * (img_size[2] // patch_size[2]) + * (1 - 0.33) + ) + + assert masking_ratio_blk.blocks[0].attn.att_mat.shape[-1] - 1 == desired_num_tokens + + +if __name__ == "__main__": + unittest.main()