Skip to content

Commit

Permalink
2407 Add support to set number classes for TorchVision models (Projec…
Browse files Browse the repository at this point in the history
…t-MONAI#2408)

* [DLMED] add TorchVisionClassificationModel

Signed-off-by: Nic Ma <[email protected]>

Co-authored-by: monai-bot <[email protected]>
  • Loading branch information
Nic-Ma and monai-bot authored Jun 24, 2021
1 parent 0662ed2 commit 79b83d9
Show file tree
Hide file tree
Showing 11 changed files with 434 additions and 51 deletions.
10 changes: 10 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,16 @@ Nets
.. autoclass:: Critic
:members:

`NetAdapter`
~~~~~~~~~~~~
.. autoclass:: NetAdapter
:members:

`TorchVisionFCModel`
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: TorchVisionFCModel
:members:

`TorchVisionFullyConvModel`
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: TorchVisionFullyConvModel
Expand Down
2 changes: 1 addition & 1 deletion monai/networks/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@
separable_filtering,
)
from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push
from .utils import get_act_layer, get_dropout_layer, get_norm_layer
from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer
26 changes: 24 additions & 2 deletions monai/networks/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

from typing import Optional, Tuple, Union

from monai.networks.layers.factories import Act, Dropout, Norm, split_args
from monai.networks.layers.factories import Act, Dropout, Norm, Pool, split_args
from monai.utils import has_option

__all__ = ["get_norm_layer", "get_act_layer", "get_dropout_layer"]
__all__ = ["get_norm_layer", "get_act_layer", "get_dropout_layer", "get_pool_layer"]


def get_norm_layer(name: Union[Tuple, str], spatial_dims: Optional[int] = 1, channels: Optional[int] = 1):
Expand Down Expand Up @@ -92,3 +92,25 @@ def get_dropout_layer(name: Union[Tuple, str, float, int], dropout_dim: Optional
drop_name, drop_args = split_args(name)
drop_type = Dropout[drop_name, dropout_dim]
return drop_type(**drop_args)


def get_pool_layer(name: Union[Tuple, str], spatial_dims: Optional[int] = 1):
"""
Create a pooling layer instance.
For example, to create adaptiveavg layer:
.. code-block:: python
from monai.networks.layers import get_pool_layer
pool_layer = get_pool_layer(("adaptiveavg", {"output_size": (1, 1, 1)}), spatial_dims=3)
Args:
name: a pooling type string or a tuple of type string and parameters.
spatial_dims: number of spatial dimensions of the input.
"""
pool_name, pool_args = split_args(name)
pool_type = Pool[pool_name, spatial_dims]
return pool_type(**pool_args)
3 changes: 2 additions & 1 deletion monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet
from .generator import Generator
from .highresnet import HighResBlock, HighResNet
from .netadapter import NetAdapter
from .regressor import Regressor
from .regunet import GlobalNet, LocalNet, RegUNet
from .resnet import ResNet, resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200
Expand Down Expand Up @@ -71,7 +72,7 @@
seresnext50,
seresnext101,
)
from .torchvision_fc import TorchVisionFullyConvModel
from .torchvision_fc import TorchVisionFCModel, TorchVisionFullyConvModel
from .unet import UNet, Unet, unet
from .varautoencoder import VarAutoEncoder
from .vnet import VNet
102 changes: 102 additions & 0 deletions monai/networks/nets/netadapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional, Tuple, Union

import torch

from monai.networks.layers import Conv, get_pool_layer


class NetAdapter(torch.nn.Module):
"""
Wrapper to replace the last layer of model by convolutional layer or FC layer.
This module expects the output of `model layers[0: -2]` is a feature map with shape [B, C, spatial dims],
then replace the model's last two layers with an optional `pooling` and a `conv` or `linear` layer.
Args:
model: a PyTorch model, support both 2D and 3D models. typically, it can be a pretrained model in Torchvision,
like: ``resnet18``, ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``, etc.
more details: https://pytorch.org/vision/stable/models.html.
n_classes: number of classes for the last classification layer. Default to 1.
dim: number of spatial dimensions, default to 2.
in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer.
use_conv: whether use convolutional layer to replace the last layer, default to False.
pool: parameters for the pooling layer, it should be a tuple, the first item is name of the pooling layer,
the second item is dictionary of the initialization args. if None, will not replace the `layers[-2]`.
default to `("avg", {"kernel_size": 7, "stride": 1})`.
bias: the bias value when replacing the last layer. if False, the layer will not learn an additive bias,
default to True.
"""

def __init__(
self,
model: torch.nn.Module,
n_classes: int = 1,
dim: int = 2,
in_channels: Optional[int] = None,
use_conv: bool = False,
pool: Optional[Tuple[str, Dict[str, Any]]] = ("avg", {"kernel_size": 7, "stride": 1}),
bias: bool = True,
):
super().__init__()
layers = list(model.children())
orig_fc = layers[-1]
in_channels_: int

if in_channels is None:
if not hasattr(orig_fc, "in_features"):
raise ValueError("please specify the input channels of last layer with arg `in_channels`.")
in_channels_ = orig_fc.in_features # type: ignore
else:
in_channels_ = in_channels

if pool is None:
self.pool = None
# remove the last layer
self.features = torch.nn.Sequential(*layers[:-1])
else:
self.pool = get_pool_layer(name=pool, spatial_dims=dim)
# remove the last 2 layers
self.features = torch.nn.Sequential(*layers[:-2])

self.fc: Union[torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv3d]
if use_conv:
# add 1x1 conv (it behaves like a FC layer)
self.fc = Conv[Conv.CONV, dim](
in_channels=in_channels_,
out_channels=n_classes,
kernel_size=1,
bias=bias,
)
else:
# remove the last Linear layer (fully connected)
self.features = torch.nn.Sequential(*layers[:-1])
# replace the out_features of FC layer
self.fc = torch.nn.Linear(
in_features=in_channels_,
out_features=n_classes,
bias=bias,
)
self.use_conv = use_conv

def forward(self, x):
x = self.features(x)
if self.pool is not None:
x = self.pool(x)

if not self.use_conv:
x = torch.flatten(x, 1)

x = self.fc(x)

return x
94 changes: 60 additions & 34 deletions monai/networks/nets/torchvision_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,66 +9,92 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

import torch

from monai.utils import optional_import
from monai.networks.nets import NetAdapter
from monai.utils import deprecated, optional_import

models, _ = optional_import("torchvision.models")


__all__ = ["TorchVisionFullyConvModel"]
__all__ = ["TorchVisionFCModel", "TorchVisionFullyConvModel"]


class TorchVisionFullyConvModel(torch.nn.Module):
class TorchVisionFCModel(NetAdapter):
"""
Customize TorchVision models to replace fully connected layer by convolutional layer.
Customize the fully connected layer of TorchVision model or replace it by convolutional layer.
Args:
model_name: name of any torchvision with adaptive avg pooling and fully connected layer at the end.
model_name: name of any torchvision model with fully connected layer at the end.
``resnet18`` (default), ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``,
``resnext50_32x4d``, ``resnext101_32x8d``, ``wide_resnet50_2``, ``wide_resnet101_2``.
model details: https://pytorch.org/vision/stable/models.html.
n_classes: number of classes for the last classification layer. Default to 1.
pool_size: the kernel size for `AvgPool2d` to replace `AdaptiveAvgPool2d`. Default to (7, 7).
pool_stride: the stride for `AvgPool2d` to replace `AdaptiveAvgPool2d`. Default to 1.
dim: number of spatial dimensions, default to 2.
in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer.
use_conv: whether use convolutional layer to replace the last layer, default to False.
pool: parameters for the pooling layer, it should be a tuple, the first item is name of the pooling layer,
the second item is dictionary of the initialization args. if None, will not replace the `layers[-2]`.
default to `("avg", {"kernel_size": 7, "stride": 1})`.
bias: the bias value when replacing the last layer. if False, the layer will not learn an additive bias,
default to True.
pretrained: whether to use the imagenet pretrained weights. Default to False.
"""

def __init__(
self,
model_name: str = "resnet18",
n_classes: int = 1,
pool_size: Union[int, Tuple[int, int]] = (7, 7),
pool_stride: Union[int, Tuple[int, int]] = 1,
dim: int = 2,
in_channels: Optional[int] = None,
use_conv: bool = False,
pool: Optional[Tuple[str, Dict[str, Any]]] = ("avg", {"kernel_size": 7, "stride": 1}),
bias: bool = True,
pretrained: bool = False,
):
super().__init__()
model = getattr(models, model_name)(pretrained=pretrained)
layers = list(model.children())

# check if the model is compatible
if not str(layers[-1]).startswith("Linear"):
# check if the model is compatible, should have a FC layer at the end
if not str(list(model.children())[-1]).startswith("Linear"):
raise ValueError(f"Model ['{model_name}'] does not have a Linear layer at the end.")
if not str(layers[-2]).startswith("AdaptiveAvgPool2d"):
raise ValueError(f"Model ['{model_name}'] does not have a AdaptiveAvgPool2d layer next to the end.")

# remove the last Linear layer (fully connected) and the adaptive avg pooling
self.features = torch.nn.Sequential(*layers[:-2])
super().__init__(
model=model,
n_classes=n_classes,
dim=dim,
in_channels=in_channels,
use_conv=use_conv,
pool=pool,
bias=bias,
)

# add 7x7 avg pooling (in place of adaptive avg pooling)
self.pool = torch.nn.AvgPool2d(kernel_size=pool_size, stride=pool_stride)

# add 1x1 conv (it behaves like a FC layer)
self.fc = torch.nn.Conv2d(model.fc.in_features, n_classes, kernel_size=(1, 1))

def forward(self, x):
x = self.features(x)

# apply 2D avg pooling
x = self.pool(x)
@deprecated(since="0.6.0", version_val="0.7.0", msg_suffix="please consider to use `TorchVisionFCModel` instead.")
class TorchVisionFullyConvModel(TorchVisionFCModel):
"""
Customize TorchVision models to replace fully connected layer by convolutional layer.
# apply last 1x1 conv layer that act like a linear layer
x = self.fc(x)
Args:
model_name: name of any torchvision with adaptive avg pooling and fully connected layer at the end.
``resnet18`` (default), ``resnet34m``, ``resnet50``, ``resnet101``, ``resnet152``,
``resnext50_32x4d``, ``resnext101_32x8d``, ``wide_resnet50_2``, ``wide_resnet101_2``.
n_classes: number of classes for the last classification layer. Default to 1.
pool_size: the kernel size for `AvgPool2d` to replace `AdaptiveAvgPool2d`. Default to (7, 7).
pool_stride: the stride for `AvgPool2d` to replace `AdaptiveAvgPool2d`. Default to 1.
pretrained: whether to use the imagenet pretrained weights. Default to False.
"""

return x
def __init__(
self,
model_name: str = "resnet18",
n_classes: int = 1,
pool_size: Union[int, Tuple[int, int]] = (7, 7),
pool_stride: Union[int, Tuple[int, int]] = 1,
pretrained: bool = False,
):
super().__init__(
model_name=model_name,
n_classes=n_classes,
use_conv=True,
pool=("avg", {"kernel_size": pool_size, "stride": pool_stride}),
pretrained=pretrained,
)
65 changes: 65 additions & 0 deletions tests/test_net_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets import NetAdapter, resnet18

device = "cuda" if torch.cuda.is_available() else "cpu"

TEST_CASE_0 = [
{"n_classes": 1, "use_conv": True, "dim": 2},
(2, 3, 224, 224),
(2, 1, 8, 1),
]

TEST_CASE_1 = [
{"n_classes": 1, "use_conv": True, "dim": 3, "pool": None},
(2, 3, 32, 32, 32),
(2, 1, 1, 1, 1),
]

TEST_CASE_2 = [
{"n_classes": 5, "use_conv": True, "dim": 3, "pool": None},
(2, 3, 32, 32, 32),
(2, 5, 1, 1, 1),
]

TEST_CASE_3 = [
{"n_classes": 5, "use_conv": True, "pool": ("avg", {"kernel_size": 4, "stride": 1}), "dim": 3},
(2, 3, 128, 128, 128),
(2, 5, 5, 1, 1),
]

TEST_CASE_4 = [
{"n_classes": 5, "use_conv": False, "pool": ("adaptiveavg", {"output_size": (1, 1, 1)}), "dim": 3},
(2, 3, 32, 32, 32),
(2, 5),
]


class TestNetAdapter(unittest.TestCase):
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_shape(self, input_param, input_shape, expected_shape):
model = resnet18(spatial_dims=input_param["dim"])
input_param["model"] = model
net = NetAdapter(**input_param).to(device)
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/test_normalize_intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_default(self):
normalized = normalizer(self.imt.copy())
self.assertTrue(normalized.dtype == np.float32)
expected = (self.imt - np.mean(self.imt)) / np.std(self.imt)
np.testing.assert_allclose(normalized, expected, rtol=1e-5)
np.testing.assert_allclose(normalized, expected, rtol=1e-3)

@parameterized.expand(TEST_CASES)
def test_nonzero(self, input_param, input_data, expected_data):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_normalize_intensityd.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_image_normalize_intensityd(self):
normalizer = NormalizeIntensityd(keys=[key])
normalized = normalizer({key: self.imt})
expected = (self.imt - np.mean(self.imt)) / np.std(self.imt)
np.testing.assert_allclose(normalized[key], expected, rtol=1e-5)
np.testing.assert_allclose(normalized[key], expected, rtol=1e-3)

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_nonzero(self, input_param, input_data, expected_data):
Expand Down
Loading

0 comments on commit 79b83d9

Please sign in to comment.