-
Notifications
You must be signed in to change notification settings - Fork 653
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #131 from aim-uofa/crop_augmentation_fix
- Loading branch information
Showing
3 changed files
with
226 additions
and
148 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import numpy as np | ||
from fvcore.transforms import transform as T | ||
|
||
from detectron2.data.transforms import RandomCrop, StandardAugInput | ||
from detectron2.structures import BoxMode | ||
|
||
|
||
class InstanceAugInput(StandardAugInput): | ||
""" | ||
Keep the old behavior of instance-aware augmentation | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
instances = kwargs.pop("instances", None) | ||
super().__init__(*args, **kwargs) | ||
if instances is not None: | ||
self.instances = instances | ||
|
||
|
||
def gen_crop_transform_with_instance(crop_size, image_size, instances, crop_box=True): | ||
""" | ||
Generate a CropTransform so that the cropping region contains | ||
the center of the given instance. | ||
Args: | ||
crop_size (tuple): h, w in pixels | ||
image_size (tuple): h, w | ||
instance (dict): an annotation dict of one instance, in Detectron2's | ||
dataset format. | ||
""" | ||
instance = (np.random.choice(instances),) | ||
instance = instance[0] | ||
crop_size = np.asarray(crop_size, dtype=np.int32) | ||
bbox = BoxMode.convert(instance["bbox"], instance["bbox_mode"], BoxMode.XYXY_ABS) | ||
center_yx = (bbox[1] + bbox[3]) * 0.5, (bbox[0] + bbox[2]) * 0.5 | ||
assert ( | ||
image_size[0] >= center_yx[0] and image_size[1] >= center_yx[1] | ||
), "The annotation bounding box is outside of the image!" | ||
assert ( | ||
image_size[0] >= crop_size[0] and image_size[1] >= crop_size[1] | ||
), "Crop size is larger than image size!" | ||
|
||
min_yx = np.maximum(np.floor(center_yx).astype(np.int32) - crop_size, 0) | ||
max_yx = np.maximum(np.asarray(image_size, dtype=np.int32) - crop_size, 0) | ||
max_yx = np.minimum(max_yx, np.ceil(center_yx).astype(np.int32)) | ||
|
||
y0 = np.random.randint(min_yx[0], max_yx[0] + 1) | ||
x0 = np.random.randint(min_yx[1], max_yx[1] + 1) | ||
|
||
# if some instance is cropped extend the box | ||
if not crop_box: | ||
num_modifications = 0 | ||
modified = True | ||
|
||
# convert crop_size to float | ||
crop_size = crop_size.astype(np.float32) | ||
while modified: | ||
modified, x0, y0, crop_size = adjust_crop(x0, y0, crop_size, instances) | ||
num_modifications += 1 | ||
if num_modifications > 100: | ||
raise ValueError( | ||
"Cannot finished cropping adjustment within 100 tries (#instances {}).".format( | ||
len(instances) | ||
) | ||
) | ||
return T.CropTransform(0, 0, image_size[1], image_size[0]) | ||
|
||
return T.CropTransform(*map(int, (x0, y0, crop_size[1], crop_size[0]))) | ||
|
||
|
||
def adjust_crop(x0, y0, crop_size, instances, eps=1e-3): | ||
modified = False | ||
|
||
x1 = x0 + crop_size[1] | ||
y1 = y0 + crop_size[0] | ||
|
||
for instance in instances: | ||
bbox = BoxMode.convert( | ||
instance["bbox"], instance["bbox_mode"], BoxMode.XYXY_ABS | ||
) | ||
|
||
if bbox[0] < x0 - eps and bbox[2] > x0 + eps: | ||
crop_size[1] += x0 - bbox[0] | ||
x0 = bbox[0] | ||
modified = True | ||
|
||
if bbox[0] < x1 - eps and bbox[2] > x1 + eps: | ||
crop_size[1] += bbox[2] - x1 | ||
x1 = bbox[2] | ||
modified = True | ||
|
||
if bbox[1] < y0 - eps and bbox[3] > y0 + eps: | ||
crop_size[0] += y0 - bbox[1] | ||
y0 = bbox[1] | ||
modified = True | ||
|
||
if bbox[1] < y1 - eps and bbox[3] > y1 + eps: | ||
crop_size[0] += bbox[3] - y1 | ||
y1 = bbox[3] | ||
modified = True | ||
|
||
return modified, x0, y0, crop_size | ||
|
||
|
||
class RandomCropWithInstance(RandomCrop): | ||
""" Instance-aware cropping. | ||
""" | ||
|
||
def __init__(self, crop_type, crop_size, crop_instance=True): | ||
""" | ||
Args: | ||
crop_instance (bool): if False, extend cropping boxes to avoid cropping instances | ||
""" | ||
super().__init__(crop_type, crop_size) | ||
self.crop_instance = crop_instance | ||
self.input_args = ("image", "instances") | ||
|
||
def get_transform(self, img, instances): | ||
image_size = img.shape[:2] | ||
crop_size = self.get_crop_size(image_size) | ||
return gen_crop_transform_with_instance( | ||
crop_size, image_size, instances, crop_box=self.crop_instance | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.