diff --git a/mmrotate/datasets/__init__.py b/mmrotate/datasets/__init__.py index c0a39236f..5f7129275 100644 --- a/mmrotate/datasets/__init__.py +++ b/mmrotate/datasets/__init__.py @@ -3,9 +3,10 @@ from .dota import DOTAv2Dataset # noqa: F401, F403 from .dota import DOTADataset, DOTAv15Dataset from .hrsc import HRSCDataset # noqa: F401, F403 +from .ohd_sjtu import OHD_SJTUDataset_L, OHD_SJTUDataset_S from .transforms import * # noqa: F401, F403 __all__ = [ 'DOTADataset', 'DOTAv15Dataset', 'DOTAv2Dataset', 'HRSCDataset', - 'DIORDataset' + 'DIORDataset', 'OHD_SJTUDataset_S', 'OHD_SJTUDataset_L' ] diff --git a/mmrotate/datasets/ohd_sjtu.py b/mmrotate/datasets/ohd_sjtu.py new file mode 100644 index 000000000..43cf8d3b3 --- /dev/null +++ b/mmrotate/datasets/ohd_sjtu.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import glob +import os.path as osp +from typing import List, Tuple + +from mmengine.dataset import BaseDataset + +from mmrotate.registry import DATASETS + + +@DATASETS.register_module() +class OHD_SJTUDataset_S(BaseDataset): + """OHD-SJTU-S dataset for detection. + + Note: 'ann_file' in OHD_SJTUDataset is different from the BaseDataset. + In BaseDataset, it is the path of an annotation file. In OHD_SJTUDataset, + it is the path of a folder containing txt files. + + Args: + img_shape (tuple[int]): + diff_thr (int): + """ + + METAINFO = { + 'classes': ('ship', 'plane'), + # palette is a list of color tuples, which is used for visualization. + 'palette': [(165, 42, 42), (189, 183, 107)] + } + + def __init__(self, + img_shape: Tuple[int, int] = (1024, 1024), + diff_thr: int = 100, + **kwargs) -> None: + self.img_shape = img_shape + self.diff_thr = diff_thr + super().__init__(**kwargs) + + def load_data_list(self) -> List[dict]: + """Load annotations from an annotation file named as 'self.ann_file'. + + Returns: + List[dict]: A list of annotation. + """ + assert self._metainfo.get('classes', None) is not None, \ + "classes in 'OHD-SJTUDataset' can not be None" + + cls_map = {c: i for i, c in enumerate(self.metainfo['classes'])} + data_list = [] + if self.ann_file == '': + img_files = glob.glob( + osp.join(self.data_prefix['img_path'], '*.png')) + for img_path in img_files: + data_info = {'img_path': img_path} + img_name = osp.split(img_path)[1] + data_info['file_name'] = img_name + img_id = img_name[:-4] + data_info['img_id'] = img_id + data_info['height'] = self.img_shape[0] + data_info['width'] = self.img_shape[1] + + instance = dict( + bbox=[], bbox_head=[], bbox_label=[], ignore_flag=0) + data_info['instances'] = [instance] + data_list.append(data_info) + return data_list + else: + txt_files = glob.glob(osp.join(self.ann_file, '*.txt')) + if len(txt_files) == 0: + raise ValueError('There is no txt file in ' + f'{self.ann_file}') + for txt_file in txt_files: + img_id = osp.split(txt_file)[1][:-4] + data_info = {'img_id': img_id} + img_name = img_id + '.png' + data_info['file_name'] = img_name + data_info['img_path'] = osp.join(self.data_prefix['img_path'], + img_name) + data_info['height'] = self.img_shape[0] + data_info['width'] = self.img_shape[1] + + instances = [] + with open(txt_file) as f: + contents = f.readlines() + for content in contents: + bbox_info = content.split(' ') + instance = {'bbox': [float(i) for i in bbox_info[:10]]} + cls_name = bbox_info[-2] + instance['bbox_label'] = cls_map[cls_name] + difficulty = int(bbox_info[-1]) + if difficulty > self.diff_thr: + instance['ignore_flag'] = 1 + else: + instance['ignore_flag'] = 0 + instances.append(instance) + data_info['instances'] = instances + data_list.append(data_info) + print(len(data_list)) + return data_list + + def filter_data(self) -> List[dict]: + """Filter annotations according to filter_cfg. + + Returns: + List[dict]: Filtered results. + """ + if self.test_mode: + return self.data_list + + filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) \ + if self.filter_cfg is not None else False + + valid_data_infos = [] + for i, data_info in enumerate(self.data_list): + if filter_empty_gt and len(data_info['instances']) == 0: + continue + valid_data_infos.append(data_info) + + return valid_data_infos + + def get_cat_ids(self, idx: int) -> List[int]: + """Get OHD-SJTU category ids by index. + + Args: + idx(int): Index of data. + Returns: + List[int]: All categories in the image of specified index. + """ + + instances = self.get_data_info(idx)['instances'] + return [instance['bbox_label'] for instance in instances] + + +@DATASETS.register_module() +class OHD_SJTUDataset_L(OHD_SJTUDataset_S): + """OHD-SJTU-L dataset for detection. + + Note: 'ann_file' in OHD_SJTUDataset is different from the BaseDataset. + In BaseDataset, it is the path of an annotation file. In OHD_SJTUDataset, + it is the path of a folder containing txt files. + """ + + METAINFO = { + 'classes': ('ship', 'plane', 'small-vehicle', 'large-vehicle', + 'harbor', 'helicopter'), + # palette is a list of color tuples, which is used for visualization. + 'palette': [(165, 42, 42), (189, 183, 107), (0, 255, 0), (255, 0, 0), + (138, 43, 226), (255, 128, 0)] + } diff --git a/mmrotate/evaluation/functional/mean_ap.py b/mmrotate/evaluation/functional/mean_ap.py index e4baf1e3b..ec6260394 100644 --- a/mmrotate/evaluation/functional/mean_ap.py +++ b/mmrotate/evaluation/functional/mean_ap.py @@ -18,14 +18,17 @@ def tpfp_default(det_bboxes, """Check if detected bboxes are true positive or false positive. Args: - det_bboxes (ndarray): Detected bboxes of this image, of shape (m, 6). - gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 5). + det_bboxes (ndarray): Detected bboxes of this image, of shape (m, 6) + or (m, 7). + gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 5) + or (n, 6). gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image, - of shape (k, 5). Defaults to None + of shape (k, 5) or (k, 6). Defaults to None iou_thr (float): IoU threshold to be considered as matched. Defaults to 0.5. - box_type (str): Box type. If the QuadriBoxes is used, you need to - specify 'qbox'. Defaults to 'rbox'. + box_type (str): Box type. If the QuadriBoxes or RotatedHeadBoxes is + used, you need to specify 'qbox' or 'rheadbox'. Defaults to + 'rbox'. area_ranges (list[tuple], optional): Range of bbox areas to be evaluated, in the format [(min1, max1), (min2, max2), ...]. Defaults to None. @@ -51,6 +54,7 @@ def tpfp_default(det_bboxes, # a certain scale tp = np.zeros((num_scales, num_dets), dtype=np.float32) fp = np.zeros((num_scales, num_dets), dtype=np.float32) + tp_head = np.zeros((num_scales, num_dets), dtype=np.float32) # if there is no gt bboxes in this image, then all det bboxes # within area range are false positives @@ -59,12 +63,19 @@ def tpfp_default(det_bboxes, fp[...] = 1 else: raise NotImplementedError - return tp, fp + if box_type == 'rheadbox': + return tp, fp, tp_head + else: + return tp, fp if box_type == 'rbox': ious = box_iou_rotated( torch.from_numpy(det_bboxes).float(), torch.from_numpy(gt_bboxes).float()).numpy() + elif box_type == 'rheadbox': + ious = box_iou_rotated( + torch.from_numpy(det_bboxes[..., :5]).float(), + torch.from_numpy(gt_bboxes[..., :5]).float()).numpy() elif box_type == 'qbox': ious = box_iou_quadri( torch.from_numpy(det_bboxes).float(), @@ -91,7 +102,12 @@ def tpfp_default(det_bboxes, or gt_area_ignore[matched_gt]): if not gt_covered[matched_gt]: gt_covered[matched_gt] = True - tp[k, i] = 1 + if box_type == 'rheadbox': + if gt_bboxes[matched_gt, -1] == det_bboxes[i, -2]: + tp[k, i] = 1 + tp_head[k, i] = 1 + else: + tp[k, i] = 1 else: fp[k, i] = 1 # otherwise ignore this detected bbox, tp = 0, fp = 0 @@ -101,6 +117,9 @@ def tpfp_default(det_bboxes, if box_type == 'rbox': bbox = det_bboxes[i, :5] area = bbox[2] * bbox[3] + elif box_type == 'rheadbox': + bbox = det_bboxes[i, :5] + area = bbox[2] * bbox[3] elif box_type == 'qbox': bbox = det_bboxes[i, :8] pts = bbox.reshape(*bbox.shape[:-1], 4, 2) @@ -114,7 +133,10 @@ def tpfp_default(det_bboxes, raise NotImplementedError if area >= min_area and area < max_area: fp[k, i] = 1 - return tp, fp + if box_type == 'rheadbox': + return tp, fp, tp_head + else: + return tp, fp def get_cls_results(det_results, annotations, class_id, box_type): @@ -147,6 +169,9 @@ def get_cls_results(det_results, annotations, class_id, box_type): elif box_type == 'qbox': cls_gts.append(torch.zeros((0, 8), dtype=torch.float64)) cls_gts_ignore.append(torch.zeros((0, 8), dtype=torch.float64)) + elif box_type == 'rheadbox': + cls_gts.append(torch.zeros((0, 6), dtype=torch.float64)) + cls_gts_ignore.append(torch.zeros((0, 6))) else: raise NotImplementedError @@ -171,9 +196,10 @@ def eval_rbbox_map(det_results, annotations (list[dict]): Ground truth annotations where each item of the list indicates an image. Keys of annotations are: - - `bboxes`: numpy array of shape (n, 5) + - `bboxes`: numpy array of shape (n, 5) or (n, 6) - `labels`: numpy array of shape (n, ) - `bboxes_ignore` (optional): numpy array of shape (k, 5) + or (k, 6) - `labels_ignore` (optional): numpy array of shape (k, ) scale_ranges (list[tuple], optional): Range of scales to be evaluated, in the format [(min1, max1), (min2, max2), ...]. A range of @@ -212,13 +238,22 @@ def eval_rbbox_map(det_results, det_results, annotations, i, box_type) # compute tp and fp for each image with multiple processes - tpfp = pool.starmap( - tpfp_default, - zip(cls_dets, cls_gts, cls_gts_ignore, - [iou_thr for _ in range(num_imgs)], - [box_type for _ in range(num_imgs)], - [area_ranges for _ in range(num_imgs)])) - tp, fp = tuple(zip(*tpfp)) + if box_type == 'rheadbox': + tpfphead = pool.starmap( + tpfp_default, + zip(cls_dets, cls_gts, cls_gts_ignore, + [iou_thr for _ in range(num_imgs)], + [box_type for _ in range(num_imgs)], + [area_ranges for _ in range(num_imgs)])) + tp, fp, tp_head = tuple(zip(*tpfphead)) + else: + tpfp = pool.starmap( + tpfp_default, + zip(cls_dets, cls_gts, cls_gts_ignore, + [iou_thr for _ in range(num_imgs)], + [box_type for _ in range(num_imgs)], + [area_ranges for _ in range(num_imgs)])) + tp, fp = tuple(zip(*tpfp)) # calculate gt number of each scale # ignored gts or gts beyond the specific scale are not counted num_gts = np.zeros(num_scales, dtype=int) @@ -226,7 +261,7 @@ def eval_rbbox_map(det_results, if area_ranges is None: num_gts[0] += bbox.shape[0] else: - if box_type == 'rbox': + if box_type == 'rbox' or box_type == 'rheadbox': gt_areas = bbox[:, 2] * bbox[:, 3] elif box_type == 'qbox': pts = bbox.reshape(*bbox.shape[:-1], 4, 2) @@ -267,6 +302,13 @@ def eval_rbbox_map(det_results, 'precision': precisions, 'ap': ap }) + if box_type == 'rheadbox': + tp_head = np.hstack(tp_head) + head_acc = [] + for tp_s_head in tp_head: + head_acc.append(sum(tp_s_head) / len(tp_s_head)) + head_acc = np.array(head_acc) + eval_results[-1]['head_acc'] = head_acc pool.close() if scale_ranges is not None: # shape (num_classes, num_scales) @@ -286,16 +328,43 @@ def eval_rbbox_map(det_results, aps.append(cls_result['ap']) mean_ap = np.array(aps).mean().item() if aps else 0.0 - print_map_summary( - mean_ap, eval_results, dataset, area_ranges, logger=logger) - - return mean_ap, eval_results + if box_type == 'rheadbox': + if scale_ranges is not None: + all_head_acc = np.vstack( + [cls_result['head_acc'] for cls_result in eval_results]) + mean_head_acc = [] + for i in range(num_scales): + if np.any(all_num_gts[:, i] > 0): + mean_head_acc.append(all_head_acc[all_num_gts[:, i] > 0, + i].mean()) + else: + mean_head_acc.append(0.0) + else: + head_accs = [] + for cls_result in eval_results: + if cls_result['num_gts'] > 0: + head_accs.append(cls_result['head_acc']) + mean_head_acc = np.array( + head_accs).mean().item() if head_accs else 0.0 + print_map_summary( + mean_ap, + eval_results, + dataset, + area_ranges, + mean_head_acc=mean_head_acc, + logger=logger) + return mean_ap, mean_head_acc, eval_results + else: + print_map_summary( + mean_ap, eval_results, dataset, area_ranges, logger=logger) + return mean_ap, eval_results def print_map_summary(mean_ap, results, dataset=None, scale_ranges=None, + mean_head_acc=None, logger=None): """Print mAP and results of each class. @@ -307,6 +376,8 @@ def print_map_summary(mean_ap, results (list[dict]): Calculated from `eval_map()`. dataset (list[str] | str, optional): Dataset name or dataset classes. scale_ranges (list[tuple], optional): Range of scales to be evaluated. + mean_head_acc (float): Calculated from 'eval_map()', mean head + detection accuracy, if None, no head detection. logger (logging.Logger | str, optional): The way to print the mAP summary. See `mmcv.utils.print_log()` for details. Defaults to None. @@ -328,11 +399,13 @@ def print_map_summary(mean_ap, recalls = np.zeros((num_scales, num_classes), dtype=np.float32) aps = np.zeros((num_scales, num_classes), dtype=np.float32) num_gts = np.zeros((num_scales, num_classes), dtype=int) + head_accs = np.zeros((num_scales, num_classes), dtype=np.float32) for i, cls_result in enumerate(results): if cls_result['recall'].size > 0: recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1] aps[:, i] = cls_result['ap'] num_gts[:, i] = cls_result['num_gts'] + head_accs[:, i] = cls_result['head_acc'] if mean_head_acc else None if dataset is None: label_names = [str(i) for i in range(num_classes)] @@ -341,19 +414,35 @@ def print_map_summary(mean_ap, if not isinstance(mean_ap, list): mean_ap = [mean_ap] + if mean_head_acc and not isinstance(mean_head_acc, list): + mean_head_acc = [mean_head_acc] - header = ['class', 'gts', 'dets', 'recall', 'ap'] + header = ['class', 'gts', 'dets', 'recall', 'ap', 'head_acc'] if \ + mean_head_acc else ['class', 'gts', 'dets', 'recall', 'ap'] for i in range(num_scales): if scale_ranges is not None: print_log(f'Scale range {scale_ranges[i]}', logger=logger) table_data = [header] for j in range(num_classes): - row_data = [ - label_names[j], num_gts[i, j], results[j]['num_dets'], - f'{recalls[i, j]:.3f}', f'{aps[i, j]:.3f}' - ] + if mean_head_acc: + row_data = [ + label_names[j], num_gts[i, j], results[j]['num_dets'], + f'{recalls[i, j]:.3f}', f'{aps[i, j]:.3f}', + f'{head_accs[i, j]:.3f}' + ] + else: + row_data = [ + label_names[j], num_gts[i, j], results[j]['num_dets'], + f'{recalls[i, j]:.3f}', f'{aps[i, j]:.3f}' + ] table_data.append(row_data) - table_data.append(['mAP', '', '', '', f'{mean_ap[i]:.3f}']) + if mean_head_acc: + table_data.append([ + 'mAP', '', '', '', f'{mean_ap[i]:.3f}', + f'{mean_head_acc[i]:.3f}' + ]) + else: + table_data.append(['mAP', '', '', '', f'{mean_ap[i]:.3f}']) table = AsciiTable(table_data) table.inner_footing_row_border = True print_log('\n' + table.table, logger=logger) diff --git a/mmrotate/evaluation/metrics/__init__.py b/mmrotate/evaluation/metrics/__init__.py index 0de029ff3..399155434 100644 --- a/mmrotate/evaluation/metrics/__init__.py +++ b/mmrotate/evaluation/metrics/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .dota_metric import DOTAMetric +from .ohd_sjtu_metric import OHD_SJTUMetric from .rotated_coco_metric import RotatedCocoMetric -__all__ = ['DOTAMetric', 'RotatedCocoMetric'] +__all__ = ['DOTAMetric', 'RotatedCocoMetric', 'OHD_SJTUMetric'] diff --git a/mmrotate/evaluation/metrics/ohd_sjtu_metric.py b/mmrotate/evaluation/metrics/ohd_sjtu_metric.py new file mode 100644 index 000000000..218fc233b --- /dev/null +++ b/mmrotate/evaluation/metrics/ohd_sjtu_metric.py @@ -0,0 +1,187 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +import tempfile +from collections import OrderedDict +from typing import List, Optional, Sequence, Union + +import numpy as np +import torch +from mmengine.logging import MMLogger +from torch import Tensor + +from mmrotate.evaluation import eval_rbbox_map +from mmrotate.registry import METRICS +from .dota_metric import DOTAMetric + + +@METRICS.register_module() +class OHD_SJTUMetric(DOTAMetric): + """OHD-SJTU evaluation metric, adding box head evaluation based on + DOTAMetric. + + Note: In addition to format the output results to JSON like CocoMetric, + it can also generate the full image's results by merging patches' results. + The premise is that you must use the tool provided by us to crop the + OHD-SJTU large images, which can be found at: ``tools/data/dota/split``. + + Args: + """ + + default_prefix: Optional[str] = 'ohd-sjtu' + + def __init__(self, + *args, + metric: Union[str, List[str]] = 'mAP', + predict_box_type: str = 'rheadbox'): + super().__init__( + *args, metric=metric, predict_box_type=predict_box_type) + + def get_head_quadrant(self, rheadboxes: Tensor): + """convert headxys in rheadboxes into head quadrants. + + Args: + rheadboxes (Tensor): rotated head boxes, each of item is + [x_c, y_c, w, h, t, head_x, head_y]. + + Return: + Tensor: each of item is [x_c, y_c, w, h, t, head_quadrant] + """ + center_xys, head_xys = rheadboxes[..., :2], rheadboxes[..., -2:] + original_shape = head_xys.shape[:-1] + assert center_xys.size(-1) == 2 and head_xys.size(-1) == 2, \ + ('The last dimension of two input params must be 2, representing ' + f'xy coordinates, but got center_xys {center_xys.shape}, ' + f'head_xys {head_xys.shape}.') + head_quadrants = [] + for center_xy, head_xy in zip(center_xys, head_xys): + delta_x = head_xy[0] - center_xy[0] + delta_y = head_xy[1] - center_xy[1] + if (delta_x >= 0) and (delta_y >= 0): + head_quadrants.append(0) + elif (delta_x >= 0) and (delta_y <= 0): + head_quadrants.append(1) + elif (delta_x <= 0) and (delta_y <= 0): + head_quadrants.append(2) + else: + head_quadrants.append(3) + head_quadrants = head_xys.new_tensor(head_quadrants) + head_quadrants = head_quadrants.view(*original_shape, 1) + return torch.cat([rheadboxes[..., :5], head_quadrants], dim=-1) + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. + """ + for data_sample in data_samples: + gt = copy.deepcopy(data_sample) + gt_instances = gt['gt_instances'] + gt_ignore_instances = gt['ignored_instances'] + if gt_instances == {}: + ann = dict() + else: + rheadbboxes = gt_instances['bboxes'] + rhead_ignore_bboxes = gt_ignore_instances['bboxes'] + # convert head xy in bboxes into head quadrant + rbboxes_hqu = self.get_head_quadrant(rheadbboxes) + if rhead_ignore_bboxes.shape[0] == 0: + rbboxes_hqu_ignore = rhead_ignore_bboxes[..., :6] + else: + rbboxes_hqu_ignore = self.get_head_quadrant( + rhead_ignore_bboxes) + ann = dict( + labels=gt_instances['labels'].cpu().numpy(), + bboxes=rbboxes_hqu.cpu().numpy(), + bboxes_ignore=rbboxes_hqu_ignore.cpu().numpy(), + labels_ignore=gt_ignore_instances['labels'].cpu().numpy()) + result = dict() + pred = data_sample['pred_instances'] + result['img_id'] = data_sample['img_id'] + result['bboxes'] = pred['bboxes'].cpu().numpy() + result['heads'] = pred['heads'].cpu().numpy() + result['scores'] = pred['scores'].cpu().numpy() + result['labels'] = pred['labels'].cpu().numpy() + result['pred_bbox_head_scores'] = [] + # get prediction in each class + for label in range(len(self.dataset_meta['classes'])): + index = np.where(result['labels'] == label)[0] + pred_bbox_head_scores = np.hstack([ + result['bboxes'][index], + result['heads'][index].reshape(-1, 1), + result['scores'][index].reshape(-1, 1) + ]) + result['pred_bbox_head_scores'].append(pred_bbox_head_scores) + + self.results.append((ann, result)) + + def compute_metrics(self, results: list) -> dict: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + Returns: + dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + # anns, results + gts, preds = zip(*results) + + tmp_dir = None + if self.outfile_prefix is None: + tmp_dir = tempfile.TemporaryDirectory() + outfile_prefix = osp.join(tmp_dir.name, 'results') + else: + outfile_prefix = self.outfile_prefix + + eval_results = OrderedDict() + if self.merge_patches: + # convert predictions to txt format and dump to zip file + zip_path = self.merge_results(preds, outfile_prefix) + logger.info(f'The submission file save at {zip_path}') + return eval_results + else: + # convert predictions to coco format and dump to json file + _ = self.results2json(preds, outfile_prefix) + if self.format_only: + logger.info('results are saved in ' + f'{osp.dirname(outfile_prefix)}') + return eval_results + + if self.metric == 'mAP': + assert isinstance(self.iou_thrs, list) + dataset_name = self.dataset_meta['classes'] + dets = [pred['pred_bbox_head_scores'] for pred in preds] + + mean_aps = [] + mean_heads = [] + for iou_thr in self.iou_thrs: + logger.info(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}') + mean_ap, mean_head_acc, _ = eval_rbbox_map( + dets, + gts, + scale_ranges=self.scale_ranges, + iou_thr=iou_thr, + use_07_metric=self.use_07_metric, + box_type=self.predict_box_type, + dataset=dataset_name, + logger=logger) + mean_aps.append(mean_ap) + mean_heads.append(mean_head_acc) + eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3) + eval_results[f'head_acc{int(iou_thr * 100):02d}'] = round( + mean_head_acc, 3) + eval_results['mAP'] = sum(mean_aps) / len(mean_aps) + eval_results['mean_head_acc'] = sum(mean_heads) / len(mean_heads) + eval_results.move_to_end('mean_head_acc', last=False) + eval_results.move_to_end('mAP', last=False) + else: + raise NotImplementedError + return eval_results diff --git a/mmrotate/models/dense_heads/rotated_retina_head.py b/mmrotate/models/dense_heads/rotated_retina_head.py index 931a3d7c3..8315cc6a3 100644 --- a/mmrotate/models/dense_heads/rotated_retina_head.py +++ b/mmrotate/models/dense_heads/rotated_retina_head.py @@ -1,9 +1,21 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List, Optional, Union + +import torch +import torch.nn as nn from mmdet.models.dense_heads import RetinaHead -from mmdet.structures.bbox import get_box_tensor +from mmdet.models.task_modules.prior_generators import anchor_inside_flags +from mmdet.models.utils import (filter_scores_and_topk, images_to_levels, + multi_apply, select_single_mlvl, unmap) +from mmdet.structures.bbox import BaseBoxes, cat_boxes, get_box_tensor +from mmdet.utils import ConfigType, InstanceList, OptInstanceList +from mmengine.config import ConfigDict +from mmengine.structures import InstanceData from torch import Tensor from mmrotate.registry import MODELS +from mmrotate.structures.bbox import RotatedHeadBoxes @MODELS.register_module() @@ -13,19 +25,253 @@ class RotatedRetinaHead(RetinaHead): Args: loss_bbox_type (str): Set the input type of ``loss_bbox``. Defaults to 'normal'. + head_detec (bool): Whether detect the head of obj. + Defaults to 'True'. """ def __init__(self, *args, loss_bbox_type: str = 'normal', - **kwargs) -> None: + head_detec: bool = True, + loss_head: ConfigType = dict( + type='mmdet.FocalLoss', use_sigmoid=True, + loss_weight=1.0), + **kwargs): + self.head_detec = head_detec super().__init__(*args, **kwargs) self.loss_bbox_type = loss_bbox_type + self.loss_head = MODELS.build(loss_head) + + def _init_layers(self): + """Initialize layers of the head.""" + super()._init_layers() + if self.head_detec: + self.conv_head = nn.Conv2d( + self.in_channels, self.num_base_priors * 4, 3, padding=1) + + def forward_single(self, x): + """Forward feature of a single scale level. + + Args: + x (Tensor): Features of a single scale level. + + Returns: + tuple: + cls_score (Tensor): Cls scores for a single scale level + the channels number is num_anchors * num_classes. + bbox_pred (Tensor): Box energies / deltas for a single scale + level, the channels number is num_anchors * 5. + head_cls (Tensor): Box head for a single scale level, the + channels number is num_anchors * 4.(Optional) + """ + cls_feat = x + reg_feat = x + for cls_conv in self.cls_convs: + cls_feat = cls_conv(cls_feat) + for reg_conv in self.reg_convs: + reg_feat = reg_conv(reg_feat) + cls_score = self.retina_cls(cls_feat) + bbox_pred = self.retina_reg(reg_feat) + if self.head_detec: + head_cls = self.conv_head(reg_feat) + else: + head_cls = None + return cls_score, bbox_pred, head_cls + + def _get_targets_single(self, + flat_anchors: Union[Tensor, BaseBoxes], + valid_flags: Tensor, + gt_instances: InstanceData, + img_meta: dict, + gt_instances_ignore: Optional[InstanceData] = None, + unmap_outputs: bool = True) -> tuple: + """Compute regression and classification targets for anchors in a + single image. + + Args: + flat_anchors (Tensor or :obj:`BaseBoxes`): Multi-level anchors + of the image, which are concatenated into a single tensor + or box type of shape (num_anchors, 4) + valid_flags (Tensor): Multi level valid flags of the image, + which are concatenated into a single tensor of + shape (num_anchors, ). + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for current image. + gt_instances_ignore (:obj:`InstanceData`, optional): Instances + to be ignored during training. It includes ``bboxes`` + attribute data that is ignored during training and testing. + Defaults to None. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. Defaults to True. + + Returns: + tuple: + + - labels (Tensor): Labels of each level. + - label_weights (Tensor): Label weights of each level. + - bbox_targets (Tensor): BBox targets of each level. + - bbox_weights (Tensor): BBox weights of each level. + - pos_inds (Tensor): positive samples indexes. + - neg_inds (Tensor): negative samples indexes. + - sampling_result (:obj:`SamplingResult`): Sampling results. + - head_targets (Tensor): BBox head targets of each level. + """ + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + self.train_cfg['allowed_border']) + if not inside_flags.any(): + raise ValueError( + 'There is no valid anchor inside the image boundary. Please ' + 'check the image size and anchor sizes, or set ' + '``allowed_border`` to -1 to skip the condition.') + # filter anchors within valid border + anchors = flat_anchors[inside_flags] + pred_instances = InstanceData(priors=anchors) + # assign gt with filtered anchors + assign_result = self.assigner.assign(pred_instances, gt_instances, + gt_instances_ignore) + sampling_result = self.sampler.sample(assign_result, pred_instances, + gt_instances) + num_valid_anchors = anchors.shape[0] + # if gtboxes is RotatedHeadBoxes, the boxes dim must -2 + # because the last two is box headx and heady + if isinstance(gt_instances.bboxes, RotatedHeadBoxes): + target_dim = (gt_instances.bboxes.size(-1) - 2) if \ + self.reg_decoded_bbox else self.bbox_coder.encode_size + else: + target_dim = gt_instances.bboxes.size(-1) if \ + self.reg_decoded_bbox else self.bbox_coder.encode_size + bbox_targets = anchors.new_zeros(num_valid_anchors, target_dim) + bbox_weights = anchors.new_zeros(num_valid_anchors, target_dim) + head_targets = anchors.new_full((num_valid_anchors, ), + -1, + dtype=torch.int) + gt_head_quadrants = gt_instances.bboxes.heads + + # TODO: Considering saving memory, is it necessary to be long? + labels = anchors.new_full((num_valid_anchors, ), + self.num_classes, + dtype=torch.long) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + head_targets[pos_inds] = gt_head_quadrants[ + assign_result.gt_inds[pos_inds] - 1] + if not self.reg_decoded_bbox: + pos_bbox_targets = self.bbox_coder.encode( + sampling_result.pos_priors, sampling_result.pos_gt_bboxes) + else: + pos_bbox_targets = sampling_result.pos_gt_bboxes + pos_bbox_targets = get_box_tensor(pos_bbox_targets) + if pos_bbox_targets.size(-1) != 5: + pos_bbox_targets = pos_bbox_targets[..., :5] + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + labels[pos_inds] = sampling_result.pos_gt_labels + if self.train_cfg['pos_weight'] <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg['pos_weight'] + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.size(0) + labels = unmap( + labels, num_total_anchors, inside_flags, + fill=self.num_classes) # fill bg label + label_weights = unmap(label_weights, num_total_anchors, + inside_flags) + bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) + bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) + head_targets = unmap( + head_targets, num_total_anchors, inside_flags, fill=-1) + return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, + neg_inds, sampling_result, head_targets) + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + head_clses: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + has shape (N, num_anchors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 5, H, W). + head_clses (list[Tensor]): Box head scores for each scale level + has shape (N, num_anchors * 4, H, W) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict: A dictionary of loss components. + """ + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + device = cls_scores[0].device + + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + avg_factor, head_targets_list) = cls_reg_targets + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + # concat all level anchors and flags to a single tensor + concat_anchor_list = [] + for i in range(len(anchor_list)): + concat_anchor_list.append(cat_boxes(anchor_list[i])) + all_anchor_list = images_to_levels(concat_anchor_list, + num_level_anchors) + losses_cls, losses_box, losses_head = multi_apply( + self.loss_by_feat_single, + cls_scores, + bbox_preds, + head_clses, + all_anchor_list, + labels_list, + label_weights_list, + bbox_targets_list, + bbox_weights_list, + head_targets_list, + avg_factor=avg_factor) + if losses_head is None: + return dict(loss_cls=losses_cls, loss_bbox=losses_box) + else: + return dict( + loss_cls=losses_cls, + loss_bbox=losses_box, + loss_head=losses_head) def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor, - anchors: Tensor, labels: Tensor, + head_cls: Tensor, anchors: Tensor, labels: Tensor, label_weights: Tensor, bbox_targets: Tensor, - bbox_weights: Tensor, avg_factor: int) -> tuple: + bbox_weights: Tensor, head_targets: Tensor, + avg_factor: int) -> tuple: """Calculate the loss of a single scale level based on the features extracted by the detection head. @@ -33,7 +279,9 @@ def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor, cls_score (Tensor): Box scores for each scale level Has shape (N, num_anchors * num_classes, H, W). bbox_pred (Tensor): Box energies / deltas for each scale - level with shape (N, num_anchors * 4, H, W). + level with shape (N, num_anchors * 5, H, W). + head_cls (Tensor): Box head scores for each scale level with + shape (N, num_anchors * 4, H, W) anchors (Tensor): Box reference for each scale level with shape (N, num_total_anchors, 4). labels (Tensor): Labels of each anchors with shape @@ -44,6 +292,8 @@ def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor, weight shape (N, num_total_anchors, 4). bbox_weights (Tensor): BBox regression loss weights of each anchor with shape (N, num_total_anchors, 4). + head_targets (Tensor): BBox head scores of each anchor with shape + (N, num_total_anchors) avg_factor (int): Average factor that is used to average the loss. Returns: @@ -56,6 +306,15 @@ def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor, 1).reshape(-1, self.cls_out_channels) loss_cls = self.loss_cls( cls_score, labels, label_weights, avg_factor=avg_factor) + + # head loss + if self.head_detec: + head_targets = head_targets.reshape(-1).long() + head_cls = head_cls.permute(0, 2, 3, 1).reshape(-1, 4) + loss_head = self.loss_head(head_cls, head_targets) + else: + loss_head = None + # regression loss target_dim = bbox_targets.size(-1) bbox_targets = bbox_targets.reshape(-1, target_dim) @@ -93,4 +352,240 @@ def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor, else: raise NotImplementedError - return loss_cls, loss_bbox + return loss_cls, loss_bbox, loss_head + + def predict_by_feat(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + head_clses: List[Tensor], + score_factors: Optional[List[Tensor]] = None, + batch_img_metas: Optional[List[dict]] = None, + cfg: Optional[ConfigDict] = None, + rescale: bool = False, + with_nms: bool = True) -> InstanceList: + """Transform a batch of output features extracted from the head into + bbox results. + + Note: When score_factors is not None, the cls_scores are + usually multiplied by it then obtain the real score used in NMS, + such as CenterNess in FCOS, IoU branch in ATSS. + + Args: + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 5D-tensor, has shape + (batch_size, num_priors * 5, H, W). + head_clses (list[Tensor]): Box heads for all scale levels, + each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + score_factors (list[Tensor], optional): Score factor for + all scale level, each is a 4D-tensor, has shape + (batch_size, num_priors * 1, H, W). Defaults to None. + batch_img_metas (list[dict], Optional): Batch image meta info. + Defaults to None. + cfg (ConfigDict, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + list[:obj:`InstanceData`]: Object detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 5), + the last dimension 5 arrange as (x1, y1, x2, y2, theta). + - bboxes_head (Tensor): Has a shape (num_instances, 4), + the last dimension 4 represents scores of 4 head edges. + """ + assert len(cls_scores) == len(bbox_preds) + if score_factors is None: + # e.g. Retina, FreeAnchor, Foveabox, etc. + with_score_factors = False + else: + # e.g. FCOS, PAA, ATSS, AutoAssign, etc. + with_score_factors = True + assert len(cls_scores) == len(score_factors) + + num_levels = len(cls_scores) + featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] + mlvl_priors = self.prior_generator.grid_priors( + featmap_sizes, + dtype=cls_scores[0].dtype, + device=cls_scores[0].device) + + result_list = [] + for img_id in range(len(batch_img_metas)): + img_meta = batch_img_metas[img_id] + cls_score_list = select_single_mlvl( + cls_scores, img_id, detach=True) + bbox_pred_list = select_single_mlvl( + bbox_preds, img_id, detach=True) + if head_clses is None: + head_cls_list = None + else: + head_cls_list = select_single_mlvl( + head_clses, img_id, detach=True) + if with_score_factors: + score_factor_list = select_single_mlvl( + score_factors, img_id, detach=True) + else: + score_factor_list = [None for _ in range(num_levels)] + results = self._predict_by_feat_single( + cls_score_list=cls_score_list, + bbox_pred_list=bbox_pred_list, + head_cls_list=head_cls_list, + score_factor_list=score_factor_list, + mlvl_priors=mlvl_priors, + img_meta=img_meta, + cfg=cfg, + rescale=rescale, + with_nms=with_nms) + result_list.append(results) + return result_list + + def _predict_by_feat_single(self, + cls_score_list: List[Tensor], + bbox_pred_list: List[Tensor], + head_cls_list: List[Tensor], + score_factor_list: List[Tensor], + mlvl_priors: List[Tensor], + img_meta: dict, + cfg: ConfigDict, + rescale: bool = False, + with_nms: bool = True) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 5, H, W). + head_cls_list (list[Tensor]): Box head from all scale levels of a + single image, each item has shape (num_priors * 4, H, W) + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image, each item has shape + (num_priors * 1, H, W). + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid. In all + anchor-based methods, it has shape (num_priors, 4). In + all anchor-free methods, it has shape (num_priors, 2) + when `with_stride=True`, otherwise it still has shape + (num_priors, 4). + img_meta (dict): Image meta info. + cfg (mmengine.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + if not self.head_detec: + return super()._predict_by_feat_single( + cls_score_list=cls_score_list, + bbox_pred_list=bbox_pred_list, + score_factor_list=score_factor_list, + mlvl_priors=mlvl_priors, + img_meta=img_meta, + cfg=cfg, + rescale=rescale, + with_nms=with_nms) + else: + assert head_cls_list is not None, 'predict head_cls_list is None!' + if score_factor_list[0] is None: + # e.g. Retina, FreeAnchor, etc. + with_score_factors = False + else: + # e.g. FCOS, PAA, ATSS, etc. + with_score_factors = True + + cfg = self.test_cfg if cfg is None else cfg + cfg = copy.deepcopy(cfg) + img_shape = img_meta['img_shape'] + nms_pre = cfg.get('nms_pre', -1) + + mlvl_bbox_preds = [] + mlvl_head_clses = [] + mlvl_valid_priors = [] + mlvl_scores = [] + mlvl_labels = [] + if with_score_factors: + mlvl_score_factors = [] + else: + mlvl_score_factors = None + for level_idx, (cls_score, bbox_pred, head_cls, score_factor, priors)\ + in enumerate(zip(cls_score_list, bbox_pred_list, head_cls_list, + score_factor_list, mlvl_priors)): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + assert cls_score.size()[-2:] == head_cls.size()[-2:] + + dim = self.bbox_coder.encode_size + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, dim) + head_cls = head_cls.permute(1, 2, 0).reshape(-1, 4).sigmoid() + head_cls = torch.argmax(head_cls, dim=1) + if with_score_factors: + score_factor = score_factor.permute(1, 2, + 0).reshape(-1).sigmoid() + cls_score = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + if self.use_sigmoid_cls: + scores = cls_score.sigmoid() + else: + scores = cls_score.softmax(-1)[:, :-1] + score_thr = cfg.get('score_thr', 0) + + results = filter_scores_and_topk( + scores, score_thr, nms_pre, + dict(bbox_pred=bbox_pred, head_cls=head_cls, priors=priors)) + scores, labels, keep_idxs, filtered_results = results + bbox_pred = filtered_results['bbox_pred'] + head_cls = filtered_results['head_cls'] + priors = filtered_results['priors'] + if with_score_factors: + score_factor = score_factor[keep_idxs] + mlvl_score_factors.append(score_factor) + mlvl_bbox_preds.append(bbox_pred) + mlvl_head_clses.append(head_cls) + mlvl_valid_priors.append(priors) + mlvl_scores.append(scores) + mlvl_labels.append(labels) + bbox_pred = torch.cat(mlvl_bbox_preds) + priors = cat_boxes(mlvl_valid_priors) + bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape) + + results = InstanceData() + results.bboxes = bboxes + results.heads = torch.cat(mlvl_head_clses) + results.scores = torch.cat(mlvl_scores) + results.labels = torch.cat(mlvl_labels) + if with_score_factors: + results.score_factors = torch.cat(mlvl_score_factors) + return self._bbox_post_process( + results=results, + cfg=cfg, + rescale=rescale, + with_nms=with_nms, + img_meta=img_meta) diff --git a/mmrotate/models/task_modules/assigners/rotate_iou2d_calculator.py b/mmrotate/models/task_modules/assigners/rotate_iou2d_calculator.py index 3674f6f4a..3e882c357 100644 --- a/mmrotate/models/task_modules/assigners/rotate_iou2d_calculator.py +++ b/mmrotate/models/task_modules/assigners/rotate_iou2d_calculator.py @@ -6,7 +6,8 @@ from mmrotate.registry import TASK_UTILS from mmrotate.structures.bbox import (QuadriBoxes, RotatedBoxes, - fake_rbbox_overlaps, rbbox_overlaps) + RotatedHeadBoxes, fake_rbbox_overlaps, + rbbox_overlaps) @TASK_UTILS.register_module() @@ -23,10 +24,14 @@ def __call__(self, Args: bboxes1 (:obj:`RotatedBoxes` or Tensor): bboxes have shape (m, 5) in format, shape (m, 6) in - format. + format, shape (m, 7) in + format. + RotatedHeadBoxes format. bboxes2 (:obj:`RotatedBoxes` or Tensor): bboxes have shape (n, 5) in format, shape (n, 6) in - format, or be empty. + format, shape (m, 7) in + format. + RotatedHeadBoxes format or be empty. mode (str): 'iou' (intersection over union), 'iof' (intersection over foreground). Defaults to 'iou'. is_aligned (bool): If True, then m and n must be equal. @@ -35,16 +40,18 @@ def __call__(self, Returns: Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,) """ - assert bboxes1.size(-1) in [0, 5, 6] - assert bboxes2.size(-1) in [0, 5, 6] + assert bboxes1.size(-1) in [0, 5, 6, 7] + assert bboxes2.size(-1) in [0, 5, 6, 7] if bboxes1.size(-1) == 6: bboxes1 = bboxes1[..., :5] if bboxes2.size(-1) == 6: bboxes2 = bboxes2[..., :5] - bboxes1 = get_box_tensor(bboxes1) - bboxes2 = get_box_tensor(bboxes2) + bboxes1 = get_box_tensor(bboxes1)[..., :5] if \ + isinstance(bboxes1, RotatedHeadBoxes) else get_box_tensor(bboxes1) + bboxes2 = get_box_tensor(bboxes2)[..., :5] if \ + isinstance(bboxes2, RotatedHeadBoxes) else get_box_tensor(bboxes2) return rbbox_overlaps(bboxes1, bboxes2, mode, is_aligned) diff --git a/mmrotate/models/task_modules/coders/delta_xywht_rbbox_coder.py b/mmrotate/models/task_modules/coders/delta_xywht_rbbox_coder.py index 64d4d19ca..a0d33d556 100644 --- a/mmrotate/models/task_modules/coders/delta_xywht_rbbox_coder.py +++ b/mmrotate/models/task_modules/coders/delta_xywht_rbbox_coder.py @@ -8,7 +8,7 @@ from torch import Tensor from mmrotate.registry import TASK_UTILS -from mmrotate.structures.bbox import RotatedBoxes +from mmrotate.structures.bbox import RotatedBoxes, RotatedHeadBoxes from ....structures.bbox.transforms import norm_angle @@ -77,7 +77,8 @@ def encode(self, bboxes: RotatedBoxes, gt_bboxes: RotatedBoxes) -> Tensor: """ assert bboxes.size(0) == gt_bboxes.size(0) assert bboxes.size(-1) == 5 - assert gt_bboxes.size(-1) == 5 + # The length of last dim of bboxes in RotatedHeadBoxes is 7 + assert gt_bboxes.size(-1) == 5 or gt_bboxes.size(-1) == 7 return bbox2delta(bboxes, gt_bboxes, self.means, self.stds, self.angle_version, self.norm_factor, self.edge_swap, self.proj_xy) @@ -154,7 +155,11 @@ def bbox2delta(proposals: RotatedBoxes, Tensor: deltas with shape (N, 5), where columns represent dx, dy, dw, dh, dt. """ - assert proposals.size() == gts.size() + if isinstance(gts, RotatedHeadBoxes): + assert proposals.shape[:-1] == gts.shape[:-1] + assert proposals.size(-1) == (gts.size(-1) - 2) + else: + assert proposals.size() == gts.size() proposals = proposals.tensor gts = gts.regularize_boxes(angle_version) proposals = proposals.float() diff --git a/mmrotate/structures/bbox/__init__.py b/mmrotate/structures/bbox/__init__.py index 895ade012..93bc9f036 100644 --- a/mmrotate/structures/bbox/__init__.py +++ b/mmrotate/structures/bbox/__init__.py @@ -4,10 +4,12 @@ rbox2hbox, rbox2qbox) from .quadri_boxes import QuadriBoxes from .rotated_boxes import RotatedBoxes +from .rotated_head_boxes import RotatedHeadBoxes from .transforms import distance2obb, gaussian2bbox, gt2gaussian, norm_angle __all__ = [ 'QuadriBoxes', 'RotatedBoxes', 'hbox2rbox', 'hbox2qbox', 'rbox2hbox', 'rbox2qbox', 'qbox2hbox', 'qbox2rbox', 'gaussian2bbox', 'gt2gaussian', - 'norm_angle', 'rbbox_overlaps', 'fake_rbbox_overlaps', 'distance2obb' + 'norm_angle', 'rbbox_overlaps', 'fake_rbbox_overlaps', 'distance2obb', + 'RotatedHeadBoxes' ] diff --git a/mmrotate/structures/bbox/box_converters.py b/mmrotate/structures/bbox/box_converters.py index 460177175..092db4f2f 100644 --- a/mmrotate/structures/bbox/box_converters.py +++ b/mmrotate/structures/bbox/box_converters.py @@ -7,6 +7,7 @@ from .quadri_boxes import QuadriBoxes from .rotated_boxes import RotatedBoxes +from .rotated_head_boxes import RotatedHeadBoxes @register_box_converter(HorizontalBoxes, RotatedBoxes) @@ -113,3 +114,17 @@ def qbox2rbox(boxes: Tensor) -> Tensor: rboxes.append([x, y, w, h, angle / 180 * np.pi]) rboxes = boxes.new_tensor(rboxes) return rboxes.view(*original_shape, 5) + + +@register_box_converter(RotatedHeadBoxes, RotatedBoxes) +def rheadbox2rbox(boxes: Tensor) -> Tensor: + """Convert quadrilateral boxes to rotated boxes. + + Args: + boxes (Tensor): RotatedHeadBoxes box tensor with shape of (..., 7). + + Returns: + Tensor: Rotated box tensor with shape of (..., 5). + """ + rboxes = boxes[..., :5] + return rboxes diff --git a/mmrotate/structures/bbox/rotated_head_boxes.py b/mmrotate/structures/bbox/rotated_head_boxes.py new file mode 100644 index 000000000..b06abe698 --- /dev/null +++ b/mmrotate/structures/bbox/rotated_head_boxes.py @@ -0,0 +1,360 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Tuple, TypeVar, Union + +import cv2 +import numpy as np +import torch +from mmdet.structures.bbox import register_box +from mmdet.structures.mask import BitmapMasks, PolygonMasks +from torch import BoolTensor, Tensor + +from .rotated_boxes import RotatedBoxes + +T = TypeVar('T') +DeviceType = Union[str, torch.device] +MaskType = Union[BitmapMasks, PolygonMasks] + + +@register_box('rheadbox') +class RotatedHeadBoxes(RotatedBoxes): + """The rotated head box class used in MMRotate. + + The 'RotatedHeadBoxes' add the head of rotated bboxes. You can get the head + of rotated bboxes by property 'head', and all the other functions and + properties are exactly the same in RotatedBoxes. + + Args: + data (Tensor or np.ndarray or Sequence): The box data with shape + (..., 10). The + the head quadrant of rotated bboxes. + dtype (torch.dtype, Optional): data type of boxes. Defaults to None. + device (str or torch.device, Optional): device of boxes. + Default to None. + clone (bool): Whether clone ``boxes`` or not. Defaults to True. + """ + + box_dim = 7 + + def __init__(self, + data: Union[Tensor, np.ndarray], + dtype: torch.dtype = None, + device: DeviceType = None, + clone: bool = True) -> None: + if isinstance(data, (np.ndarray, Tensor, Sequence)): + data = torch.as_tensor(data) + else: + raise TypeError('boxes should be Tensor, ndarray, or Sequence, ', + f'but got {type(data)}') + if device is not None or dtype is not None: + data = data.to(dtype=dtype, device=device) + # Clone the data to avoid potential bugs + if clone: + data = data.clone() + # handle the empty input like [] + if data.numel() == 0: + data = data.reshape((-1, self.box_dim)) + assert data.dim() >= 2 and (data.size(-1) == self.box_dim or + data.size(-1) == 10),\ + ('The boxes dimension must >= 2 and the length of the last ' + f'dimension must be {self.box_dim}, but got boxes with ' + f'shape {data.shape}.') + if data.size(-1) == 10: + self.tensor = self.get_head_and_rbox(data[..., :-2], data[..., + -2:]) + elif data.size(-1) == self.box_dim: + self.tensor = data + else: + raise ValueError( + 'The length of last dim of data is not correct, }' + f'expect {self.box_dim}, but got data with shape {data.shape}.' + ) + + def get_head_and_rbox(self, qbox_data: Tensor, head_data: Tensor): + """get head quadrant and rbox params from qbox data and head data. + + Args: + qbox_data (Tensor): qbox format data, shape(..., 8). + head_data (Tensor): head xy data, shape(..., 2). + """ + assert qbox_data.size(-1) == 8 and head_data.size(-1) == 2, \ + ('The last dimension of two input params is incorrect, expect 2 ' + 'or 8, ' + f'but got qbox_data {qbox_data.shape}, head_data ' + f'{head_data.shape}.') + original_shape = qbox_data.shape[:-1] + points = qbox_data.cpu().numpy().reshape(-1, 4, 2) + rboxes = [] + for pts in points: + (x, y), (w, h), angle = cv2.minAreaRect(pts) + rboxes.append([x, y, w, h, angle / 180 * np.pi]) + rboxes = qbox_data.new_tensor(rboxes) + rboxes = rboxes.view(*original_shape, 5) + return torch.cat([rboxes, head_data], dim=-1) + + def get_head_quadrant(self, center_xys: Tensor, head_xys: Tensor): + """get head quadrant from head xy. + + Args: + center_xys (Tensor): bboxes center coordinates, shape (..., 2) + head_xys (Tensor): heads xy coordinates, shape (m, 2) + """ + center_xys = center_xys.reshape(-1, 2) + original_shape = head_xys.shape[:-1] + assert center_xys.size(-1) == 2 and head_xys.size(-1) == 2, \ + ('The last dimension of two input params must be 2, representing ' + f'xy coordinates, but got center_xys {center_xys.shape}, ' + f'head_xys {head_xys.shape}.') + head_quadrants = [] + for center_xy, head_xy in zip(center_xys, head_xys): + delta_x = head_xy[0] - center_xy[0] + delta_y = head_xy[1] - center_xy[1] + if (delta_x >= 0) and (delta_y >= 0): + head_quadrants.append(0) + elif (delta_x >= 0) and (delta_y <= 0): + head_quadrants.append(1) + elif (delta_x <= 0) and (delta_y <= 0): + head_quadrants.append(2) + else: + head_quadrants.append(3) + head_quadrants = head_xys.new_tensor(head_quadrants) + head_quadrants = head_quadrants.view(*original_shape) + return head_quadrants + + def regularize_boxes(self, + pattern: Optional[str] = None, + width_longer: bool = True, + start_angle: float = -90) -> Tensor: + """Regularize rotated boxes according to angle pattern and start angle. + + Args: + pattern (str, Optional): Regularization pattern. Can only be 'oc', + 'le90', or 'le135'. Defaults to None. + width_longer (bool): Whether to make sure width is larger than + height. Defaults to True. + start_angle (float): The starting angle of the box angle + represented in degrees. Defaults to -90. + + Returns: + Tensor: Regularized box tensor + """ + boxes, headxs, headys = self.tensor[..., :5], self.tensor[ + ..., -2], self.tensor[..., -1] + if pattern is not None: + if pattern == 'oc': + width_longer, start_angle = False, -90 + elif pattern == 'le90': + width_longer, start_angle = True, -90 + elif pattern == 'le135': + width_longer, start_angle = True, -45 + else: + raise ValueError("pattern only can be 'oc', 'le90', and" + f"'le135', but get {pattern}.") + start_angle = start_angle / 180 * np.pi + x, y, w, h, t = boxes.unbind(dim=-1) + if width_longer: + # swap edge and angle if h >= w + w_ = torch.where(w > h, w, h) + h_ = torch.where(w > h, h, w) + t = torch.where(w > h, t, t + np.pi / 2) + t = ((t - start_angle) % np.pi) + start_angle + else: + # swap edge and angle if angle > pi/2 + t = ((t - start_angle) % np.pi) + w_ = torch.where(t < np.pi / 2, w, h) + h_ = torch.where(t < np.pi / 2, h, w) + t = torch.where(t < np.pi / 2, t, t - np.pi / 2) + start_angle + self.tensor = torch.stack([x, y, w_, h_, t, headxs, headys], dim=-1) + return self.tensor[..., :5] + + @property + def heads(self) -> Tensor: + """Return a tensor representing the heads quadrant of boxes. + + if boxes have shape of (m, 8) or (m, 5), heads quadrant have shape of + (m,) + """ + center_xys, head_xys = self.tensor[..., :2], self.tensor[..., -2:] + assert center_xys.size(-1) == 2 and head_xys.size(-1) == 2, \ + ('The last dimension of two input params must be 2, representing ' + f'xy coordinates, but got center_xys {center_xys.shape}, ' + f'head_xys {head_xys.shape}.') + head_quadrants = self.get_head_quadrant(center_xys, head_xys) + return head_quadrants.int() + + @property + def head_xys(self) -> Tensor: + """Return a tensor representing the heads xy of boxes. + + if boxes have shape of (m, 8) or (m, 5), head_xys have shape of (m, 2) + """ + head_xys = self.tensor[..., -2:] + return head_xys + + def flip_(self, + img_shape: Tuple[int, int], + direction: str = 'horizontal') -> None: + """Flip boxes horizontally or vertically in-place. + + Args: + img_shape (Tuple[int, int]): A tuple of image height and width. + direction (str): Flip direction, options are "horizontal", + "vertical" and "diagonal". Defaults to "horizontal" + """ + assert direction in ['horizontal', 'vertical', 'diagonal'] + flipped = self.tensor + if direction == 'horizontal': + flipped[..., 0] = img_shape[1] - flipped[..., 0] + flipped[..., -2] = img_shape[1] - flipped[..., -2] + flipped[..., 4] = -flipped[..., 4] + elif direction == 'vertical': + flipped[..., 1] = img_shape[0] - flipped[..., 1] + flipped[..., -1] = img_shape[0] - flipped[..., -1] + flipped[..., 4] = -flipped[..., 4] + else: + flipped[..., 0] = img_shape[1] - flipped[..., 0] + flipped[..., -2] = img_shape[1] - flipped[..., -2] + flipped[..., 1] = img_shape[0] - flipped[..., 1] + flipped[..., -1] = img_shape[0] - flipped[..., -1] + + def rotate_(self, center: Tuple[float, float], angle: float) -> None: + """Rotate all boxes in-place. + + Args: + center (Tuple[float, float]): Rotation origin. + angle (float): Rotation angle represented in degree. Positive + values mean clockwise rotation. + """ + boxes = self.tensor + rotation_matrix = boxes.new_tensor( + cv2.getRotationMatrix2D(center, -angle, 1)) + centers, wh, t, head_xys = torch.split(boxes, [2, 2, 1, 2], dim=-1) + t = t + angle / 180 * np.pi + centers = torch.cat( + [centers, centers.new_ones(*centers.shape[:-1], 1)], dim=-1) + head_xys = torch.cat( + [head_xys, head_xys.new_ones(*head_xys.shape[:-1], 1)], dim=-1) + centers_T = torch.transpose(centers, -1, -2) + centers_T = torch.matmul(rotation_matrix, centers_T) + head_xys_T = torch.transpose(head_xys, -1, -2) + head_xys_T = torch.matmul(rotation_matrix, head_xys_T) + centers = torch.transpose(centers_T, -1, -2) + head_xys = torch.transpose(head_xys_T, -1, -2) + self.tensor = torch.cat([centers, wh, t, head_xys], dim=-1) + + def project_(self, homography_matrix: Union[Tensor, np.ndarray]) -> None: + """Geometric transformat boxes in-place. + + Args: + homography_matrix (Tensor or np.ndarray]): + Shape (3, 3) for geometric transformation. + """ + boxes, head_xys = self.tensor[..., :5], self.tensor[..., 5:] + if isinstance(homography_matrix, np.ndarray): + homography_matrix = boxes.new_tensor(homography_matrix) + corners = self.rbox2corner(boxes) + corners = torch.cat( + [corners, corners.new_ones(*corners.shape[:-1], 1)], dim=-1) + head_xys = torch.cat( + [head_xys, head_xys.new_ones(*head_xys.shape[:-1], 1)], dim=-1) + corners_T = torch.transpose(corners, -1, -2) + corners_T = torch.matmul(homography_matrix, corners_T) + head_xys_T = torch.transpose(head_xys, -1, -2) + head_xys_T = torch.matmul(homography_matrix, head_xys_T) + corners = torch.transpose(corners_T, -1, -2) + head_xys = torch.transpose(head_xys_T, -1, -2) + # Convert to homogeneous coordinates by normalization + corners = corners[..., :2] / corners[..., 2:3] + head_xys = head_xys[..., :2] / head_xys[..., 2:3] + boxes = self.corner2rbox(corners) + self.tensor = torch.cat([boxes, head_xys], dim=-1) + + def rescale_(self, scale_factor: Tuple[float, float]) -> None: + """Rescale boxes w.r.t. rescale_factor in-place. + + Note: + Both ``rescale_`` and ``resize_`` will enlarge or shrink boxes + w.r.t ``scale_facotr``. The difference is that ``resize_`` only + changes the width and the height of boxes, but ``rescale_`` also + rescales the box centers simultaneously. + + Args: + scale_factor (Tuple[float, float]): factors for scaling boxes. + The length should be 2. + """ + boxes, head_xys = self.tensor[..., :5], self.tensor[..., 5:] + assert len(scale_factor) == 2 + scale_x, scale_y = scale_factor + ctrs, w, h, t = torch.split(boxes, [2, 1, 1, 1], dim=-1) + cos_value, sin_value = torch.cos(t), torch.sin(t) + + # Refer to https://github.com/facebookresearch/detectron2/blob/main/detectron2/structures/rotated_boxes.py # noqa + # rescale centers and head_xys + ctrs = ctrs * ctrs.new_tensor([scale_x, scale_y]) + head_xys = head_xys * head_xys.new_tensor([scale_x, scale_y]) + # rescale width and height + w = w * torch.sqrt((scale_x * cos_value)**2 + (scale_y * sin_value)**2) + h = h * torch.sqrt((scale_x * sin_value)**2 + (scale_y * cos_value)**2) + # recalculate theta + t = torch.atan2(scale_x * sin_value, scale_y * cos_value) + self.tensor = torch.cat([ctrs, w, h, t, head_xys], dim=-1) + + def resize_(self, scale_factor: Tuple[float, float]) -> None: + """Resize the box width and height w.r.t scale_factor in-place. + + Note: + Both ``rescale_`` and ``resize_`` will enlarge or shrink boxes + w.r.t ``scale_facotr``. The difference is that ``resize_`` only + changes the width and the height of boxes, but ``rescale_`` also + rescales the box centers simultaneously. + + Args: + scale_factor (Tuple[float, float]): factors for scaling box + shapes. The length should be 2. + """ + boxes, head_xys = self.tensor[..., :5], self.tensor[..., 5:] + assert len(scale_factor) == 2 + ctrs, wh, t = torch.split(boxes, [2, 2, 1], dim=-1) + scale_factor = boxes.new_tensor(scale_factor) + wh = wh * scale_factor + self.tensor = torch.cat([ctrs, wh, t, head_xys], dim=-1) + + def find_inside_points(self, + points: Tensor, + is_aligned: bool = False, + eps: float = 0.01) -> BoolTensor: + """Find inside box points. Boxes dimension must be 2. + Args: + points (Tensor): Points coordinates. Has shape of (m, 2). + is_aligned (bool): Whether ``points`` has been aligned with boxes + or not. If True, the length of boxes and ``points`` should be + the same. Defaults to False. + eps (float): Make sure the points are inside not on the boundary. + Defaults to 0.01. + + Returns: + BoolTensor: A BoolTensor indicating whether the box is inside the + image. Assuming the boxes has shape of (n, 5), if ``is_aligned`` + is False. The index has shape of (m, n). If ``is_aligned`` is True, + m should be equal to n and the index has shape of (m, ). + """ + boxes = self.tensor[..., :5] + assert boxes.dim() == 2, 'boxes dimension must be 2.' + + if not is_aligned: + boxes = boxes[None, :, :] + points = points[:, None, :] + else: + assert boxes.size(0) == points.size(0) + + ctrs, wh, t = torch.split(boxes, [2, 2, 1], dim=-1) + cos_value, sin_value = torch.cos(t), torch.sin(t) + matrix = torch.cat([cos_value, sin_value, -sin_value, cos_value], + dim=-1).reshape(*boxes.shape[:-1], 2, 2) + + offset = points - ctrs + offset = torch.matmul(matrix, offset[..., None]) + offset = offset.squeeze(-1) + offset_x, offset_y = offset[..., 0], offset[..., 1] + w, h = wh[..., 0], wh[..., 1] + return (offset_x <= w / 2 - eps) & (offset_x >= - w / 2 + eps) & \ + (offset_y <= h / 2 - eps) & (offset_y >= - h / 2 + eps)