Skip to content

Commit

Permalink
Add albumentations (open-mmlab#45)
Browse files Browse the repository at this point in the history
* Add Albu transform

* pre-commit

* Create optional.txt

* Update requirements.txt

* Update transforms.py
  • Loading branch information
daavoo authored Sep 22, 2020
1 parent 8d3acce commit 99115fd
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 30 deletions.
2 changes: 1 addition & 1 deletion docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ Optional arguments:

Examples:

Assume that you have already downloaded the checkpoints to the directory `checkpoints/`.
Assume that you have already downloaded the checkpoints to the directory `checkpoints/`.
Test ResNet-50 on ImageNet validation and evaluate the top-1 and top-5.

```shell
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/finetune.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ img_norm_cfg = dict(
train_pipeline = [
dict(type='RandomCrop', size=32, padding=4),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Resize', size=224)
dict(type='Resize', size=224)
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
Expand Down
40 changes: 20 additions & 20 deletions docs/tutorials/new_modules.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ from .resnet import ResNet
class ResNet_CIFAR(ResNet):

"""ResNet backbone for CIFAR.
short description of the backbone
Args:
depth(int): Network depth, from {18, 34, 50, 101, 152}.
...
Expand All @@ -45,7 +45,7 @@ class ResNet_CIFAR(ResNet):

def init_weights(self, pretrained=None):
pass # override ResNet init_weights if necessary

def train(self, mode=True):
pass # override ResNet train if necessary
```
Expand Down Expand Up @@ -77,7 +77,7 @@ To add a new neck, we mainly implement the `forward` function, which applies som

```python
import torch.nn as nn

from ..builder import NECKS

@NECKS.register_module()
Expand Down Expand Up @@ -117,11 +117,11 @@ To implement a new head, basically we need to implement `forward_train`, which t
```python
from ..builder import HEADS
from .cls_head import ClsHead


@HEADS.register_module()
class LinearClsHead(ClsHead):

def __init__(self,
num_classes,
in_channels,
Expand All @@ -130,24 +130,24 @@ To implement a new head, basically we need to implement `forward_train`, which t
super(LinearClsHead, self).__init__(loss=loss, topk=topk)
self.in_channels = in_channels
self.num_classes = num_classes

if self.num_classes <= 0:
raise ValueError(
f'num_classes={num_classes} must be a positive integer')

self._init_layers()

def _init_layers(self):
self.fc = nn.Linear(self.in_channels, self.num_classes)

def init_weights(self):
normal_init(self.fc, mean=0, std=0.01, bias=0)

def forward_train(self, x, gt_label):
cls_score = self.fc(x)
losses = self.loss(cls_score, gt_label)
return losses

```


Expand Down Expand Up @@ -178,37 +178,37 @@ Together with the added GlobalAveragePooling neck, an entire config for a model
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

```

### Add new loss

To add a new loss function, we mainly implement the `forward` function in the loss module.
In addition, it is helpful to leverage the decorator `weighted_loss` to weight the loss for each element.
Assuming that we want to mimic a probablistic distribution generated from anther classification model, we implement a L1Loss to fulfil the purpose as below.
Assuming that we want to mimic a probablistic distribution generated from anther classification model, we implement a L1Loss to fulfil the purpose as below.

1. Create a new file in `mmcls/models/losses/l1_loss.py`.
```python
import torch
import torch.nn as nn

from ..builder import LOSSES
from .utils import weighted_loss

@weighted_loss
def l1_loss(pred, target):
assert pred.size() == target.size() and target.numel() > 0
loss = torch.abs(pred - target)
return loss

@LOSSES.register_module()
class L1Loss(nn.Module):

def __init__(self, reduction='mean', loss_weight=1.0):
super(L1Loss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight

def forward(self,
pred,
target,
Expand Down
8 changes: 4 additions & 4 deletions mmcls/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class CIFAR10(BaseDataset):
"""

base_folder = 'cifar-10-batches-py'
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz"
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
filename = 'cifar-10-python.tar.gz'
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
Expand Down Expand Up @@ -110,8 +110,8 @@ class CIFAR100(CIFAR10):
"""

base_folder = 'cifar-100-python'
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
filename = "cifar-100-python.tar.gz"
url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
filename = 'cifar-100-python.tar.gz'
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
train_list = [
['train', '16019d7e3df5f24257cddd939b257f8d'],
Expand Down
144 changes: 140 additions & 4 deletions mmcls/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import math
import random

Expand All @@ -6,6 +7,13 @@

from ..builder import PIPELINES

try:
import albumentations
from albumentations import Compose
except ImportError:
albumentations = None
Compose = None


@PIPELINES.register_module()
class RandomCrop(object):
Expand Down Expand Up @@ -155,8 +163,8 @@ def __init__(self,
else:
self.size = (size, size)
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
raise ValueError("range should be of kind (min, max). "
f"But received {scale}")
raise ValueError('range should be of kind (min, max). '
f'But received {scale}')
if backend not in ['cv2', 'pillow']:
raise ValueError(f'backend: {backend} is not supported for resize.'
'Supported backends are "cv2", "pillow"')
Expand Down Expand Up @@ -363,8 +371,8 @@ def __init__(self, size, interpolation='bilinear', backend='cv2'):
assert size[0] > 0 and (size[1] > 0 or size[1] == -1)
if size[1] == -1:
self.resize_w_short_side = True
assert interpolation in ("nearest", "bilinear", "bicubic", "area",
"lanczos")
assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area',
'lanczos')
if backend not in ['cv2', 'pillow']:
raise ValueError(f'backend: {backend} is not supported for resize.'
'Supported backends are "cv2", "pillow"')
Expand Down Expand Up @@ -486,3 +494,131 @@ def __repr__(self):
repr_str += f'std={list(self.std)}, '
repr_str += f'to_rgb={self.to_rgb})'
return repr_str


@PIPELINES.register_module()
class Albu(object):
"""Albumentation augmentation.
Adds custom transformations from Albumentations library.
Please, visit `https://albumentations.readthedocs.io`
to get more information.
An example of ``transforms`` is as followed:
.. code-block::
[
dict(
type='ShiftScaleRotate',
shift_limit=0.0625,
scale_limit=0.0,
rotate_limit=0,
interpolation=1,
p=0.5),
dict(
type='RandomBrightnessContrast',
brightness_limit=[0.1, 0.3],
contrast_limit=[0.1, 0.3],
p=0.2),
dict(type='ChannelShuffle', p=0.1),
dict(
type='OneOf',
transforms=[
dict(type='Blur', blur_limit=3, p=1.0),
dict(type='MedianBlur', blur_limit=3, p=1.0)
],
p=0.1),
]
Args:
transforms (list[dict]): A list of albu transformations
keymap (dict): Contains {'input key':'albumentation-style key'}
"""

def __init__(self, transforms, keymap=None, update_pad_shape=False):
if Compose is None:
raise RuntimeError('albumentations is not installed')

self.transforms = transforms
self.filter_lost_elements = False
self.update_pad_shape = update_pad_shape

self.aug = Compose([self.albu_builder(t) for t in self.transforms])

if not keymap:
self.keymap_to_albu = {
'img': 'image',
}
else:
self.keymap_to_albu = keymap
self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}

def albu_builder(self, cfg):
"""Import a module from albumentations.
It inherits some of :func:`build_from_cfg` logic.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
Returns:
obj: The constructed object.
"""

assert isinstance(cfg, dict) and 'type' in cfg
args = cfg.copy()

obj_type = args.pop('type')
if mmcv.is_str(obj_type):
if albumentations is None:
raise RuntimeError('albumentations is not installed')
obj_cls = getattr(albumentations, obj_type)
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')

if 'transforms' in args:
args['transforms'] = [
self.albu_builder(transform)
for transform in args['transforms']
]

return obj_cls(**args)

@staticmethod
def mapper(d, keymap):
"""Dictionary mapper. Renames keys according to keymap provided.
Args:
d (dict): old dict
keymap (dict): {'old_key':'new_key'}
Returns:
dict: new dict.
"""

updated_dict = {}
for k, v in zip(d.keys(), d.values()):
new_k = keymap.get(k, k)
updated_dict[new_k] = d[k]
return updated_dict

def __call__(self, results):
# dict to albumentations format
results = self.mapper(results, self.keymap_to_albu)

results = self.aug(**results)

if 'gt_labels' in results:
if isinstance(results['gt_labels'], list):
results['gt_labels'] = np.array(results['gt_labels'])
results['gt_labels'] = results['gt_labels'].astype(np.int64)

# back to the original format
results = self.mapper(results, self.keymap_back)

# update final shape
if self.update_pad_shape:
results['pad_shape'] = results['img'].shape

return results

def __repr__(self):
repr_str = self.__class__.__name__ + f'(transforms={self.transforms})'
return repr_str
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
-r requirements/runtime.txt
-r requirements/optional.txt
-r requirements/tests.txt
1 change: 1 addition & 0 deletions requirements/optional.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
albumentations>=0.3.2
24 changes: 24 additions & 0 deletions tests/test_pipelines/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,3 +780,27 @@ def reset_results(results, original_img):
flipped_img = np.array(flipped_img)
assert np.equal(results['img'], results['img2']).all()
assert np.equal(results['img'], flipped_img).all()


def test_albu_transform():
results = dict(
img_prefix=osp.join(osp.dirname(__file__), '../data'),
img_info=dict(filename='color.jpg'))

# Define simple pipeline
load = dict(type='LoadImageFromFile')
load = build_from_cfg(load, PIPELINES)

albu_transform = dict(
type='Albu', transforms=[dict(type='ChannelShuffle', p=1)])
albu_transform = build_from_cfg(albu_transform, PIPELINES)

normalize = dict(type='Normalize', mean=[0] * 3, std=[0] * 3, to_rgb=True)
normalize = build_from_cfg(normalize, PIPELINES)

# Execute transforms
results = load(results)
results = albu_transform(results)
results = normalize(results)

assert results['img'].dtype == np.float32

0 comments on commit 99115fd

Please sign in to comment.