diff --git a/data_loader/modules/iaa_augment.py b/data_loader/modules/iaa_augment.py index bd2588e..902862f 100644 --- a/data_loader/modules/iaa_augment.py +++ b/data_loader/modules/iaa_augment.py @@ -48,17 +48,12 @@ def __call__(self, data): def may_augment_annotation(self, aug, data, shape): if aug is None: return data - - line_polys = [] + + all_poly = [] for poly in data['text_polys']: - new_poly = self.may_augment_poly(aug, shape, poly) - line_polys.append(new_poly) - data['text_polys'] = np.array(line_polys) + for p in poly: + all_poly.append(imgaug.Keypoint(p[0], p[1])) + keypoints = aug.augment_keypoints([imgaug.KeypointsOnImage(all_poly, shape=shape)])[0].keypoints + final_poly = np.array([(p.x, p.y) for p in keypoints]).reshape([-1, 4, 2]) + data['text_polys'] =final_poly return data - - def may_augment_poly(self, aug, img_shape, poly): - keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly] - keypoints = aug.augment_keypoints( - [imgaug.KeypointsOnImage(keypoints, shape=img_shape)])[0].keypoints - poly = [(p.x, p.y) for p in keypoints] - return poly