Skip to content

Commit

Permalink
Refactor mask_target_single function to handle unsupported ground tru…
Browse files Browse the repository at this point in the history
…th mask types and provide warnings for missing ground truth masks
  • Loading branch information
eugene123tw committed Oct 10, 2024
1 parent 758ea97 commit 7bc9140
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from __future__ import annotations

import warnings

import numpy as np
import torch
from datumaro.components.annotation import Polygon
Expand Down Expand Up @@ -41,7 +43,7 @@ def mask_target(
"""
cfg_list = [cfg for _ in range(len(pos_proposals_list))]
mask_targets = map(
mask_target_single,
mask_target_single, # type: ignore[arg-type]
pos_proposals_list,
pos_assigned_gt_inds_list,
gt_masks_list,
Expand All @@ -58,20 +60,24 @@ def mask_target_single(
pos_proposals: Tensor,
pos_assigned_gt_inds: Tensor,
gt_masks: list[Polygon] | tv_tensors.Mask,
cfg: dict,
mask_size: list[int],
meta_info: dict,
) -> Tensor:
"""Compute mask target for each positive proposal in the image."""
mask_size = _pair(mask_size)
if len(gt_masks) == 0:
warnings.warn("No ground truth masks are provided!", stacklevel=2)
return pos_proposals.new_zeros((0, *mask_size))

if isinstance(gt_masks[0], Polygon):
crop_and_resize = crop_and_resize_polygons
elif isinstance(gt_masks, tv_tensors.Mask):
crop_and_resize = crop_and_resize_masks
else:
msg = f"Unsupported type of masks: {type(gt_masks[0])}"
raise NotImplementedError(msg)
warnings.warn("Unsupported ground truth mask type!", stacklevel=2)
return pos_proposals.new_zeros((0, *mask_size))

device = pos_proposals.device
mask_size = _pair(cfg["mask_size"])
num_pos = pos_proposals.size(0)
if num_pos > 0:
proposals_np = pos_proposals.cpu().numpy()
Expand All @@ -83,7 +89,7 @@ def mask_target_single(
mask_targets = crop_and_resize(
gt_masks,
proposals_np,
mask_size,
mask_size, # type: ignore[arg-type]
inds=pos_assigned_gt_inds,
device=device,
)
Expand Down
6 changes: 4 additions & 2 deletions src/otx/core/data/pre_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ def pre_filtering(
dataset = DmDataset.filter(dataset, is_valid_annot, filter_annotations=True)
dataset = remove_unused_labels(dataset, data_format, ignore_index)
if unannotated_items_ratio > 0:
empty_items = [item.id for item in dataset if item.subset == "train" and len(item.annotations) == 0]
empty_items = [
item.id for item in dataset if item.subset in ("train", "TRAINING") and len(item.annotations) == 0
]
used_background_items = set(sample(empty_items, int(len(empty_items) * unannotated_items_ratio)))

return DmDataset.filter(
dataset,
lambda item: not (
item.subset == "train" and len(item.annotations) == 0 and item.id not in used_background_items
item.subset in ("train", "TRAINING") and len(item.annotations) == 0 and item.id not in used_background_items
),
)

Expand Down
110 changes: 110 additions & 0 deletions tests/unit/algo/common/test_iou2d_calculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch


from otx.algo.common.utils.bbox_overlaps import bbox_overlaps
from otx.algo.common.utils.assigners.iou2d_calculator import BboxOverlaps2D


def test_bbox_overlaps_2d(eps=1e-7):

def _construct_bbox(num_bbox=None):
img_h = int(np.random.randint(3, 1000))
img_w = int(np.random.randint(3, 1000))
if num_bbox is None:
num_bbox = np.random.randint(1, 10)
x1y1 = torch.rand((num_bbox, 2))
x2y2 = torch.max(torch.rand((num_bbox, 2)), x1y1)
bboxes = torch.cat((x1y1, x2y2), -1)
bboxes[:, 0::2] *= img_w
bboxes[:, 1::2] *= img_h
return bboxes, num_bbox

# is_aligned is True, bboxes.size(-1) == 5 (include score)
self = BboxOverlaps2D()
bboxes1, num_bbox = _construct_bbox()
bboxes2, _ = _construct_bbox(num_bbox)
bboxes1 = torch.cat((bboxes1, torch.rand((num_bbox, 1))), 1)
bboxes2 = torch.cat((bboxes2, torch.rand((num_bbox, 1))), 1)
gious = self(bboxes1, bboxes2, 'giou', True)
assert gious.size() == (num_bbox, ), gious.size()
assert torch.all(gious >= -1) and torch.all(gious <= 1)

# is_aligned is True, bboxes1.size(-2) == 0
bboxes1 = torch.empty((0, 4))
bboxes2 = torch.empty((0, 4))
gious = self(bboxes1, bboxes2, 'giou', True)
assert gious.size() == (0, ), gious.size()
assert torch.all(gious == torch.empty((0, )))
assert torch.all(gious >= -1) and torch.all(gious <= 1)

# is_aligned is True, and bboxes.ndims > 2
bboxes1, num_bbox = _construct_bbox()
bboxes2, _ = _construct_bbox(num_bbox)
bboxes1 = bboxes1.unsqueeze(0).repeat(2, 1, 1)
# test assertion when batch dim is not the same
with pytest.raises(ValueError):
self(bboxes1, bboxes2.unsqueeze(0).repeat(3, 1, 1), 'giou', True)
bboxes2 = bboxes2.unsqueeze(0).repeat(2, 1, 1)
gious = self(bboxes1, bboxes2, 'giou', True)
assert torch.all(gious >= -1) and torch.all(gious <= 1)
assert gious.size() == (2, num_bbox)
bboxes1 = bboxes1.unsqueeze(0).repeat(2, 1, 1, 1)
bboxes2 = bboxes2.unsqueeze(0).repeat(2, 1, 1, 1)
gious = self(bboxes1, bboxes2, 'giou', True)
assert torch.all(gious >= -1) and torch.all(gious <= 1)
assert gious.size() == (2, 2, num_bbox)

# is_aligned is False
bboxes1, num_bbox1 = _construct_bbox()
bboxes2, num_bbox2 = _construct_bbox()
gious = self(bboxes1, bboxes2, 'giou')
assert torch.all(gious >= -1) and torch.all(gious <= 1)
assert gious.size() == (num_bbox1, num_bbox2)

# is_aligned is False, and bboxes.ndims > 2
bboxes1 = bboxes1.unsqueeze(0).repeat(2, 1, 1)
bboxes2 = bboxes2.unsqueeze(0).repeat(2, 1, 1)
gious = self(bboxes1, bboxes2, 'giou')
assert torch.all(gious >= -1) and torch.all(gious <= 1)
assert gious.size() == (2, num_bbox1, num_bbox2)
bboxes1 = bboxes1.unsqueeze(0)
bboxes2 = bboxes2.unsqueeze(0)
gious = self(bboxes1, bboxes2, 'giou')
assert torch.all(gious >= -1) and torch.all(gious <= 1)
assert gious.size() == (1, 2, num_bbox1, num_bbox2)

# is_aligned is False, bboxes1.size(-2) == 0
gious = self(torch.empty(1, 2, 0, 4), bboxes2, 'giou')
assert torch.all(gious == torch.empty(1, 2, 0, bboxes2.size(-2)))
assert torch.all(gious >= -1) and torch.all(gious <= 1)

# test allclose between bbox_overlaps and the original official
# implementation.
bboxes1 = torch.FloatTensor([
[0, 0, 10, 10],
[10, 10, 20, 20],
[32, 32, 38, 42],
])
bboxes2 = torch.FloatTensor([
[0, 0, 10, 20],
[0, 10, 10, 19],
[10, 10, 20, 20],
])
gious = bbox_overlaps(bboxes1, bboxes2, 'giou', is_aligned=True, eps=eps)
gious = gious.numpy().round(4)
# the gt is got with four decimal precision.
expected_gious = np.array([0.5000, -0.0500, -0.8214])
assert np.allclose(gious, expected_gious, rtol=0, atol=eps)

# test mode 'iof'
ious = bbox_overlaps(bboxes1, bboxes2, 'iof', is_aligned=True, eps=eps)
assert torch.all(ious >= -1) and torch.all(ious <= 1)
assert ious.size() == (bboxes1.size(0), )
ious = bbox_overlaps(bboxes1, bboxes2, 'iof', eps=eps)
assert torch.all(ious >= -1) and torch.all(ious <= 1)
assert ious.size() == (bboxes1.size(0), bboxes2.size(0))

0 comments on commit 7bc9140

Please sign in to comment.