Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Tokozume committed Dec 12, 2017
0 parents commit 7d2e623
Show file tree
Hide file tree
Showing 10 changed files with 534 additions and 0 deletions.
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

### PyCharm ###
\.idea/
67 changes: 67 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
BC learning for images
=========================

Implementation of [Between-class Learning for Image Classification](https://arxiv.org/abs/1711.10284) by Yuji Tokozume, Yoshitaka Ushiku, and Tatsuya Harada.

#### Between-class (BC) learning:
- We generate between-class examples by mixing two training examples belonging to different classes with a random ratio.
- We then input the mixed data to the model and
train the model to output the mixing ratio.
- Original paper: [Learning from Between-class Examples for Deep Sound Recognition](https://arxiv.org/abs/1711.10282) by us ([github](https://github.com/mil-tokyo/bc_learning_sound))

## Contents
- BC learning for images
- BC: mix two images simply using internal divisions.
- BC+: mix two images treating them as waveforms.
- Training of 11-layer CNN on CIFAR datasets


## Setup
- Install [Chainer](https://chainer.org/) v1.24 on a machine with CUDA GPU.
- Prepare CIFAR datasets.


## Training
- Template:

python main.py --dataset [cifar10 or cifar100] --netType convnet --data path/to/dataset/directory/ (--BC) (--plus)

- Recipes:
- Standard learning on CIFAR-10 (around 6.1% error):

python main.py --dataset cifar10 --netType convnet --data path/to/dataset/directory/


- BC learning on CIFAR-10 (around 5.4% error):

python main.py --dataset cifar10 --netType convnet --data path/to/dataset/directory/ --BC

- BC+ learning on CIFAR-10 (around 5.2% error):

python main.py --dataset cifar10 --netType convnet --data path/to/dataset/directory/ --BC --plus

- Notes:
- It uses the same data augmentation scheme as [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch).
- By default, it runs training 10 times. You can specify the number of trials by using --nTrials command.
- Please check [opts.py](https://github.com/mil-tokyo/bc_learning_image/blob/master/opts.py) for other command line arguments.

## Results

Error rate (average of 10 trials)

| Learning | CIFAR-10 | CIFAR-100 |
|:--|:-:|:-:|
| Standard | 6.07 | 26.68 |
| BC (ours) | 5.40 | 24.28 |
| BC+ (ours) | **5.22** | **23.68** |

- Other results (please see [paper](https://arxiv.org/abs/1711.10284)):
- The performance of [Shake-Shake Regularization](https://github.com/xgastaldi/shake-shake) [[1]](#1) on CIFAR-10 was improved from 2.86% to 2.26%.
- The performance of [ResNeXt](https://github.com/facebookresearch/ResNeXt) [[2]](#2) on ImageNet-1K was improved from 20.4% to 19.4% (single-crop top-1 validation error).

---

#### Reference
<i id=1></i>[1] X. Gastaldi. Shake-shake regularization. In *ICLR Workshop*, 2017.

<i id=2></i>[2] S. Xie, R. Girshick, P. Dollar, Z. Tu, and K. He. Aggregated residual transformations for deep neural networks. In *CVPR*, 2017.
117 changes: 117 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import os
import numpy as np
import random
import cPickle
import chainer

import utils as U


class ImageDataset(chainer.dataset.DatasetMixin):
def __init__(self, images, labels, opt, train=True):
self.base = chainer.datasets.TupleDataset(images, labels)
self.opt = opt
self.train = train
self.mix = (opt.BC and train)
if opt.dataset == 'cifar10':
if opt.plus:
self.mean = np.array([4.60, 2.24, -6.84])
self.std = np.array([55.9, 53.7, 56.5])
else:
self.mean = np.array([125.3, 123.0, 113.9])
self.std = np.array([63.0, 62.1, 66.7])
else:
if opt.plus:
self.mean = np.array([7.37, 2.13, -9.50])
self.std = np.array([57.6, 54.0, 58.5])
else:
self.mean = np.array([129.3, 124.1, 112.4])
self.std = np.array([68.2, 65.4, 70.4])

self.preprocess_funcs = self.preprocess_setup()

def __len__(self):
return len(self.base)

def preprocess_setup(self):
if self.opt.plus:
normalize = U.zero_mean
else:
normalize = U.normalize
if self.train:
funcs = [normalize(self.mean, self.std),
U.horizontal_flip(),
U.padding(4),
U.random_crop(32),
]
else:
funcs = [normalize(self.mean, self.std)]

return funcs

def preprocess(self, image):
for f in self.preprocess_funcs:
image = f(image)

return image

def get_example(self, i):
if self.mix: # Training phase of BC learning
while True: # Select two training examples
image1, label1 = self.base[random.randint(0, len(self.base) - 1)]
image2, label2 = self.base[random.randint(0, len(self.base) - 1)]
if label1 != label2:
break
image1 = self.preprocess(image1)
image2 = self.preprocess(image2)

# Mix two images
r = np.array(random.random())
if self.opt.plus:
g1 = np.std(image1)
g2 = np.std(image2)
p = 1.0 / (1 + g1 / g2 * (1 - r) / r)
image = ((image1 * p + image2 * (1 - p)) / np.sqrt(p ** 2 + (1 - p) ** 2)).astype(np.float32)
else:
image = (image1 * r + image2 * (1 - r)).astype(np.float32)

# Mix two labels
eye = np.eye(self.opt.nClasses)
label = (eye[label1] * r + eye[label2] * (1 - r)).astype(np.float32)

else: # Training phase of standard learning or testing phase
image, label = self.base[i]
image = self.preprocess(image).astype(np.float32)
label = np.array(label, dtype=np.int32)

return image, label


def setup(opt):
def unpickle(fn):
with open(fn, 'rb') as f:
data = cPickle.load(f)
return data

if opt.dataset == 'cifar10':
train = [unpickle(os.path.join(opt.data, 'data_batch_{}'.format(i))) for i in range(1, 6)]
train_images = np.concatenate([d['data'] for d in train]).reshape((-1, 3, 32, 32))
train_labels = np.concatenate([d['labels'] for d in train])
val = unpickle(os.path.join(opt.data, 'test_batch'))
val_images = val['data'].reshape((-1, 3, 32, 32))
val_labels = val['labels']
else:
train = unpickle(os.path.join(opt.data, 'train'))
train_images = train['data'].reshape(-1, 3, 32, 32)
train_labels = train['fine_labels']
val = unpickle(os.path.join(opt.data, 'test'))
val_images = val['data'].reshape((-1, 3, 32, 32))
val_labels = val['fine_labels']

# Iterator setup
train_data = ImageDataset(train_images, train_labels, opt, train=True)
val_data = ImageDataset(val_images, val_labels, opt, train=False)
train_iter = chainer.iterators.MultiprocessIterator(train_data, opt.batchSize, repeat=False)
val_iter = chainer.iterators.SerialIterator(val_data, opt.batchSize, repeat=False, shuffle=False)

return train_iter, val_iter
55 changes: 55 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
Between-class Learning for Image Classification.
Yuji Tokozume, Yoshitaka Ushiku, and Tatsuya Harada
"""

import sys

sys.path.append('/home/harada/local/envs/chainer/site-packages/')

import os
import datetime
import chainer

import opts
import models
import dataset
from train import Trainer


def main():
opt = opts.parse()
chainer.cuda.get_device_from_id(opt.gpu).use()
for i in range(1, opt.nTrials + 1):
print('+-- Trial {} --+'.format(i))
train(opt, i)


def train(opt, trial):
model = getattr(models, opt.netType)(opt.nClasses)
model.to_gpu()
optimizer = chainer.optimizers.NesterovAG(lr=opt.LR, momentum=opt.momentum)
optimizer.setup(model)
optimizer.add_hook(chainer.optimizer.WeightDecay(opt.weightDecay))
train_iter, val_iter = dataset.setup(opt)
trainer = Trainer(model, optimizer, train_iter, val_iter, opt)

for epoch in range(1, opt.nEpochs + 1):
train_loss, train_top1 = trainer.train(epoch)
val_top1 = trainer.val()
sys.stderr.write('\r\033[K')
sys.stdout.write(
'| Epoch: {}/{} | Train: LR {} Loss {:.3f} top1 {:.2f} | Val: top1 {:.2f}\n'.format(
epoch, opt.nEpochs, trainer.optimizer.lr, train_loss, train_top1, val_top1))
sys.stdout.flush()

if opt.save != 'None':
chainer.serializers.save_npz(
os.path.join(opt.save, 'model_trial{}.npz'.format(datetime.datetime.now())), model)
chainer.serializers.save_npz(
os.path.join(opt.save, 'optimizer_trial{}.npz'.format(datetime.datetime.now())), optimizer)


if __name__ == '__main__':
main()
1 change: 1 addition & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from convnet import ConvNet as convnet
19 changes: 19 additions & 0 deletions models/convbnrelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import chainer
import chainer.functions as F
import chainer.links as L


class ConvBNReLU(chainer.Chain):
def __init__(self, in_channels, out_channels, ksize, stride=1, pad=0,
initialW=chainer.initializers.HeNormal(), nobias=True):
super(ConvBNReLU, self).__init__(
conv=L.Convolution2D(in_channels, out_channels, ksize, stride, pad,
initialW=initialW, nobias=nobias),
bn=L.BatchNormalization(out_channels, eps=1e-5)
)

def __call__(self, x, train):
h = self.conv(x)
h = self.bn(h, test=not train)

return F.relu(h)
44 changes: 44 additions & 0 deletions models/convnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import math
import chainer
import chainer.functions as F
import chainer.links as L
from chainer.initializers import Uniform
from convbnrelu import ConvBNReLU


class ConvNet(chainer.Chain):
def __init__(self, n_classes):
super(ConvNet, self).__init__(
conv11=ConvBNReLU(3, 64, 3, pad=1),
conv12=ConvBNReLU(64, 64, 3, pad=1),
conv21=ConvBNReLU(64, 128, 3, pad=1),
conv22=ConvBNReLU(128, 128, 3, pad=1),
conv31=ConvBNReLU(128, 256, 3, pad=1),
conv32=ConvBNReLU(256, 256, 3, pad=1),
conv33=ConvBNReLU(256, 256, 3, pad=1),
conv34=ConvBNReLU(256, 256, 3, pad=1),
fc4=L.Linear(256 * 4 * 4, 1024, initialW=Uniform(1. / math.sqrt(256 * 4 * 4))),
fc5=L.Linear(1024, 1024, initialW=Uniform(1. / math.sqrt(1024))),
fc6=L.Linear(1024, n_classes, initialW=Uniform(1. / math.sqrt(1024)))
)
self.train = True

def __call__(self, x):
h = self.conv11(x, self.train)
h = self.conv12(h, self.train)
h = F.max_pooling_2d(h, 2)

h = self.conv21(h, self.train)
h = self.conv22(h, self.train)
h = F.max_pooling_2d(h, 2)

h = self.conv31(h, self.train)
h = self.conv32(h, self.train)
h = self.conv33(h, self.train)
h = self.conv34(h, self.train)
h = F.max_pooling_2d(h, 2)

h = F.dropout(F.relu(self.fc4(h)), train=self.train)
h = F.dropout(F.relu(self.fc5(h)), train=self.train)

return self.fc6(h)
77 changes: 77 additions & 0 deletions opts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
import argparse


def parse():
parser = argparse.ArgumentParser(description='BC learning for image classification')

# General settings
parser.add_argument('--dataset', required=True, choices=['cifar10', 'cifar100'])
parser.add_argument('--netType', required=True, choices=['convnet'])
parser.add_argument('--data', required=True, help='Path to dataset')
parser.add_argument('--nTrials', type=int, default=10)
parser.add_argument('--save', default='None', help='Directory to save the results')
parser.add_argument('--gpu', type=int, default=0)

# Learning settings
parser.add_argument('--BC', action='store_true', help='BC learning')
parser.add_argument('--plus', action='store_true', help='Use BC+')
parser.add_argument('--nEpochs', type=int, default=-1)
parser.add_argument('--LR', type=float, default=-1, help='Initial learning rate')
parser.add_argument('--schedule', type=float, nargs='*', default=-1, help='When to divide the LR')
parser.add_argument('--warmup', type=int, default=-1, help='Number of epochs to warm up')
parser.add_argument('--batchSize', type=int, default=-1)
parser.add_argument('--weightDecay', type=float, default=5e-4)
parser.add_argument('--momentum', type=float, default=0.9)

opt = parser.parse_args()
if opt.plus and not opt.BC:
raise Exception('Using only --plus option is invalid.')

# Dataset details
if opt.dataset == 'cifar10':
opt.nClasses = 10
else: # cifar100
opt.nClasses = 100

# Default settings
default_settings = dict()
default_settings['cifar10'] = {
'convnet': {'nEpochs': 250, 'LR': 0.1, 'schedule': [0.4, 0.6, 0.8], 'warmup': 0, 'batchSize': 128}
}
default_settings['cifar100'] = {
'convnet': {'nEpochs': 250, 'LR': 0.1, 'schedule': [0.4, 0.6, 0.8], 'warmup': 0, 'batchSize': 128}
}
for key in ['nEpochs', 'LR', 'schedule', 'warmup', 'batchSize']:
if eval('opt.{}'.format(key)) == -1:
setattr(opt, key, default_settings[opt.dataset][opt.netType][key])

if opt.save != 'None' and not os.path.isdir(opt.save):
os.makedirs(opt.save)

display_info(opt)

return opt


def display_info(opt):
if opt.BC:
if opt.plus:
learning = 'BC+'
else:
learning = 'BC'
else:
learning = 'standard'

print('+------------------------------+')
print('| CIFAR classification')
print('+------------------------------+')
print('| dataset : {}'.format(opt.dataset))
print('| netType : {}'.format(opt.netType))
print('| learning : {}'.format(learning))
print('| nEpochs : {}'.format(opt.nEpochs))
print('| LRInit : {}'.format(opt.LR))
print('| schedule : {}'.format(opt.schedule))
print('| warmup : {}'.format(opt.warmup))
print('| batchSize: {}'.format(opt.batchSize))
print('+------------------------------+')
Loading

0 comments on commit 7d2e623

Please sign in to comment.