From 40a505347d5658e732c19c579117fd9197560e24 Mon Sep 17 00:00:00 2001 From: liuyanyi Date: Tue, 4 Apr 2023 18:11:36 +0800 Subject: [PATCH 1/5] Move predict to base angle dense head --- mmrotate/models/dense_heads/__init__.py | 3 +- .../dense_heads/base_angle_dense_head.py | 297 ++++++++++++++++++ .../models/dense_heads/rotated_fcos_head.py | 260 +-------------- .../models/dense_heads/rotated_rtmdet_head.py | 260 +-------------- 4 files changed, 329 insertions(+), 491 deletions(-) create mode 100644 mmrotate/models/dense_heads/base_angle_dense_head.py diff --git a/mmrotate/models/dense_heads/__init__.py b/mmrotate/models/dense_heads/__init__.py index 0c4b6ef36..8351d79c1 100644 --- a/mmrotate/models/dense_heads/__init__.py +++ b/mmrotate/models/dense_heads/__init__.py @@ -12,11 +12,12 @@ from .rotated_rtmdet_head import RotatedRTMDetHead, RotatedRTMDetSepBNHead from .s2a_head import S2AHead, S2ARefineHead from .sam_reppoints_head import SAMRepPointsHead +from .base_angle_dense_head import BaseAngleDenseHead __all__ = [ 'RotatedRetinaHead', 'OrientedRPNHead', 'RotatedRepPointsHead', 'SAMRepPointsHead', 'AngleBranchRetinaHead', 'RotatedATSSHead', 'RotatedFCOSHead', 'OrientedRepPointsHead', 'R3Head', 'R3RefineHead', 'S2AHead', 'S2ARefineHead', 'CFAHead', 'H2RBoxHead', 'RotatedRTMDetHead', - 'RotatedRTMDetSepBNHead' + 'RotatedRTMDetSepBNHead', 'BaseAngleDenseHead' ] diff --git a/mmrotate/models/dense_heads/base_angle_dense_head.py b/mmrotate/models/dense_heads/base_angle_dense_head.py new file mode 100644 index 000000000..f537a52ab --- /dev/null +++ b/mmrotate/models/dense_heads/base_angle_dense_head.py @@ -0,0 +1,297 @@ +import copy +from typing import List, Optional + +import torch +from mmdet.models.dense_heads.base_dense_head import BaseDenseHead +from mmdet.models.utils import filter_scores_and_topk, select_single_mlvl +from mmdet.structures.bbox import cat_boxes +from mmdet.utils import InstanceList +from mmengine import ConfigDict +from mmengine.structures import InstanceData +from torch import Tensor +from mmdet.utils import (ConfigType, InstanceList, OptConfigType, + OptInstanceList, reduce_mean) + +from mmrotate.structures import RotatedBoxes +from mmrotate.registry import MODELS, TASK_UTILS + + +@MODELS.register_module() +class BaseAngleDenseHead(BaseDenseHead): + """ + Base class for dense heads with angle. + Commonly, BaseAngleDenseHead will be used with other head. + + Args: + angle_version (str, optional): The version of angle. Defaults to + 'le90'. + use_hbbox_loss (bool, optional): Whether to use the loss of + horizontal bboxes. Defaults to False. + angle_coder (dict, optional): Config dict for angle coder. + Defaults to dict(type='PseudoAngleCoder'). + loss_angle (dict, optional): Config dict for angle loss. + Defaults to None. + """ + + def __init__(self, + angle_version: str = 'le90', + use_hbbox_loss: bool = False, + angle_coder: ConfigType = dict(type='PseudoAngleCoder'), + loss_angle: OptConfigType = None, + *args, **kwargs): + self.angle_version = angle_version + self.use_hbbox_loss = use_hbbox_loss + self.angle_coder = TASK_UTILS.build(angle_coder) + if loss_angle is not None: + self.loss_angle = MODELS.build(loss_angle) + else: + self.loss_angle = None + if self.use_hbbox_loss: + assert self.loss_angle is not None + # Commonly, BaseAngleDenseHead will be used with other head. + # So we call super here to init the other head. + # For example, RotatedFCOSHead will used with FCOSHead, + # so super here will call the init function of FCOSHead. + super().__init__(*args, **kwargs) + + def predict_by_feat(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + angle_preds: List[Tensor], + score_factors: Optional[List[Tensor]] = None, + batch_img_metas: Optional[List[dict]] = None, + cfg: Optional[ConfigDict] = None, + rescale: bool = False, + with_nms: bool = True) -> InstanceList: + """Transform a batch of output features extracted from the head into + bbox results. + Note: When score_factors is not None, the cls_scores are + usually multiplied by it then obtain the real score used in NMS, + such as CenterNess in FCOS, IoU branch in ATSS. + Args: + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + angle_preds (list[Tensor]): Box angle for each scale level + with shape (N, num_points * encode_size, H, W) + score_factors (list[Tensor], optional): Score factor for + all scale level, each is a 4D-tensor, has shape + (batch_size, num_priors * 1, H, W). Defaults to None. + batch_img_metas (list[dict], Optional): Batch image meta info. + Defaults to None. + cfg (ConfigDict, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + Returns: + list[:obj:`InstanceData`]: Object detection results of each image + after the post process. Each item usually contains following keys. + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 5), + the last dimension 5 arrange as (x, y, w, h, t). + """ + assert len(cls_scores) == len(bbox_preds) + + if score_factors is None: + # e.g. Retina, FreeAnchor, Foveabox, etc. + with_score_factors = False + else: + # e.g. FCOS, PAA, ATSS, AutoAssign, etc. + with_score_factors = True + assert len(cls_scores) == len(score_factors) + + num_levels = len(cls_scores) + + featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] + mlvl_priors = self.prior_generator.grid_priors( + featmap_sizes, + dtype=cls_scores[0].dtype, + device=cls_scores[0].device) + + result_list = [] + + for img_id in range(len(batch_img_metas)): + img_meta = batch_img_metas[img_id] + cls_score_list = select_single_mlvl( + cls_scores, img_id, detach=True) + bbox_pred_list = select_single_mlvl( + bbox_preds, img_id, detach=True) + angle_pred_list = select_single_mlvl( + angle_preds, img_id, detach=True) + if with_score_factors: + score_factor_list = select_single_mlvl( + score_factors, img_id, detach=True) + else: + score_factor_list = [None for _ in range(num_levels)] + + results = self._predict_by_feat_single( + cls_score_list=cls_score_list, + bbox_pred_list=bbox_pred_list, + angle_pred_list=angle_pred_list, + score_factor_list=score_factor_list, + mlvl_priors=mlvl_priors, + img_meta=img_meta, + cfg=cfg, + rescale=rescale, + with_nms=with_nms) + result_list.append(results) + return result_list + + def _predict_by_feat_single(self, + cls_score_list: List[Tensor], + bbox_pred_list: List[Tensor], + angle_pred_list: List[Tensor], + score_factor_list: List[Tensor], + mlvl_priors: List[Tensor], + img_meta: dict, + cfg: ConfigDict, + rescale: bool = False, + with_nms: bool = True) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + angle_pred_list (list[Tensor]): Box angle for a single scale + level with shape (N, num_points * encode_size, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image, each item has shape + (num_priors * 1, H, W). + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid. In all + anchor-based methods, it has shape (num_priors, 4). In + all anchor-free methods, it has shape (num_priors, 2) + when `with_stride=True`, otherwise it still has shape + (num_priors, 4). + img_meta (dict): Image meta info. + cfg (mmengine.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 5), + the last dimension 5 arrange as (x, y, w, h, t). + """ + if score_factor_list[0] is None: + # e.g. Retina, FreeAnchor, etc. + with_score_factors = False + else: + # e.g. FCOS, PAA, ATSS, etc. + with_score_factors = True + + cfg = self.test_cfg if cfg is None else cfg + cfg = copy.deepcopy(cfg) + img_shape = img_meta['img_shape'] + nms_pre = cfg.get('nms_pre', -1) + + mlvl_bbox_preds = [] + mlvl_decoded_angles = [] + mlvl_valid_priors = [] + mlvl_scores = [] + mlvl_labels = [] + if with_score_factors: + mlvl_score_factors = [] + else: + mlvl_score_factors = None + for level_idx, ( + cls_score, bbox_pred, angle_pred, score_factor, priors) in \ + enumerate(zip(cls_score_list, bbox_pred_list, angle_pred_list, + score_factor_list, mlvl_priors)): + + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + + # dim = self.bbox_coder.encode_size + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) + angle_pred = angle_pred.permute(1, 2, 0).reshape( + -1, self.angle_coder.encode_size) + if with_score_factors: + score_factor = score_factor.permute(1, 2, + 0).reshape(-1).sigmoid() + cls_score = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + if self.use_sigmoid_cls: + scores = cls_score.sigmoid() + else: + # remind that we set FG labels to [0, num_class-1] + # since mmdet v2.0 + # BG cat_id: num_class + scores = cls_score.softmax(-1)[:, :-1] + + # After https://github.com/open-mmlab/mmdetection/pull/6268/, + # this operation keeps fewer bboxes under the same `nms_pre`. + # There is no difference in performance for most models. If you + # find a slight drop in performance, you can set a larger + # `nms_pre` than before. + score_thr = cfg.get('score_thr', 0) + + results = filter_scores_and_topk( + scores, score_thr, nms_pre, + dict( + bbox_pred=bbox_pred, angle_pred=angle_pred, priors=priors)) + scores, labels, keep_idxs, filtered_results = results + + bbox_pred = filtered_results['bbox_pred'] + angle_pred = filtered_results['angle_pred'] + priors = filtered_results['priors'] + + decoded_angle = self.angle_coder.decode(angle_pred, keepdim=True) + + if with_score_factors: + score_factor = score_factor[keep_idxs] + + mlvl_bbox_preds.append(bbox_pred) + mlvl_decoded_angles.append(decoded_angle) + mlvl_valid_priors.append(priors) + mlvl_scores.append(scores) + mlvl_labels.append(labels) + + if with_score_factors: + mlvl_score_factors.append(score_factor) + + bbox_pred = torch.cat(mlvl_bbox_preds) + decoded_angle = torch.cat(mlvl_decoded_angles) + priors = cat_boxes(mlvl_valid_priors) + + decode_with_angle = cfg.get('decode_with_angle', True) + if decode_with_angle: + bbox_pred = torch.cat([bbox_pred, decoded_angle], dim=-1) + bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape) + else: + bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape) + bboxes = torch.cat([bboxes[..., :4], decoded_angle], dim=-1) + + results = InstanceData() + results.bboxes = RotatedBoxes(bboxes) + results.scores = torch.cat(mlvl_scores) + results.labels = torch.cat(mlvl_labels) + if with_score_factors: + results.score_factors = torch.cat(mlvl_score_factors) + + return self._bbox_post_process( + results=results, + cfg=cfg, + rescale=rescale, + with_nms=with_nms, + img_meta=img_meta) diff --git a/mmrotate/models/dense_heads/rotated_fcos_head.py b/mmrotate/models/dense_heads/rotated_fcos_head.py index b8fce7cb5..e0f5ba602 100644 --- a/mmrotate/models/dense_heads/rotated_fcos_head.py +++ b/mmrotate/models/dense_heads/rotated_fcos_head.py @@ -17,12 +17,13 @@ from mmrotate.registry import MODELS, TASK_UTILS from mmrotate.structures import RotatedBoxes +from mmrotate.models.dense_heads.base_angle_dense_head import BaseAngleDenseHead INF = 1e8 -@MODELS.register_module() -class RotatedFCOSHead(FCOSHead): +# @MODELS.register_module() +class RotatedFCOSHead(BaseAngleDenseHead, FCOSHead): """Anchor-free head used in `FCOS `_. Compared with FCOS head, Rotated FCOS head add a angle branch to @@ -78,25 +79,24 @@ def __init__(self, loss_weight=1.0), loss_angle: OptConfigType = None, **kwargs): - self.angle_version = angle_version - self.use_hbbox_loss = use_hbbox_loss self.is_scale_angle = scale_angle - self.angle_coder = TASK_UTILS.build(angle_coder) + # There are two super classes, so all the arguments should be + # passed to the super class. super().__init__( + # Arguments for FCOSHead num_classes=num_classes, in_channels=in_channels, bbox_coder=bbox_coder, loss_cls=loss_cls, loss_bbox=loss_bbox, loss_centerness=loss_centerness, + # Arguments for BaseAngleDenseHead + angle_version=angle_version, + use_hbbox_loss=use_hbbox_loss, + angle_coder=angle_coder, + loss_angle=loss_angle, **kwargs) - if loss_angle is not None: - self.loss_angle = MODELS.build(loss_angle) - else: - self.loss_angle = None - if self.use_hbbox_loss: - assert self.loss_angle is not None - self.h_bbox_coder = TASK_UTILS.build(h_bbox_coder) + self.h_bbox_coder = TASK_UTILS.build(h_bbox_coder) def _init_layers(self): """Initialize layers of the head.""" @@ -437,235 +437,9 @@ def _get_targets_single( return labels, bbox_targets, angle_targets - def predict_by_feat(self, - cls_scores: List[Tensor], - bbox_preds: List[Tensor], - angle_preds: List[Tensor], - score_factors: Optional[List[Tensor]] = None, - batch_img_metas: Optional[List[dict]] = None, - cfg: Optional[ConfigDict] = None, - rescale: bool = False, - with_nms: bool = True) -> InstanceList: - """Transform a batch of output features extracted from the head into - bbox results. - Note: When score_factors is not None, the cls_scores are - usually multiplied by it then obtain the real score used in NMS, - such as CenterNess in FCOS, IoU branch in ATSS. - Args: - cls_scores (list[Tensor]): Classification scores for all - scale levels, each is a 4D-tensor, has shape - (batch_size, num_priors * num_classes, H, W). - bbox_preds (list[Tensor]): Box energies / deltas for all - scale levels, each is a 4D-tensor, has shape - (batch_size, num_priors * 4, H, W). - angle_preds (list[Tensor]): Box angle for each scale level - with shape (N, num_points * encode_size, H, W) - score_factors (list[Tensor], optional): Score factor for - all scale level, each is a 4D-tensor, has shape - (batch_size, num_priors * 1, H, W). Defaults to None. - batch_img_metas (list[dict], Optional): Batch image meta info. - Defaults to None. - cfg (ConfigDict, optional): Test / postprocessing - configuration, if None, test_cfg would be used. - Defaults to None. - rescale (bool): If True, return boxes in original image space. - Defaults to False. - with_nms (bool): If True, do nms before return boxes. - Defaults to True. - Returns: - list[:obj:`InstanceData`]: Object detection results of each image - after the post process. Each item usually contains following keys. - - scores (Tensor): Classification scores, has a shape - (num_instance, ) - - labels (Tensor): Labels of bboxes, has a shape - (num_instances, ). - - bboxes (Tensor): Has a shape (num_instances, 5), - the last dimension 5 arrange as (x, y, w, h, t). - """ - assert len(cls_scores) == len(bbox_preds) - - if score_factors is None: - # e.g. Retina, FreeAnchor, Foveabox, etc. - with_score_factors = False - else: - # e.g. FCOS, PAA, ATSS, AutoAssign, etc. - with_score_factors = True - assert len(cls_scores) == len(score_factors) - - num_levels = len(cls_scores) - featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] - mlvl_priors = self.prior_generator.grid_priors( - featmap_sizes, - dtype=cls_scores[0].dtype, - device=cls_scores[0].device) - - result_list = [] - - for img_id in range(len(batch_img_metas)): - img_meta = batch_img_metas[img_id] - cls_score_list = select_single_mlvl( - cls_scores, img_id, detach=True) - bbox_pred_list = select_single_mlvl( - bbox_preds, img_id, detach=True) - angle_pred_list = select_single_mlvl( - angle_preds, img_id, detach=True) - if with_score_factors: - score_factor_list = select_single_mlvl( - score_factors, img_id, detach=True) - else: - score_factor_list = [None for _ in range(num_levels)] - - results = self._predict_by_feat_single( - cls_score_list=cls_score_list, - bbox_pred_list=bbox_pred_list, - angle_pred_list=angle_pred_list, - score_factor_list=score_factor_list, - mlvl_priors=mlvl_priors, - img_meta=img_meta, - cfg=cfg, - rescale=rescale, - with_nms=with_nms) - result_list.append(results) - return result_list - - def _predict_by_feat_single(self, - cls_score_list: List[Tensor], - bbox_pred_list: List[Tensor], - angle_pred_list: List[Tensor], - score_factor_list: List[Tensor], - mlvl_priors: List[Tensor], - img_meta: dict, - cfg: ConfigDict, - rescale: bool = False, - with_nms: bool = True) -> InstanceData: - """Transform a single image's features extracted from the head into - bbox results. - Args: - cls_score_list (list[Tensor]): Box scores from all scale - levels of a single image, each item has shape - (num_priors * num_classes, H, W). - bbox_pred_list (list[Tensor]): Box energies / deltas from - all scale levels of a single image, each item has shape - (num_priors * 4, H, W). - angle_pred_list (list[Tensor]): Box angle for a single scale - level with shape (N, num_points * encode_size, H, W). - score_factor_list (list[Tensor]): Score factor from all scale - levels of a single image, each item has shape - (num_priors * 1, H, W). - mlvl_priors (list[Tensor]): Each element in the list is - the priors of a single level in feature pyramid. In all - anchor-based methods, it has shape (num_priors, 4). In - all anchor-free methods, it has shape (num_priors, 2) - when `with_stride=True`, otherwise it still has shape - (num_priors, 4). - img_meta (dict): Image meta info. - cfg (mmengine.Config): Test / postprocessing configuration, - if None, test_cfg would be used. - rescale (bool): If True, return boxes in original image space. - Defaults to False. - with_nms (bool): If True, do nms before return boxes. - Defaults to True. - Returns: - :obj:`InstanceData`: Detection results of each image - after the post process. - Each item usually contains following keys. - - scores (Tensor): Classification scores, has a shape - (num_instance, ) - - labels (Tensor): Labels of bboxes, has a shape - (num_instances, ). - - bboxes (Tensor): Has a shape (num_instances, 5), - the last dimension 5 arrange as (x, y, w, h, t). - """ - if score_factor_list[0] is None: - # e.g. Retina, FreeAnchor, etc. - with_score_factors = False - else: - # e.g. FCOS, PAA, ATSS, etc. - with_score_factors = True - - cfg = self.test_cfg if cfg is None else cfg - cfg = copy.deepcopy(cfg) - img_shape = img_meta['img_shape'] - nms_pre = cfg.get('nms_pre', -1) - - mlvl_bbox_preds = [] - mlvl_valid_priors = [] - mlvl_scores = [] - mlvl_labels = [] - if with_score_factors: - mlvl_score_factors = [] - else: - mlvl_score_factors = None - for level_idx, ( - cls_score, bbox_pred, angle_pred, score_factor, priors) in \ - enumerate(zip(cls_score_list, bbox_pred_list, angle_pred_list, - score_factor_list, mlvl_priors)): - - assert cls_score.size()[-2:] == bbox_pred.size()[-2:] - - # dim = self.bbox_coder.encode_size - bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) - angle_pred = angle_pred.permute(1, 2, 0).reshape( - -1, self.angle_coder.encode_size) - if with_score_factors: - score_factor = score_factor.permute(1, 2, - 0).reshape(-1).sigmoid() - cls_score = cls_score.permute(1, 2, - 0).reshape(-1, self.cls_out_channels) - if self.use_sigmoid_cls: - scores = cls_score.sigmoid() - else: - # remind that we set FG labels to [0, num_class-1] - # since mmdet v2.0 - # BG cat_id: num_class - scores = cls_score.softmax(-1)[:, :-1] - - # After https://github.com/open-mmlab/mmdetection/pull/6268/, - # this operation keeps fewer bboxes under the same `nms_pre`. - # There is no difference in performance for most models. If you - # find a slight drop in performance, you can set a larger - # `nms_pre` than before. - score_thr = cfg.get('score_thr', 0) - - results = filter_scores_and_topk( - scores, score_thr, nms_pre, - dict( - bbox_pred=bbox_pred, angle_pred=angle_pred, priors=priors)) - scores, labels, keep_idxs, filtered_results = results - - bbox_pred = filtered_results['bbox_pred'] - angle_pred = filtered_results['angle_pred'] - priors = filtered_results['priors'] - - decoded_angle = self.angle_coder.decode(angle_pred, keepdim=True) - bbox_pred = torch.cat([bbox_pred, decoded_angle], dim=-1) - - if with_score_factors: - score_factor = score_factor[keep_idxs] - - mlvl_bbox_preds.append(bbox_pred) - mlvl_valid_priors.append(priors) - mlvl_scores.append(scores) - mlvl_labels.append(labels) - - if with_score_factors: - mlvl_score_factors.append(score_factor) - - bbox_pred = torch.cat(mlvl_bbox_preds) - priors = cat_boxes(mlvl_valid_priors) - bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape) - - results = InstanceData() - results.bboxes = RotatedBoxes(bboxes) - results.scores = torch.cat(mlvl_scores) - results.labels = torch.cat(mlvl_labels) - if with_score_factors: - results.score_factors = torch.cat(mlvl_score_factors) - - return self._bbox_post_process( - results=results, - cfg=cfg, - rescale=rescale, - with_nms=with_nms, - img_meta=img_meta) +if __name__=='__main__': + from mmrotate.utils import register_all_modules + register_all_modules() + fcos_head = RotatedFCOSHead(num_classes=2, in_channels=256, feat_channels=256, strides=[8, 16, 32, 64, 128], angle_version='le90') + print(RotatedFCOSHead.mro()) \ No newline at end of file diff --git a/mmrotate/models/dense_heads/rotated_rtmdet_head.py b/mmrotate/models/dense_heads/rotated_rtmdet_head.py index 6ee13b941..038fc6058 100644 --- a/mmrotate/models/dense_heads/rotated_rtmdet_head.py +++ b/mmrotate/models/dense_heads/rotated_rtmdet_head.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy from typing import List, Optional, Tuple import torch @@ -7,23 +6,23 @@ from mmdet.models import inverse_sigmoid from mmdet.models.dense_heads import RTMDetHead from mmdet.models.task_modules import anchor_inside_flags -from mmdet.models.utils import (filter_scores_and_topk, multi_apply, - select_single_mlvl, sigmoid_geometric_mean, +from mmdet.models.utils import (multi_apply, + sigmoid_geometric_mean, unmap) -from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, cat_boxes, distance2bbox +from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, distance2bbox from mmdet.utils import (ConfigType, InstanceList, OptConfigType, OptInstanceList, reduce_mean) -from mmengine import ConfigDict from mmengine.model import bias_init_with_prob, constant_init, normal_init from mmengine.structures import InstanceData from torch import Tensor, nn -from mmrotate.registry import MODELS, TASK_UTILS -from mmrotate.structures import RotatedBoxes, distance2obb +from mmrotate.registry import MODELS +from mmrotate.structures import distance2obb +from mmrotate.models.dense_heads.base_angle_dense_head import BaseAngleDenseHead @MODELS.register_module() -class RotatedRTMDetHead(RTMDetHead): +class RotatedRTMDetHead(BaseAngleDenseHead, RTMDetHead): """Detection Head of Rotated RTMDet. Args: @@ -48,11 +47,9 @@ def __init__(self, angle_coder: ConfigType = dict(type='PseudoAngleCoder'), loss_angle: OptConfigType = None, **kwargs) -> None: - self.angle_version = angle_version - self.use_hbbox_loss = use_hbbox_loss self.is_scale_angle = scale_angle - self.angle_coder = TASK_UTILS.build(angle_coder) super().__init__( + # Arguments of RTMDetHead num_classes, in_channels, # useless, but error @@ -60,11 +57,12 @@ def __init__(self, type='mmdet.CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + # Arguments of BaseAngleDenseHead + angle_version=angle_version, + use_hbbox_loss=use_hbbox_loss, + angle_coder=angle_coder, + loss_angle=loss_angle, **kwargs) - if loss_angle is not None: - self.loss_angle = MODELS.build(loss_angle) - else: - self.loss_angle = None def _init_layers(self): """Initialize layers of the head.""" @@ -446,238 +444,6 @@ def _get_targets_single(self, return (anchors, labels, label_weights, bbox_targets, assign_metrics, sampling_result) - def predict_by_feat(self, - cls_scores: List[Tensor], - bbox_preds: List[Tensor], - angle_preds: List[Tensor], - score_factors: Optional[List[Tensor]] = None, - batch_img_metas: Optional[List[dict]] = None, - cfg: Optional[ConfigDict] = None, - rescale: bool = False, - with_nms: bool = True) -> InstanceList: - """Transform a batch of output features extracted from the head into - bbox results. - Note: When score_factors is not None, the cls_scores are - usually multiplied by it then obtain the real score used in NMS, - such as CenterNess in FCOS, IoU branch in ATSS. - Args: - cls_scores (list[Tensor]): Classification scores for all - scale levels, each is a 4D-tensor, has shape - (batch_size, num_priors * num_classes, H, W). - bbox_preds (list[Tensor]): Box energies / deltas for all - scale levels, each is a 4D-tensor, has shape - (batch_size, num_priors * 4, H, W). - angle_preds (list[Tensor]): Box angle for each scale level - with shape (N, num_points * angle_dim, H, W) - score_factors (list[Tensor], optional): Score factor for - all scale level, each is a 4D-tensor, has shape - (batch_size, num_priors * 1, H, W). Defaults to None. - batch_img_metas (list[dict], Optional): Batch image meta info. - Defaults to None. - cfg (ConfigDict, optional): Test / postprocessing - configuration, if None, test_cfg would be used. - Defaults to None. - rescale (bool): If True, return boxes in original image space. - Defaults to False. - with_nms (bool): If True, do nms before return boxes. - Defaults to True. - Returns: - list[:obj:`InstanceData`]: Object detection results of each image - after the post process. Each item usually contains following keys. - - scores (Tensor): Classification scores, has a shape - (num_instance, ) - - labels (Tensor): Labels of bboxes, has a shape - (num_instances, ). - - bboxes (Tensor): Has a shape (num_instances, 5), - the last dimension 5 arrange as (x, y, w, h, t). - """ - assert len(cls_scores) == len(bbox_preds) - - if score_factors is None: - # e.g. Retina, FreeAnchor, Foveabox, etc. - with_score_factors = False - else: - # e.g. FCOS, PAA, ATSS, AutoAssign, etc. - with_score_factors = True - assert len(cls_scores) == len(score_factors) - - num_levels = len(cls_scores) - - featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] - mlvl_priors = self.prior_generator.grid_priors( - featmap_sizes, - dtype=cls_scores[0].dtype, - device=cls_scores[0].device) - - result_list = [] - - for img_id in range(len(batch_img_metas)): - img_meta = batch_img_metas[img_id] - cls_score_list = select_single_mlvl( - cls_scores, img_id, detach=True) - bbox_pred_list = select_single_mlvl( - bbox_preds, img_id, detach=True) - angle_pred_list = select_single_mlvl( - angle_preds, img_id, detach=True) - if with_score_factors: - score_factor_list = select_single_mlvl( - score_factors, img_id, detach=True) - else: - score_factor_list = [None for _ in range(num_levels)] - - results = self._predict_by_feat_single( - cls_score_list=cls_score_list, - bbox_pred_list=bbox_pred_list, - angle_pred_list=angle_pred_list, - score_factor_list=score_factor_list, - mlvl_priors=mlvl_priors, - img_meta=img_meta, - cfg=cfg, - rescale=rescale, - with_nms=with_nms) - result_list.append(results) - return result_list - - def _predict_by_feat_single(self, - cls_score_list: List[Tensor], - bbox_pred_list: List[Tensor], - angle_pred_list: List[Tensor], - score_factor_list: List[Tensor], - mlvl_priors: List[Tensor], - img_meta: dict, - cfg: ConfigDict, - rescale: bool = False, - with_nms: bool = True) -> InstanceData: - """Transform a single image's features extracted from the head into - bbox results. - Args: - cls_score_list (list[Tensor]): Box scores from all scale - levels of a single image, each item has shape - (num_priors * num_classes, H, W). - bbox_pred_list (list[Tensor]): Box energies / deltas from - all scale levels of a single image, each item has shape - (num_priors * 4, H, W). - angle_pred_list (list[Tensor]): Box angle for a single scale - level with shape (N, num_points * angle_dim, H, W). - score_factor_list (list[Tensor]): Score factor from all scale - levels of a single image, each item has shape - (num_priors * 1, H, W). - mlvl_priors (list[Tensor]): Each element in the list is - the priors of a single level in feature pyramid. In all - anchor-based methods, it has shape (num_priors, 4). In - all anchor-free methods, it has shape (num_priors, 2) - when `with_stride=True`, otherwise it still has shape - (num_priors, 4). - img_meta (dict): Image meta info. - cfg (mmengine.Config): Test / postprocessing configuration, - if None, test_cfg would be used. - rescale (bool): If True, return boxes in original image space. - Defaults to False. - with_nms (bool): If True, do nms before return boxes. - Defaults to True. - Returns: - :obj:`InstanceData`: Detection results of each image - after the post process. - Each item usually contains following keys. - - scores (Tensor): Classification scores, has a shape - (num_instance, ) - - labels (Tensor): Labels of bboxes, has a shape - (num_instances, ). - - bboxes (Tensor): Has a shape (num_instances, 5), - the last dimension 5 arrange as (x, y, w, h, t). - """ - if score_factor_list[0] is None: - # e.g. Retina, FreeAnchor, etc. - with_score_factors = False - else: - # e.g. FCOS, PAA, ATSS, etc. - with_score_factors = True - - cfg = self.test_cfg if cfg is None else cfg - cfg = copy.deepcopy(cfg) - img_shape = img_meta['img_shape'] - nms_pre = cfg.get('nms_pre', -1) - - mlvl_bbox_preds = [] - mlvl_valid_priors = [] - mlvl_scores = [] - mlvl_labels = [] - if with_score_factors: - mlvl_score_factors = [] - else: - mlvl_score_factors = None - for level_idx, ( - cls_score, bbox_pred, angle_pred, score_factor, priors) in \ - enumerate(zip(cls_score_list, bbox_pred_list, angle_pred_list, - score_factor_list, mlvl_priors)): - - assert cls_score.size()[-2:] == bbox_pred.size()[-2:] - - bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) - angle_pred = angle_pred.permute(1, 2, 0).reshape( - -1, self.angle_coder.encode_size) - if with_score_factors: - score_factor = score_factor.permute(1, 2, - 0).reshape(-1).sigmoid() - cls_score = cls_score.permute(1, 2, - 0).reshape(-1, self.cls_out_channels) - if self.use_sigmoid_cls: - scores = cls_score.sigmoid() - else: - # remind that we set FG labels to [0, num_class-1] - # since mmdet v2.0 - # BG cat_id: num_class - scores = cls_score.softmax(-1)[:, :-1] - - # After https://github.com/open-mmlab/mmdetection/pull/6268/, - # this operation keeps fewer bboxes under the same `nms_pre`. - # There is no difference in performance for most models. If you - # find a slight drop in performance, you can set a larger - # `nms_pre` than before. - score_thr = cfg.get('score_thr', 0) - - results = filter_scores_and_topk( - scores, score_thr, nms_pre, - dict( - bbox_pred=bbox_pred, angle_pred=angle_pred, priors=priors)) - scores, labels, keep_idxs, filtered_results = results - - bbox_pred = filtered_results['bbox_pred'] - angle_pred = filtered_results['angle_pred'] - priors = filtered_results['priors'] - - decoded_angle = self.angle_coder.decode(angle_pred, keepdim=True) - bbox_pred = torch.cat([bbox_pred, decoded_angle], dim=-1) - - if with_score_factors: - score_factor = score_factor[keep_idxs] - - mlvl_bbox_preds.append(bbox_pred) - mlvl_valid_priors.append(priors) - mlvl_scores.append(scores) - mlvl_labels.append(labels) - - if with_score_factors: - mlvl_score_factors.append(score_factor) - - bbox_pred = torch.cat(mlvl_bbox_preds) - priors = cat_boxes(mlvl_valid_priors) - bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape) - - results = InstanceData() - results.bboxes = RotatedBoxes(bboxes) - results.scores = torch.cat(mlvl_scores) - results.labels = torch.cat(mlvl_labels) - if with_score_factors: - results.score_factors = torch.cat(mlvl_score_factors) - - return self._bbox_post_process( - results=results, - cfg=cfg, - rescale=rescale, - with_nms=with_nms, - img_meta=img_meta) - @MODELS.register_module() class RotatedRTMDetSepBNHead(RotatedRTMDetHead): From 60bac0e2219b3cf73256e5805e67cfb93889c8ec Mon Sep 17 00:00:00 2001 From: liuyanyi Date: Tue, 4 Apr 2023 19:21:08 +0800 Subject: [PATCH 2/5] fix lint --- mmrotate/models/dense_heads/__init__.py | 2 +- .../dense_heads/base_angle_dense_head.py | 21 ++++++++++--------- .../models/dense_heads/rotated_fcos_head.py | 21 +++++-------------- .../models/dense_heads/rotated_rtmdet_head.py | 7 +++---- 4 files changed, 20 insertions(+), 31 deletions(-) diff --git a/mmrotate/models/dense_heads/__init__.py b/mmrotate/models/dense_heads/__init__.py index 8351d79c1..033bde640 100644 --- a/mmrotate/models/dense_heads/__init__.py +++ b/mmrotate/models/dense_heads/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .angle_branch_retina_head import AngleBranchRetinaHead +from .base_angle_dense_head import BaseAngleDenseHead from .cfa_head import CFAHead from .h2rbox_head import H2RBoxHead from .oriented_reppoints_head import OrientedRepPointsHead @@ -12,7 +13,6 @@ from .rotated_rtmdet_head import RotatedRTMDetHead, RotatedRTMDetSepBNHead from .s2a_head import S2AHead, S2ARefineHead from .sam_reppoints_head import SAMRepPointsHead -from .base_angle_dense_head import BaseAngleDenseHead __all__ = [ 'RotatedRetinaHead', 'OrientedRPNHead', 'RotatedRepPointsHead', diff --git a/mmrotate/models/dense_heads/base_angle_dense_head.py b/mmrotate/models/dense_heads/base_angle_dense_head.py index f537a52ab..f075fd5ea 100644 --- a/mmrotate/models/dense_heads/base_angle_dense_head.py +++ b/mmrotate/models/dense_heads/base_angle_dense_head.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. import copy from typing import List, Optional @@ -5,22 +6,19 @@ from mmdet.models.dense_heads.base_dense_head import BaseDenseHead from mmdet.models.utils import filter_scores_and_topk, select_single_mlvl from mmdet.structures.bbox import cat_boxes -from mmdet.utils import InstanceList +from mmdet.utils import ConfigType, InstanceList, OptConfigType from mmengine import ConfigDict from mmengine.structures import InstanceData from torch import Tensor -from mmdet.utils import (ConfigType, InstanceList, OptConfigType, - OptInstanceList, reduce_mean) -from mmrotate.structures import RotatedBoxes from mmrotate.registry import MODELS, TASK_UTILS +from mmrotate.structures import RotatedBoxes @MODELS.register_module() class BaseAngleDenseHead(BaseDenseHead): - """ - Base class for dense heads with angle. - Commonly, BaseAngleDenseHead will be used with other head. + """Base class for dense heads with angle. Commonly, BaseAngleDenseHead will + be used with other head. Args: angle_version (str, optional): The version of angle. Defaults to @@ -38,7 +36,8 @@ def __init__(self, use_hbbox_loss: bool = False, angle_coder: ConfigType = dict(type='PseudoAngleCoder'), loss_angle: OptConfigType = None, - *args, **kwargs): + *args, + **kwargs): self.angle_version = angle_version self.use_hbbox_loss = use_hbbox_loss self.angle_coder = TASK_UTILS.build(angle_coder) @@ -277,9 +276,11 @@ def _predict_by_feat_single(self, decode_with_angle = cfg.get('decode_with_angle', True) if decode_with_angle: bbox_pred = torch.cat([bbox_pred, decoded_angle], dim=-1) - bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape) + bboxes = self.bbox_coder.decode( + priors, bbox_pred, max_shape=img_shape) else: - bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape) + bboxes = self.bbox_coder.decode( + priors, bbox_pred, max_shape=img_shape) bboxes = torch.cat([bboxes[..., :4], decoded_angle], dim=-1) results = InstanceData() diff --git a/mmrotate/models/dense_heads/rotated_fcos_head.py b/mmrotate/models/dense_heads/rotated_fcos_head.py index e0f5ba602..4088d6cc4 100644 --- a/mmrotate/models/dense_heads/rotated_fcos_head.py +++ b/mmrotate/models/dense_heads/rotated_fcos_head.py @@ -1,23 +1,19 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple import torch import torch.nn as nn from mmcv.cnn import Scale from mmdet.models.dense_heads import FCOSHead -from mmdet.models.utils import (filter_scores_and_topk, multi_apply, - select_single_mlvl) -from mmdet.structures.bbox import cat_boxes +from mmdet.models.utils import multi_apply from mmdet.utils import (ConfigType, InstanceList, OptConfigType, OptInstanceList, reduce_mean) -from mmengine import ConfigDict from mmengine.structures import InstanceData from torch import Tensor -from mmrotate.registry import MODELS, TASK_UTILS -from mmrotate.structures import RotatedBoxes -from mmrotate.models.dense_heads.base_angle_dense_head import BaseAngleDenseHead +from mmrotate.models.dense_heads.base_angle_dense_head import \ + BaseAngleDenseHead +from mmrotate.registry import TASK_UTILS INF = 1e8 @@ -436,10 +432,3 @@ def _get_targets_single( angle_targets = gt_angle[range(num_points), min_area_inds] return labels, bbox_targets, angle_targets - - -if __name__=='__main__': - from mmrotate.utils import register_all_modules - register_all_modules() - fcos_head = RotatedFCOSHead(num_classes=2, in_channels=256, feat_channels=256, strides=[8, 16, 32, 64, 128], angle_version='le90') - print(RotatedFCOSHead.mro()) \ No newline at end of file diff --git a/mmrotate/models/dense_heads/rotated_rtmdet_head.py b/mmrotate/models/dense_heads/rotated_rtmdet_head.py index 038fc6058..8932abd99 100644 --- a/mmrotate/models/dense_heads/rotated_rtmdet_head.py +++ b/mmrotate/models/dense_heads/rotated_rtmdet_head.py @@ -6,9 +6,7 @@ from mmdet.models import inverse_sigmoid from mmdet.models.dense_heads import RTMDetHead from mmdet.models.task_modules import anchor_inside_flags -from mmdet.models.utils import (multi_apply, - sigmoid_geometric_mean, - unmap) +from mmdet.models.utils import multi_apply, sigmoid_geometric_mean, unmap from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, distance2bbox from mmdet.utils import (ConfigType, InstanceList, OptConfigType, OptInstanceList, reduce_mean) @@ -16,9 +14,10 @@ from mmengine.structures import InstanceData from torch import Tensor, nn +from mmrotate.models.dense_heads.base_angle_dense_head import \ + BaseAngleDenseHead from mmrotate.registry import MODELS from mmrotate.structures import distance2obb -from mmrotate.models.dense_heads.base_angle_dense_head import BaseAngleDenseHead @MODELS.register_module() From 18ae83e831c3ee3140dfac63041f36ede1ecd67b Mon Sep 17 00:00:00 2001 From: liuyanyi Date: Tue, 4 Apr 2023 19:42:03 +0800 Subject: [PATCH 3/5] fix --- mmrotate/models/dense_heads/rotated_fcos_head.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmrotate/models/dense_heads/rotated_fcos_head.py b/mmrotate/models/dense_heads/rotated_fcos_head.py index 4088d6cc4..2414a106c 100644 --- a/mmrotate/models/dense_heads/rotated_fcos_head.py +++ b/mmrotate/models/dense_heads/rotated_fcos_head.py @@ -13,12 +13,12 @@ from mmrotate.models.dense_heads.base_angle_dense_head import \ BaseAngleDenseHead -from mmrotate.registry import TASK_UTILS +from mmrotate.registry import MODELS, TASK_UTILS INF = 1e8 -# @MODELS.register_module() +@MODELS.register_module() class RotatedFCOSHead(BaseAngleDenseHead, FCOSHead): """Anchor-free head used in `FCOS `_. From 9eac6cda9ad06af488b7d17753ea10cfd7d8667f Mon Sep 17 00:00:00 2001 From: liuyanyi Date: Wed, 5 Apr 2023 11:19:12 +0800 Subject: [PATCH 4/5] fix error --- mmrotate/models/dense_heads/base_angle_dense_head.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mmrotate/models/dense_heads/base_angle_dense_head.py b/mmrotate/models/dense_heads/base_angle_dense_head.py index f075fd5ea..81756c544 100644 --- a/mmrotate/models/dense_heads/base_angle_dense_head.py +++ b/mmrotate/models/dense_heads/base_angle_dense_head.py @@ -41,17 +41,18 @@ def __init__(self, self.angle_version = angle_version self.use_hbbox_loss = use_hbbox_loss self.angle_coder = TASK_UTILS.build(angle_coder) + # Commonly, BaseAngleDenseHead will be used with other head. + # So we call super here to init the other head. + # For example, RotatedFCOSHead will used with FCOSHead, + # so super here will call the init function of FCOSHead. + super().__init__(*args, **kwargs) + if loss_angle is not None: self.loss_angle = MODELS.build(loss_angle) else: self.loss_angle = None if self.use_hbbox_loss: assert self.loss_angle is not None - # Commonly, BaseAngleDenseHead will be used with other head. - # So we call super here to init the other head. - # For example, RotatedFCOSHead will used with FCOSHead, - # so super here will call the init function of FCOSHead. - super().__init__(*args, **kwargs) def predict_by_feat(self, cls_scores: List[Tensor], From 22e464f5badab8e6da3283f02dbd2df3891af251 Mon Sep 17 00:00:00 2001 From: liuyanyi Date: Wed, 5 Apr 2023 12:38:48 +0800 Subject: [PATCH 5/5] fix test error --- mmrotate/models/dense_heads/rotated_rtmdet_head.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mmrotate/models/dense_heads/rotated_rtmdet_head.py b/mmrotate/models/dense_heads/rotated_rtmdet_head.py index 8932abd99..1a2c1b82a 100644 --- a/mmrotate/models/dense_heads/rotated_rtmdet_head.py +++ b/mmrotate/models/dense_heads/rotated_rtmdet_head.py @@ -49,9 +49,9 @@ def __init__(self, self.is_scale_angle = scale_angle super().__init__( # Arguments of RTMDetHead - num_classes, - in_channels, - # useless, but error + num_classes=num_classes, + in_channels=in_channels, + # add scope to prevent error loss_centerness=dict( type='mmdet.CrossEntropyLoss', use_sigmoid=True, @@ -481,8 +481,8 @@ def __init__(self, assert scale_angle is False, \ 'scale_angle does not support in RotatedRTMDetSepBNHead' super().__init__( - num_classes, - in_channels, + num_classes=num_classes, + in_channels=in_channels, norm_cfg=norm_cfg, act_cfg=act_cfg, pred_kernel_size=pred_kernel_size,