Skip to content

Commit

Permalink
update dataset mapper with latest detectron2 (#123)
Browse files Browse the repository at this point in the history
* update dataset mapper with latest detectron2

* tf_gens-->augmentation

* fix bug
  • Loading branch information
wangg12 authored Jun 27, 2020
1 parent b380f07 commit db7f0b1
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 29 deletions.
43 changes: 25 additions & 18 deletions adet/data/dataset_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from detectron2.data import transforms as T

from .detection_utils import (
build_transform_gen,
build_augmentation,
transform_instance_annotations,
annotations_to_instances,
gen_crop_transform_with_instance,
Expand All @@ -35,9 +35,9 @@ class DatasetMapperWithBasis(DatasetMapper):
def __init__(self, cfg, is_train=True):
super().__init__(cfg, is_train)

# Rebuild transform gen
logger.info("Rebuilding the transform generators. The previous generators will be overridden.")
self.tfm_gens = build_transform_gen(cfg, is_train)
# Rebuild augmentations
logger.info("Rebuilding the augmentations. The previous augmentations will be overridden.")
self.augmentation = build_augmentation(cfg, is_train)

# fmt: off
self.basis_loss_on = cfg.MODEL.BASIS_MODULE.LOSS_ON
Expand Down Expand Up @@ -73,40 +73,44 @@ def __call__(self, dataset_dict):
raise e

if "annotations" not in dataset_dict or len(dataset_dict["annotations"]) == 0:
image, transforms = T.apply_transform_gens(
([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image
image, transforms = T.apply_augmentations(
([self.crop] if self.crop else []) + self.augmentation, image
)
else:
# Crop around an instance if there are instances in the image.
# USER: Remove if you don't use cropping
if self.crop_gen:
if self.crop:
crop_tfm = gen_crop_transform_with_instance(
self.crop_gen.get_crop_size(image.shape[:2]),
self.crop.get_crop_size(image.shape[:2]),
image.shape[:2],
dataset_dict["annotations"],
crop_box=self.crop_box,
)
image = crop_tfm.apply_image(image)
try:
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
image, transforms = T.apply_augmentations(self.augmentation, image)
except ValueError as e:
print(dataset_dict["file_name"])
raise e
if self.crop_gen:
if self.crop:
transforms = crop_tfm + transforms

image_shape = image.shape[:2] # h, w

# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
# Therefore it's important to use torch.Tensor.
dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
# Can use uint8 if it turns out to be slow some day
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))

# USER: Remove if you don't use pre-computed proposals.
# Most users would not need this feature.
if self.load_proposals:
utils.transform_proposals(
dataset_dict, image_shape, transforms, self.min_box_side_len, self.proposal_topk
dataset_dict,
image_shape,
transforms,
proposal_topk=self.proposal_topk,
min_box_size=self.proposal_min_box_size,
)

if not self.is_train:
Expand Down Expand Up @@ -134,16 +138,19 @@ def __call__(self, dataset_dict):
instances = annotations_to_instances(
annos, image_shape, mask_format=self.mask_format
)
# Create a tight bounding box from masks, useful when image is cropped
if self.crop_gen and instances.has("gt_masks"):

# After transforms such as cropping are applied, the bounding box may no longer
# tightly bound the object. As an example, imagine a triangle object
# [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
# bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
# the intersection of original bounding box and the cropping box.
if self.crop and instances.has("gt_masks"):
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
dataset_dict["instances"] = utils.filter_empty_instances(instances)

# USER: Remove if you don't do semantic/panoptic segmentation.
if "sem_seg_file_name" in dataset_dict:
with PathManager.open(dataset_dict.pop("sem_seg_file_name"), "rb") as f:
sem_seg_gt = Image.open(f)
sem_seg_gt = np.asarray(sem_seg_gt, dtype="uint8")
sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
sem_seg_gt = transforms.apply_segmentation(sem_seg_gt)
sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
dataset_dict["sem_seg"] = sem_seg_gt
Expand Down
28 changes: 17 additions & 11 deletions adet/data/detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def adjust_crop(x0, y0, crop_size, instances, eps=1e-3):
crop_size[1] += bbox[2] - x1
x1 = bbox[2]
modified = True

if bbox[1] < y0 - eps and bbox[3] > y0 + eps:
crop_size[0] += y0 - bbox[1]
y0 = bbox[1]
Expand All @@ -89,7 +89,7 @@ def adjust_crop(x0, y0, crop_size, instances, eps=1e-3):
modified = True

return modified, x0, y0, crop_size


def transform_instance_annotations(
annotation, transforms, image_size, *, keypoint_hflip_indices=None
Expand All @@ -98,7 +98,7 @@ def transform_instance_annotations(
annotation = d2_transform_inst_anno(
annotation, transforms, image_size,
keypoint_hflip_indices=keypoint_hflip_indices)

if "beziers" in annotation:
beziers = transform_beziers_annotations(
annotation["beziers"], transforms
Expand Down Expand Up @@ -140,16 +140,16 @@ def annotations_to_instances(annos, image_size, mask_format="polygon"):
text = [obj.get("rec", []) for obj in annos]
instance.text = torch.as_tensor(
text, dtype=torch.int32)

return instance


def build_transform_gen(cfg, is_train):
def build_augmentation(cfg, is_train):
"""
With option to don't use hflip
Returns:
list[TransformGen]
list[Augmentation]
"""
if is_train:
min_size = cfg.INPUT.MIN_SIZE_TRAIN
Expand All @@ -164,10 +164,16 @@ def build_transform_gen(cfg, is_train):
len(min_size)
)

tfm_gens = []
tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
augmentation = []
augmentation.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
if is_train:
if cfg.INPUT.HFLIP_TRAIN:
tfm_gens.append(T.RandomFlip())
logger.info("TransformGens used in training: " + str(tfm_gens))
return tfm_gens
augmentation.append(T.RandomFlip())
logger.info("Augmentations used in training: " + str(augmentation))
return augmentation


build_transform_gen = build_augmentation
"""
Alias for backward-compatibility.
"""

0 comments on commit db7f0b1

Please sign in to comment.