Skip to content

Commit

Permalink
Replace imgaug with albumentations (#1623)
Browse files Browse the repository at this point in the history
What's the worst that could happen?

* Initial commit

* Fix augmentation

* Update more deps requirements

* Use pip for installing albumentations and avoid reinstalling OpenCV

* Update other conda envs
  • Loading branch information
talmo authored Mar 17, 2024
1 parent 8ab323e commit eb14764
Show file tree
Hide file tree
Showing 12 changed files with 61 additions and 68 deletions.
2 changes: 1 addition & 1 deletion .conda/bld.bat
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ set PIP_IGNORE_INSTALLED=False

@REM Install the pip dependencies. Note: Using urls to wheels might be better:
@REM https://docs.conda.io/projects/conda-build/en/stable/user-guide/wheel-files.html)
pip install --no-cache-dir -r .\requirements.txt
pip install --no-cache-dir -r .\requirements.txt --no-binary qudida,albumentations

@REM Install sleap itself. This does not install the requirements, but will list which
@REM requirements are missing (see "install_requires") when user attempts to install.
Expand Down
2 changes: 1 addition & 1 deletion .conda/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ export PIP_IGNORE_INSTALLED=False

# Install the pip dependencies. Note: Using urls to wheels might be better:
# https://docs.conda.io/projects/conda-build/en/stable/user-guide/wheel-files.html)
pip install --no-cache-dir -r ./requirements.txt
pip install --no-cache-dir -r ./requirements.txt --no-binary qudida,albumentations


# Install sleap itself. This does not install the requirements, but will list which
Expand Down
2 changes: 0 additions & 2 deletions .conda/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ requirements:
- conda-forge::attrs ==21.4.0
- conda-forge::cattrs ==1.1.1
- conda-forge::h5py ==3.1 # [not win]
- conda-forge::imgaug ==0.4.0
- conda-forge::jsmin
- conda-forge::jsonpickle ==1.2
- conda-forge::networkx
Expand Down Expand Up @@ -61,7 +60,6 @@ requirements:
- conda-forge::cudnn=8.2.1
- nvidia::cuda-nvcc=11.3
- conda-forge::h5py ==3.1 # [not win]
- conda-forge::imgaug ==0.4.0
- conda-forge::jsmin
- conda-forge::jsonpickle ==1.2
- conda-forge::networkx
Expand Down
2 changes: 1 addition & 1 deletion .conda_mac/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ export PIP_NO_INDEX=False
export PIP_NO_DEPENDENCIES=False
export PIP_IGNORE_INSTALLED=False

pip install --no-cache-dir -r requirements.txt
pip install --no-cache-dir -r requirements.txt --no-binary qudida,albumentations

python setup.py install --single-version-externally-managed --record=record.txt
2 changes: 0 additions & 2 deletions .conda_mac/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ requirements:
- conda-forge::attrs >=21.2.0
- conda-forge::cattrs ==1.1.1
- conda-forge::h5py
- conda-forge::imgaug ==0.4.0
- conda-forge::jsmin
- conda-forge::jsonpickle ==1.2
- conda-forge::keras <2.10.0,>=2.9.0rc0 # Required by tensorflow-macos
Expand All @@ -61,7 +60,6 @@ requirements:
- conda-forge::attrs >=21.2.0
- conda-forge::cattrs ==1.1.1
- conda-forge::h5py
- conda-forge::imgaug ==0.4.0
- conda-forge::jsmin
- conda-forge::jsonpickle ==1.2
- conda-forge::keras <2.10.0,>=2.9.0rc0 # Required by tensorflow-macos
Expand Down
3 changes: 1 addition & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ dependencies:
# Packages SLEAP uses directly
- conda-forge::attrs >=21.2.0 #,<=21.4.0
- conda-forge::cattrs ==1.1.1
- conda-forge::imgaug ==0.4.0
- conda-forge::jsmin
- conda-forge::jsonpickle ==1.2
- conda-forge::networkx
Expand Down Expand Up @@ -46,4 +45,4 @@ dependencies:

- pip:
- "--editable=.[conda_dev]"

- "--no-binary qudida,albumentations"
2 changes: 1 addition & 1 deletion environment_mac.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ dependencies:
- conda-forge::attrs >=21.2.0
- conda-forge::cattrs ==1.1.1
- conda-forge::h5py
- conda-forge::imgaug ==0.4.0
- conda-forge::jsmin
- conda-forge::jsonpickle ==1.2
- conda-forge::keras <2.10.0,>=2.9.0rc0 # Required by tensorflow-macos
Expand All @@ -38,3 +37,4 @@ dependencies:
- conda-forge::tensorflow-hub
- pip:
- "--editable=.[conda_dev]"
- "--no-binary qudida,albumentations"
2 changes: 1 addition & 1 deletion environment_no_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ dependencies:
# Packages SLEAP uses directly
- conda-forge::attrs >=21.2.0 #,<=21.4.0
- conda-forge::cattrs ==1.1.1
- conda-forge::imgaug ==0.4.0
- conda-forge::jsmin
- conda-forge::jsonpickle ==1.2
- conda-forge::networkx
Expand Down Expand Up @@ -41,3 +40,4 @@ dependencies:

- pip:
- "--editable=.[conda_dev]"
- "--no-binary qudida,albumentations"
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ tensorflow-metal==0.5.0; sys_platform == 'darwin' and platform_machine == 'arm64
# Conda installing results in https://github.com/h5py/h5py/issues/2037
h5py<3.2; sys_platform == 'win32' # Newer versions result in error above, linking issue in Linux
pynwb>=2.3.3 # 2.0.0 required by ndx-pose, 2.3.3 fixes importlib-metadata incompatibility

albumentations
88 changes: 42 additions & 46 deletions sleap/nn/data/augmentation.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,11 @@
"""Transformers for applying data augmentation."""

# Monkey patch for: https://github.com/aleju/imgaug/issues/537
# TODO: Fix when PyPI/conda packages are available for version fencing.
import numpy

if hasattr(numpy.random, "_bit_generator"):
numpy.random.bit_generator = numpy.random._bit_generator

import sleap
import numpy as np
import tensorflow as tf
import attr
from typing import List, Text, Optional
import imgaug as ia
import imgaug.augmenters as iaa
import albumentations as A
from sleap.nn.config import AugmentationConfig
from sleap.nn.data.instance_cropping import crop_bboxes

Expand Down Expand Up @@ -111,23 +103,23 @@ def flip_instances_ud(


@attr.s(auto_attribs=True)
class ImgaugAugmenter:
"""Data transformer based on the `imgaug` library.
class AlbumentationsAugmenter:
"""Data transformer based on the `albumentations` library.
This class can generate a `tf.data.Dataset` from an existing one that generates
image and instance data. Element of the output dataset will have a set of
augmentation transformations applied.
Attributes:
augmenter: An instance of `imgaug.augmenters.Sequential` that will be applied to
augmenter: An instance of `albumentations.Compose` that will be applied to
each element of the input dataset.
image_key: Name of the example key where the image is stored. Defaults to
"image".
instances_key: Name of the example key where the instance points are stored.
Defaults to "instances".
"""

augmenter: iaa.Sequential
augmenter: A.Compose
image_key: str = "image"
instances_key: str = "instances"

Expand All @@ -137,7 +129,7 @@ def from_config(
config: AugmentationConfig,
image_key: Text = "image",
instances_key: Text = "instances",
) -> "ImgaugAugmenter":
) -> "AlbumentationsAugmenter":
"""Create an augmenter from a set of configuration parameters.
Args:
Expand All @@ -148,52 +140,63 @@ def from_config(
Defaults to "instances".
Returns:
An instance of `ImgaugAugmenter` with the specified augmentation
An instance of `AlbumentationsAugmenter` with the specified augmentation
configuration.
"""
aug_stack = []
if config.rotate:
aug_stack.append(
iaa.Affine(
rotate=(config.rotation_min_angle, config.rotation_max_angle)
A.Rotate(
limit=(config.rotation_min_angle, config.rotation_max_angle), p=1.0
)
)
if config.translate:
aug_stack.append(
iaa.Affine(
A.Affine(
translate_px={
"x": (config.translate_min, config.translate_max),
"y": (config.translate_min, config.translate_max),
}
},
p=1.0,
)
)
if config.scale:
aug_stack.append(iaa.Affine(scale=(config.scale_min, config.scale_max)))
if config.uniform_noise:
aug_stack.append(
iaa.AddElementwise(
value=(config.uniform_noise_min_val, config.uniform_noise_max_val)
)
A.Affine(scale=(config.scale_min, config.scale_max), p=1.0)
)
if config.uniform_noise:

def uniform_noise(image, **kwargs):
return image + np.random.uniform(
config.uniform_noise_min_val, config.uniform_noise_max_val
)

aug_stack.append(A.Lambda(image=uniform_noise))
if config.gaussian_noise:
aug_stack.append(
iaa.AdditiveGaussianNoise(
loc=config.gaussian_noise_mean, scale=config.gaussian_noise_stddev
A.GaussNoise(
mean=config.gaussian_noise_mean,
var_limit=config.gaussian_noise_stddev,
)
)
if config.contrast:
aug_stack.append(
iaa.GammaContrast(
gamma=(config.contrast_min_gamma, config.contrast_max_gamma)
A.RandomGamma(
gamma_limit=(config.contrast_min_gamma, config.contrast_max_gamma),
p=1.0,
)
)
if config.brightness:
aug_stack.append(
iaa.Add(value=(config.brightness_min_val, config.brightness_max_val))
A.RandomBrightness(
limit=(config.brightness_min_val, config.brightness_max_val), p=1.0
)
)

return cls(
augmenter=iaa.Sequential(aug_stack),
augmenter=A.Compose(
aug_stack, keypoint_params=A.KeypointParams(format="xy")
),
image_key=image_key,
instances_key=instances_key,
)
Expand Down Expand Up @@ -226,22 +229,16 @@ def transform_dataset(self, input_ds: tf.data.Dataset) -> tf.data.Dataset:
# Define augmentation function to map over each sample.
def py_augment(image, instances):
"""Local processing function that will not be autographed."""
# Ensure that the transformations applied to all data within this
# example are kept consistent.
aug_det = self.augmenter.to_deterministic()
# Convert to numpy arrays.
img = image.numpy()
kps = instances.numpy()
original_shape = kps.shape
kps = kps.reshape(-1, 2)

# Augment the image.
aug_img = aug_det.augment_image(image.numpy())

# This will get converted to a rank 3 tensor (n_instances, n_nodes, 2).
aug_instances = np.full_like(instances, np.nan)

# Augment each set of points for each instance.
for i, instance in enumerate(instances):
kps = ia.KeypointsOnImage.from_xy_array(
instance.numpy(), tuple(image.shape)
)
aug_instances[i] = aug_det.augment_keypoints(kps).to_xy_array()
# Augment.
augmented = self.augmenter(image=img, keypoints=kps)
aug_img = augmented["image"]
aug_instances = np.array(augmented["keypoints"]).reshape(original_shape)

return aug_img, aug_instances

Expand All @@ -258,7 +255,6 @@ def augment(frame_data):
return frame_data

# Apply the augmentation to each element.
# Note: We map sequentially since imgaug gets slower with tf.data parallelism.
output_ds = input_ds.map(augment)

return output_ds
Expand Down
16 changes: 8 additions & 8 deletions sleap/nn/data/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from sleap.nn.data.providers import LabelsReader, VideoReader
from sleap.nn.data.augmentation import (
AugmentationConfig,
ImgaugAugmenter,
AlbumentationsAugmenter,
RandomCropper,
RandomFlipper,
)
Expand Down Expand Up @@ -68,7 +68,7 @@

PROVIDERS = (LabelsReader, VideoReader)
TRANSFORMERS = (
ImgaugAugmenter,
AlbumentationsAugmenter,
RandomCropper,
Normalizer,
Resizer,
Expand Down Expand Up @@ -406,7 +406,7 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
self.data_config.labels.skeletons[0],
horizontal=self.optimization_config.augmentation_config.flip_horizontal,
)
pipeline += ImgaugAugmenter.from_config(
pipeline += AlbumentationsAugmenter.from_config(
self.optimization_config.augmentation_config
)
if self.optimization_config.augmentation_config.random_crop:
Expand Down Expand Up @@ -550,7 +550,7 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
self.data_config.labels.skeletons[0],
horizontal=self.optimization_config.augmentation_config.flip_horizontal,
)
pipeline += ImgaugAugmenter.from_config(
pipeline += AlbumentationsAugmenter.from_config(
self.optimization_config.augmentation_config
)
if self.optimization_config.augmentation_config.random_crop:
Expand Down Expand Up @@ -713,7 +713,7 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
self.data_config.labels.skeletons[0],
horizontal=self.optimization_config.augmentation_config.flip_horizontal,
)
pipeline += ImgaugAugmenter.from_config(
pipeline += AlbumentationsAugmenter.from_config(
self.optimization_config.augmentation_config
)
pipeline += Normalizer.from_config(self.data_config.preprocessing)
Expand Down Expand Up @@ -863,7 +863,7 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
self.data_config.labels.skeletons[0],
horizontal=aug_config.flip_horizontal,
)
pipeline += ImgaugAugmenter.from_config(aug_config)
pipeline += AlbumentationsAugmenter.from_config(aug_config)
if aug_config.random_crop:
pipeline += RandomCropper(
crop_height=aug_config.random_crop_height,
Expand Down Expand Up @@ -1028,7 +1028,7 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
horizontal=aug_config.flip_horizontal,
)

pipeline += ImgaugAugmenter.from_config(aug_config)
pipeline += AlbumentationsAugmenter.from_config(aug_config)
if aug_config.random_crop:
pipeline += RandomCropper(
crop_height=aug_config.random_crop_height,
Expand Down Expand Up @@ -1186,7 +1186,7 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
config=self.data_config.preprocessing,
provider=data_provider,
)
pipeline += ImgaugAugmenter.from_config(
pipeline += AlbumentationsAugmenter.from_config(
self.optimization_config.augmentation_config
)
pipeline += Normalizer.from_config(self.data_config.preprocessing)
Expand Down
6 changes: 3 additions & 3 deletions tests/nn/data/test_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ def test_augmentation(min_labels):
ds = labels_reader.make_dataset()
example_preaug = next(iter(ds))

augmenter = augmentation.ImgaugAugmenter.from_config(
augmenter = augmentation.AlbumentationsAugmenter.from_config(
augmentation.AugmentationConfig(
rotate=True, rotation_min_angle=-90, rotation_max_angle=-90
rotate=True, rotation_min_angle=90, rotation_max_angle=90
)
)
ds = augmenter.transform_dataset(ds)
Expand Down Expand Up @@ -52,7 +52,7 @@ def test_augmentation_with_no_instances(min_labels):
)

p = min_labels.to_pipeline(user_labeled_only=False)
p += augmentation.ImgaugAugmenter.from_config(
p += augmentation.AlbumentationsAugmenter.from_config(
augmentation.AugmentationConfig(rotate=True)
)
exs = p.run()
Expand Down

0 comments on commit eb14764

Please sign in to comment.