diff --git a/src/otx/algo/instance_segmentation/mmdet/models/dense_heads/rtmdet_head.py b/src/otx/algo/instance_segmentation/mmdet/models/dense_heads/rtmdet_head.py index 6ed41265b7a..b133cd17a97 100644 --- a/src/otx/algo/instance_segmentation/mmdet/models/dense_heads/rtmdet_head.py +++ b/src/otx/algo/instance_segmentation/mmdet/models/dense_heads/rtmdet_head.py @@ -18,6 +18,7 @@ unmap, ) from otx.algo.modules.conv_module import ConvModule +from otx.algo.modules.norm import is_norm from otx.algo.modules.scale import Scale from otx.algo.utils.mmengine_utils import InstanceData from otx.algo.utils.weight_init import bias_init_with_prob, constant_init, normal_init @@ -108,7 +109,8 @@ def init_weights(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d): normal_init(m, mean=0, std=0.01) - constant_init(m, 1) + if is_norm(m): + constant_init(m, 1) bias_cls = bias_init_with_prob(0.01) normal_init(self.rtm_cls, std=0.01, bias=bias_cls) normal_init(self.rtm_reg, std=0.01) diff --git a/src/otx/algo/instance_segmentation/mmdet/models/dense_heads/rtmdet_ins_head.py b/src/otx/algo/instance_segmentation/mmdet/models/dense_heads/rtmdet_ins_head.py index 7eec5dbeba0..743f5bbb279 100644 --- a/src/otx/algo/instance_segmentation/mmdet/models/dense_heads/rtmdet_ins_head.py +++ b/src/otx/algo/instance_segmentation/mmdet/models/dense_heads/rtmdet_ins_head.py @@ -27,6 +27,7 @@ from otx.algo.instance_segmentation.mmdet.structures.bbox.transforms import get_box_wh, scale_boxes from otx.algo.modules.base_module import BaseModule from otx.algo.modules.conv_module import ConvModule +from otx.algo.modules.norm import is_norm from otx.algo.utils.weight_init import bias_init_with_prob, constant_init, normal_init from .rtmdet_head import RTMDetHead @@ -912,7 +913,8 @@ def init_weights(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d): normal_init(m, mean=0, std=0.01) - constant_init(m, 1) + if is_norm(m): + constant_init(m, 1) bias_cls = bias_init_with_prob(0.01) for rtm_cls, rtm_reg in zip(self.rtm_cls, self.rtm_reg): normal_init(rtm_cls, std=0.01, bias=bias_cls) diff --git a/src/otx/algo/modules/norm.py b/src/otx/algo/modules/norm.py index 587e6840ca8..88764f43585 100644 --- a/src/otx/algo/modules/norm.py +++ b/src/otx/algo/modules/norm.py @@ -132,3 +132,17 @@ def build_norm_layer(cfg: dict, num_features: int, postfix: int | str = "") -> t param.requires_grad = requires_grad return name, layer + + +def is_norm(layer: nn.Module) -> bool: + """Check if a layer is a normalization layer. + + Args: + layer (nn.Module): The layer to be checked. + exclude (type | tuple[type]): Types to be excluded. + + Returns: + bool: Whether the layer is a norm layer. + """ + all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm) + return isinstance(layer, all_norm_bases)