Skip to content

Commit

Permalink
Refactoring ConvModule by removing norm_cfg (#3816)
Browse files Browse the repository at this point in the history
* Update `build_norm_layer` to use `norm_callable`

* WIP

* Replace `norm_cfg` with `norm_callable`

* Update `activation_callable` docstring

* Update `CHANGELOG`

* Enable using pre-assigned nn.Module

* Update to use pre-assigned norm layer in `ConvModule`

* Fix

* Enable `partial(build_norm_layer, ...)`

* Fix unit test

* Fix typo

* Fix unit test

* Fix unit test

* Restore `build_activation_layer` and update `ConvModule` to use preassigned module

* Update to use `build_activation_layer`

* Fix

* Enable to get nn.Module

* Remove `callable` in arg name

* Fix unit test

* Fix rtdetr18

* Fix

* precommit

* Fix
  • Loading branch information
sungchul2 authored Aug 19, 2024
1 parent 43f1fc9 commit f50e821
Show file tree
Hide file tree
Showing 66 changed files with 1,446 additions and 1,138 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ All notable changes to this project will be documented in this file.
(<https://github.com/openvinotoolkit/training_extensions/pull/3759>)
- Enable to use polygon and bitmap mask as prompt inputs for zero-shot learning
(<https://github.com/openvinotoolkit/training_extensions/pull/3769>)
- Refactoring `ConvModule` by removing `conv_cfg` and `act_cfg`
(<https://github.com/openvinotoolkit/training_extensions/pull/3783>, <https://github.com/openvinotoolkit/training_extensions/pull/3809>)
- Refactoring `ConvModule` by removing `conv_cfg`, `norm_cfg`, and `act_cfg`
(<https://github.com/openvinotoolkit/training_extensions/pull/3783>, <https://github.com/openvinotoolkit/training_extensions/pull/3816>, <https://github.com/openvinotoolkit/training_extensions/pull/3809>)

### Bug fixes

Expand Down
89 changes: 45 additions & 44 deletions src/otx/algo/action_classification/backbones/x3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm

from otx.algo.modules.activation import Swish
from otx.algo.modules.activation import Swish, build_activation_layer
from otx.algo.modules.conv_module import Conv3dModule
from otx.algo.modules.norm import build_norm_layer
from otx.algo.utils.mmengine_utils import load_checkpoint
from otx.algo.utils.weight_init import constant_init, kaiming_init

Expand Down Expand Up @@ -72,10 +73,10 @@ class BlockX3D(nn.Module):
unit. If set as None, it means not using SE unit. Default: None.
use_swish (bool): Whether to use swish as the activation function
before and after the 3x3x3 conv. Default: True.
norm_cfg (dict): Config for norm layers. required keys are ``type``,
Default: ``dict(type='BN3d')``.
activation_callable (Callable[..., nn.Module] | None): Activation layer module.
Defaults to `nn.ReLU`.
normalization (Callable[..., nn.Module] | None): Normalization layer module.
Defaults to None.
activation (Callable[..., nn.Module] | None): Activation layer module.
Defaults to ``nn.ReLU``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
Expand All @@ -89,8 +90,8 @@ def __init__(
downsample: nn.Module | None = None,
se_ratio: float | None = None,
use_swish: bool = True,
norm_cfg: dict | None = None,
activation_callable: Callable[..., nn.Module] | None = nn.ReLU,
normalization: Callable[..., nn.Module] | None = None,
activation: Callable[..., nn.Module] | None = nn.ReLU,
with_cp: bool = False,
):
super().__init__()
Expand All @@ -102,8 +103,8 @@ def __init__(
self.downsample = downsample
self.se_ratio = se_ratio
self.use_swish = use_swish
self.norm_cfg = norm_cfg
self.activation_callable = activation_callable
self.normalization = normalization
self.activation = activation
self.with_cp = with_cp

self.conv1 = Conv3dModule(
Expand All @@ -113,8 +114,8 @@ def __init__(
stride=1,
padding=0,
bias=False,
norm_cfg=self.norm_cfg,
activation_callable=self.activation_callable,
normalization=build_norm_layer(normalization, num_features=planes),
activation=build_activation_layer(activation),
)
# Here we use the channel-wise conv
self.conv2 = Conv3dModule(
Expand All @@ -125,8 +126,8 @@ def __init__(
padding=1,
groups=planes,
bias=False,
norm_cfg=self.norm_cfg,
activation_callable=None,
normalization=build_norm_layer(normalization, num_features=planes),
activation=None,
)

self.swish = Swish()
Expand All @@ -138,14 +139,14 @@ def __init__(
stride=1,
padding=0,
bias=False,
norm_cfg=self.norm_cfg,
activation_callable=None,
normalization=build_norm_layer(normalization, num_features=outplanes),
activation=None,
)

if self.se_ratio is not None:
self.se_module = SEModule(planes, self.se_ratio)

self.relu = self.activation_callable() if self.activation_callable else nn.ReLU(inplace=True)
self.relu = self.activation() if self.activation else nn.ReLU(inplace=True)

def forward(self, x: Tensor) -> Tensor:
"""Defines the computation performed at every call."""
Expand Down Expand Up @@ -195,11 +196,10 @@ class X3DBackbone(nn.Module):
unit. If set as None, it means not using SE unit. Default: 1 / 16.
use_swish (bool): Whether to use swish as the activation function
before and after the 3x3x3 conv. Default: True.
norm_cfg (dict): Config for norm layers. required keys are ``type`` and
``requires_grad``.
Default: ``dict(type='BN3d', requires_grad=True)``.
activation_callable (Callable[..., nn.Module] | None): Activation layer module.
Defaults to `nn.ReLU`.
normalization (Callable[..., nn.Module] | None): Normalization layer module.
Defaults to None.
activation (Callable[..., nn.Module] | None): Activation layer module.
Defaults to ``nn.ReLU``.
norm_eval (bool): Whether to set BN layers to eval mode, namely, freeze
running stats (mean and var). Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
Expand All @@ -223,8 +223,8 @@ def __init__(
se_style: str = "half",
se_ratio: float = 1 / 16,
use_swish: bool = True,
norm_cfg: dict | None = None,
activation_callable: Callable[..., nn.Module] | None = nn.ReLU,
normalization: Callable[..., nn.Module] | None = None,
activation: Callable[..., nn.Module] | None = nn.ReLU,
norm_eval: bool = False,
with_cp: bool = False,
zero_init_residual: bool = True,
Expand Down Expand Up @@ -266,8 +266,8 @@ def __init__(
raise ValueError(msg)
self.use_swish = use_swish

self.norm_cfg = norm_cfg
self.activation_callable = activation_callable
self.normalization = normalization
self.activation = activation
self.norm_eval = norm_eval
self.with_cp = with_cp
self.zero_init_residual = zero_init_residual
Expand All @@ -293,8 +293,8 @@ def __init__(
se_style=self.se_style,
se_ratio=self.se_ratio,
use_swish=self.use_swish,
norm_cfg=self.norm_cfg,
activation_callable=self.activation_callable,
normalization=self.normalization,
activation=self.activation,
with_cp=with_cp,
**kwargs,
)
Expand All @@ -311,8 +311,8 @@ def __init__(
stride=1,
padding=0,
bias=False,
norm_cfg=self.norm_cfg,
activation_callable=self.activation_callable,
normalization=build_norm_layer(self.normalization, num_features=int(self.feat_dim * self.gamma_b)),
activation=build_activation_layer(self.activation),
)
self.feat_dim = int(self.feat_dim * self.gamma_b)

Expand Down Expand Up @@ -349,8 +349,8 @@ def make_res_layer(
se_style: str = "half",
se_ratio: float | None = None,
use_swish: bool = True,
norm_cfg: dict | None = None,
activation_callable: Callable[..., nn.Module] | None = nn.ReLU,
normalization: Callable[..., nn.Module] | None = None,
activation: Callable[..., nn.Module] | None = nn.ReLU,
with_cp: bool = False,
**kwargs,
) -> nn.Module:
Expand All @@ -375,9 +375,10 @@ def make_res_layer(
Default: None.
use_swish (bool): Whether to use swish as the activation function
before and after the 3x3x3 conv. Default: True.
norm_cfg (dict | None): Config for norm layers. Default: None.
activation_callable (Callable[..., nn.Module] | None): Activation layer module.
Defaults to `nn.ReLU`.
normalization (Callable[..., nn.Module] | None): Normalization layer module.
Defaults to None.
activation (Callable[..., nn.Module] | None): Activation layer module.
Defaults to ``nn.ReLU``.
with_cp (bool | None): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Default: False.
Expand All @@ -394,8 +395,8 @@ def make_res_layer(
stride=(1, spatial_stride, spatial_stride),
padding=0,
bias=False,
norm_cfg=norm_cfg,
activation_callable=None,
normalization=build_norm_layer(normalization, num_features=inplanes),
activation=None,
)

use_se = [False] * blocks
Expand All @@ -416,8 +417,8 @@ def make_res_layer(
downsample=downsample,
se_ratio=se_ratio if use_se[0] else None,
use_swish=use_swish,
norm_cfg=norm_cfg,
activation_callable=activation_callable,
normalization=normalization,
activation=activation,
with_cp=with_cp,
**kwargs,
),
Expand All @@ -432,8 +433,8 @@ def make_res_layer(
spatial_stride=1,
se_ratio=se_ratio if use_se[i] else None,
use_swish=use_swish,
norm_cfg=norm_cfg,
activation_callable=activation_callable,
normalization=normalization,
activation=activation,
with_cp=with_cp,
**kwargs,
),
Expand All @@ -450,8 +451,8 @@ def _make_stem_layer(self) -> None:
stride=(1, 2, 2),
padding=(0, 1, 1),
bias=False,
norm_cfg=None,
activation_callable=None,
normalization=None,
activation=None,
)
self.conv1_t = Conv3dModule(
self.base_channels,
Expand All @@ -461,8 +462,8 @@ def _make_stem_layer(self) -> None:
padding=(2, 0, 0),
groups=self.base_channels,
bias=False,
norm_cfg=self.norm_cfg,
activation_callable=self.activation_callable,
normalization=build_norm_layer(self.normalization, num_features=self.base_channels),
activation=build_activation_layer(self.activation),
)

def _freeze_stages(self) -> None:
Expand Down
5 changes: 3 additions & 2 deletions src/otx/algo/action_classification/x3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from otx.algo.action_classification.backbones.x3d import X3DBackbone
from otx.algo.action_classification.heads.x3d_head import X3DHead
from otx.algo.action_classification.recognizers.recognizer import BaseRecognizer
from otx.algo.modules.norm import build_norm_layer
from otx.algo.utils.mmengine_utils import load_checkpoint
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.metrics.accuracy import MultiClassClsMetricCallable
Expand Down Expand Up @@ -67,8 +68,8 @@ def _build_model(self, num_classes: int) -> nn.Module:
gamma_b=2.25,
gamma_d=2.2,
gamma_w=1,
norm_cfg={"type": "BN3d", "requires_grad": True},
activation_callable=partial(nn.ReLU, inplace=True),
normalization=partial(build_norm_layer, nn.BatchNorm3d, requires_grad=True),
activation=partial(nn.ReLU, inplace=True),
),
cls_head=X3DHead(
num_classes=num_classes,
Expand Down
Loading

0 comments on commit f50e821

Please sign in to comment.