Skip to content

Commit

Permalink
add back is_norm
Browse files Browse the repository at this point in the history
  • Loading branch information
eugene123tw committed May 8, 2024
1 parent b42e1d8 commit ef0be7b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
unmap,
)
from otx.algo.modules.conv_module import ConvModule
from otx.algo.modules.norm import is_norm
from otx.algo.modules.scale import Scale
from otx.algo.utils.mmengine_utils import InstanceData
from otx.algo.utils.weight_init import bias_init_with_prob, constant_init, normal_init
Expand Down Expand Up @@ -108,7 +109,8 @@ def init_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, mean=0, std=0.01)
constant_init(m, 1)
if is_norm(m):
constant_init(m, 1)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.rtm_cls, std=0.01, bias=bias_cls)
normal_init(self.rtm_reg, std=0.01)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from otx.algo.instance_segmentation.mmdet.structures.bbox.transforms import get_box_wh, scale_boxes
from otx.algo.modules.base_module import BaseModule
from otx.algo.modules.conv_module import ConvModule
from otx.algo.modules.norm import is_norm
from otx.algo.utils.weight_init import bias_init_with_prob, constant_init, normal_init

from .rtmdet_head import RTMDetHead
Expand Down Expand Up @@ -912,7 +913,8 @@ def init_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, mean=0, std=0.01)
constant_init(m, 1)
if is_norm(m):
constant_init(m, 1)
bias_cls = bias_init_with_prob(0.01)
for rtm_cls, rtm_reg in zip(self.rtm_cls, self.rtm_reg):
normal_init(rtm_cls, std=0.01, bias=bias_cls)
Expand Down
14 changes: 14 additions & 0 deletions src/otx/algo/modules/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,17 @@ def build_norm_layer(cfg: dict, num_features: int, postfix: int | str = "") -> t
param.requires_grad = requires_grad

return name, layer


def is_norm(layer: nn.Module) -> bool:
"""Check if a layer is a normalization layer.
Args:
layer (nn.Module): The layer to be checked.
exclude (type | tuple[type]): Types to be excluded.
Returns:
bool: Whether the layer is a norm layer.
"""
all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
return isinstance(layer, all_norm_bases)

0 comments on commit ef0be7b

Please sign in to comment.