Skip to content

Commit

Permalink
Merge pull request #131 from aim-uofa/crop_augmentation_fix
Browse files Browse the repository at this point in the history
fix crop augmentation #127, #129
  • Loading branch information
tianzhi0549 authored Jul 2, 2020
2 parents db7f0b1 + 2169216 commit c6fa04e
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 148 deletions.
123 changes: 123 additions & 0 deletions adet/data/augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import numpy as np
from fvcore.transforms import transform as T

from detectron2.data.transforms import RandomCrop, StandardAugInput
from detectron2.structures import BoxMode


class InstanceAugInput(StandardAugInput):
"""
Keep the old behavior of instance-aware augmentation
"""

def __init__(self, *args, **kwargs):
instances = kwargs.pop("instances", None)
super().__init__(*args, **kwargs)
if instances is not None:
self.instances = instances


def gen_crop_transform_with_instance(crop_size, image_size, instances, crop_box=True):
"""
Generate a CropTransform so that the cropping region contains
the center of the given instance.
Args:
crop_size (tuple): h, w in pixels
image_size (tuple): h, w
instance (dict): an annotation dict of one instance, in Detectron2's
dataset format.
"""
instance = (np.random.choice(instances),)
instance = instance[0]
crop_size = np.asarray(crop_size, dtype=np.int32)
bbox = BoxMode.convert(instance["bbox"], instance["bbox_mode"], BoxMode.XYXY_ABS)
center_yx = (bbox[1] + bbox[3]) * 0.5, (bbox[0] + bbox[2]) * 0.5
assert (
image_size[0] >= center_yx[0] and image_size[1] >= center_yx[1]
), "The annotation bounding box is outside of the image!"
assert (
image_size[0] >= crop_size[0] and image_size[1] >= crop_size[1]
), "Crop size is larger than image size!"

min_yx = np.maximum(np.floor(center_yx).astype(np.int32) - crop_size, 0)
max_yx = np.maximum(np.asarray(image_size, dtype=np.int32) - crop_size, 0)
max_yx = np.minimum(max_yx, np.ceil(center_yx).astype(np.int32))

y0 = np.random.randint(min_yx[0], max_yx[0] + 1)
x0 = np.random.randint(min_yx[1], max_yx[1] + 1)

# if some instance is cropped extend the box
if not crop_box:
num_modifications = 0
modified = True

# convert crop_size to float
crop_size = crop_size.astype(np.float32)
while modified:
modified, x0, y0, crop_size = adjust_crop(x0, y0, crop_size, instances)
num_modifications += 1
if num_modifications > 100:
raise ValueError(
"Cannot finished cropping adjustment within 100 tries (#instances {}).".format(
len(instances)
)
)
return T.CropTransform(0, 0, image_size[1], image_size[0])

return T.CropTransform(*map(int, (x0, y0, crop_size[1], crop_size[0])))


def adjust_crop(x0, y0, crop_size, instances, eps=1e-3):
modified = False

x1 = x0 + crop_size[1]
y1 = y0 + crop_size[0]

for instance in instances:
bbox = BoxMode.convert(
instance["bbox"], instance["bbox_mode"], BoxMode.XYXY_ABS
)

if bbox[0] < x0 - eps and bbox[2] > x0 + eps:
crop_size[1] += x0 - bbox[0]
x0 = bbox[0]
modified = True

if bbox[0] < x1 - eps and bbox[2] > x1 + eps:
crop_size[1] += bbox[2] - x1
x1 = bbox[2]
modified = True

if bbox[1] < y0 - eps and bbox[3] > y0 + eps:
crop_size[0] += y0 - bbox[1]
y0 = bbox[1]
modified = True

if bbox[1] < y1 - eps and bbox[3] > y1 + eps:
crop_size[0] += bbox[3] - y1
y1 = bbox[3]
modified = True

return modified, x0, y0, crop_size


class RandomCropWithInstance(RandomCrop):
""" Instance-aware cropping.
"""

def __init__(self, crop_type, crop_size, crop_instance=True):
"""
Args:
crop_instance (bool): if False, extend cropping boxes to avoid cropping instances
"""
super().__init__(crop_type, crop_size)
self.crop_instance = crop_instance
self.input_args = ("image", "instances")

def get_transform(self, img, instances):
image_size = img.shape[:2]
crop_size = self.get_crop_size(image_size)
return gen_crop_transform_with_instance(
crop_size, image_size, instances, crop_box=self.crop_instance
)
128 changes: 79 additions & 49 deletions adet/data/dataset_mapper.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
import copy
import numpy as np
import logging
import os.path as osp

import numpy as np
import torch
from fvcore.common.file_io import PathManager
from PIL import Image
import logging
from pycocotools import mask as maskUtils

from detectron2.data.dataset_mapper import DatasetMapper
from detectron2.data.detection_utils import SizeMismatchError
from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T
from detectron2.data.dataset_mapper import DatasetMapper
from detectron2.data.detection_utils import SizeMismatchError

from .detection_utils import (
build_augmentation,
transform_instance_annotations,
annotations_to_instances,
gen_crop_transform_with_instance,
)
from .augmentation import InstanceAugInput, RandomCropWithInstance
from .detection_utils import (annotations_to_instances, build_augmentation,
transform_instance_annotations)

"""
This file contains the default mapping that's applied to "dataset dicts".
Expand All @@ -27,6 +26,28 @@
logger = logging.getLogger(__name__)


def segmToRLE(segm, img_size):
h, w = img_size
if type(segm) == list:
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles = maskUtils.frPyObjects(segm, h, w)
rle = maskUtils.merge(rles)
elif type(segm["counts"]) == list:
# uncompressed RLE
rle = maskUtils.frPyObjects(segm, h, w)
else:
# rle
rle = segm
return rle


def segmToMask(segm, img_size):
rle = segmToRLE(segm, img_size)
m = maskUtils.decode(rle)
return m


class DatasetMapperWithBasis(DatasetMapper):
"""
This caller enables the default Detectron2 mapper to read an additional basis semantic label
Expand All @@ -36,13 +57,27 @@ def __init__(self, cfg, is_train=True):
super().__init__(cfg, is_train)

# Rebuild augmentations
logger.info("Rebuilding the augmentations. The previous augmentations will be overridden.")
logger.info(
"Rebuilding the augmentations. The previous augmentations will be overridden."
)
self.augmentation = build_augmentation(cfg, is_train)

if cfg.INPUT.CROP.ENABLED and is_train:
self.augmentation.insert(
0,
RandomCropWithInstance(
cfg.INPUT.CROP.TYPE,
cfg.INPUT.CROP.SIZE,
cfg.INPUT.CROP.CROP_INSTANCE,
),
)
logging.getLogger(__name__).info(
"Cropping used in training: " + str(self.augmentation[0])
)

# fmt: off
self.basis_loss_on = cfg.MODEL.BASIS_MODULE.LOSS_ON
self.ann_set = cfg.MODEL.BASIS_MODULE.ANN_SET
self.crop_box = cfg.INPUT.CROP.CROP_INSTANCE
self.basis_loss_on = cfg.MODEL.BASIS_MODULE.LOSS_ON
self.ann_set = cfg.MODEL.BASIS_MODULE.ANN_SET
# fmt: on

def __call__(self, dataset_dict):
Expand Down Expand Up @@ -72,35 +107,27 @@ def __call__(self, dataset_dict):
else:
raise e

if "annotations" not in dataset_dict or len(dataset_dict["annotations"]) == 0:
image, transforms = T.apply_augmentations(
([self.crop] if self.crop else []) + self.augmentation, image
)
# USER: Remove if you don't do semantic/panoptic segmentation.
if "sem_seg_file_name" in dataset_dict:
sem_seg_gt = utils.read_image(
dataset_dict.pop("sem_seg_file_name"), "L"
).squeeze(2)
else:
# Crop around an instance if there are instances in the image.
# USER: Remove if you don't use cropping
if self.crop:
crop_tfm = gen_crop_transform_with_instance(
self.crop.get_crop_size(image.shape[:2]),
image.shape[:2],
dataset_dict["annotations"],
crop_box=self.crop_box,
)
image = crop_tfm.apply_image(image)
try:
image, transforms = T.apply_augmentations(self.augmentation, image)
except ValueError as e:
print(dataset_dict["file_name"])
raise e
if self.crop:
transforms = crop_tfm + transforms
sem_seg_gt = None

image_shape = image.shape[:2] # h, w
aug_input = InstanceAugInput(image, sem_seg=sem_seg_gt, instances=dataset_dict["annotations"])
transforms = aug_input.apply_augmentations(self.augmentation)
image, sem_seg_gt = aug_input.image, aug_input.sem_seg

image_shape = image.shape[:2] # h, w
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
# Therefore it's important to use torch.Tensor.
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
dataset_dict["image"] = torch.as_tensor(
np.ascontiguousarray(image.transpose(2, 0, 1))
)
if sem_seg_gt is not None:
dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))

# USER: Remove if you don't use pre-computed proposals.
# Most users would not need this feature.
Expand Down Expand Up @@ -130,7 +157,10 @@ def __call__(self, dataset_dict):
# USER: Implement additional transformations if you have other types of data
annos = [
transform_instance_annotations(
obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
obj,
transforms,
image_shape,
keypoint_hflip_indices=self.keypoint_hflip_indices,
)
for obj in dataset_dict.pop("annotations")
if obj.get("iscrowd", 0) == 0
Expand All @@ -143,24 +173,24 @@ def __call__(self, dataset_dict):
# tightly bound the object. As an example, imagine a triangle object
# [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
# bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
# the intersection of original bounding box and the cropping box.
if self.crop and instances.has("gt_masks"):
if self.compute_tight_boxes and instances.has("gt_masks"):
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
dataset_dict["instances"] = utils.filter_empty_instances(instances)

# USER: Remove if you don't do semantic/panoptic segmentation.
if "sem_seg_file_name" in dataset_dict:
sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
sem_seg_gt = transforms.apply_segmentation(sem_seg_gt)
sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
dataset_dict["sem_seg"] = sem_seg_gt

if self.basis_loss_on and self.is_train:
# load basis supervisions
if self.ann_set == "coco":
basis_sem_path = dataset_dict["file_name"].replace('train2017', 'thing_train2017').replace('image/train', 'thing_train')
basis_sem_path = (
dataset_dict["file_name"]
.replace("train2017", "thing_train2017")
.replace("image/train", "thing_train")
)
else:
basis_sem_path = dataset_dict["file_name"].replace('coco', 'lvis').replace('train2017', 'thing_train')
basis_sem_path = (
dataset_dict["file_name"]
.replace("coco", "lvis")
.replace("train2017", "thing_train")
)
# change extension to npz
basis_sem_path = osp.splitext(basis_sem_path)[0] + ".npz"
basis_sem_gt = np.load(basis_sem_path)["mask"]
Expand Down
Loading

0 comments on commit c6fa04e

Please sign in to comment.