Skip to content

Commit

Permalink
crop resize
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanklut committed Jul 12, 2024
1 parent 2e46321 commit c154ff6
Showing 1 changed file with 88 additions and 2 deletions.
90 changes: 88 additions & 2 deletions data/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,6 +1474,93 @@ def get_transform(self, image, sem_seg) -> T.Transform:
raise ValueError(f"Image type {type(image)} not supported")


class CropResize(T.Augmentation):
def __init__(
self,
crop_type,
crop_size,
resize_mode,
scale=None,
target_dpi=None,
min_size=None,
max_size=None,
ignore_value=255,
):
if resize_mode == "shortest_edge":
if min_size is None or max_size is None:
raise ValueError("min_size and max_size must be provided for shortest_edge mode")
self.resize = ResizeShortestEdge(min_size, max_size, "choice")
elif resize_mode == "longest_edge":
if min_size is None or max_size is None:
raise ValueError("min_size and max_size must be provided for longest_edge mode")
self.resize = ResizeLongestEdge(min_size, max_size, "choice")
elif resize_mode == "scaling":
if scale is None or max_size is None:
raise ValueError("scale and max_size must be provided for scaling mode")
self.resize = ResizeScaling(scale, max_size)

self.target_dpi = target_dpi

self.crop_size = crop_size
self.crop_type = crop_type

self.ignore_value = ignore_value

def numpy_transform(self, image: np.ndarray, dpi: int) -> T.Transform:
height, width = image.shape[:2]
new_height, new_width = self.resize.get_output_shape(height, width, dpi)
scale_x = width / new_width
scale_y = height / new_height

original_crop_size = self.crop_size

scaled_crop_size = (int(self.crop_size[0] * scale_y), int(self.crop_size[1] * scale_x))

transform = T.TransformList(
[
RandomCrop(self.crop_type, self.crop_size).numpy_transform(image),
NT.ResizeTransform(
height=scaled_crop_size[0],
width=scaled_crop_size[1],
new_height=original_crop_size[0],
new_width=original_crop_size[1],
),
]
)
return transform

def torch_transform(self, image: torch.Tensor, dpi: int) -> T.Transform:
height, width = image.shape[-2:]
new_height, new_width = self.resize.get_output_shape(height, width, dpi)
scale_x = width / new_width
scale_y = height / new_height

original_crop_size = self.crop_size

scaled_crop_size = (int(self.crop_size[0] * scale_y), int(self.crop_size[1] * scale_x))

transform = T.TransformList(
[
RandomCrop(self.crop_type, self.crop_size).torch_transform(image),
TT.ResizeTransform(
height=scaled_crop_size[0],
width=scaled_crop_size[1],
new_height=original_crop_size[0],
new_width=original_crop_size[1],
),
]
)
return transform

def get_transform(self, image, dpi):
if isinstance(image, np.ndarray):
return self.numpy_transform(image, dpi)
elif isinstance(image, torch.Tensor):
return self.torch_transform(image, dpi)
else:
raise ValueError(f"Image type {type(image)} not supported")


def build_augmentation(cfg: CfgNode, mode: str = "train") -> list[T.Augmentation]:
"""
Function to generate all the augmentations used in the inference and training process
Expand Down Expand Up @@ -1558,10 +1645,9 @@ def build_augmentation(cfg: CfgNode, mode: str = "train") -> list[T.Augmentation
# Crop
if cfg.INPUT.CROP.ENABLED:
augmentation.append(
RandomCrop_CategoryAreaConstraint(
RandomCrop(
crop_type=cfg.INPUT.CROP.TYPE,
crop_size=cfg.INPUT.CROP.SIZE,
single_category_max_area=cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA,
)
)

Expand Down

0 comments on commit c154ff6

Please sign in to comment.