diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..f236977
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,8 @@
+### Python ###
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+### PyCharm ###
+\.idea/
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..7603f1c
--- /dev/null
+++ b/README.md
@@ -0,0 +1,67 @@
+BC learning for images
+=========================
+
+Implementation of [Between-class Learning for Image Classification](https://arxiv.org/abs/1711.10284) by Yuji Tokozume, Yoshitaka Ushiku, and Tatsuya Harada.
+
+#### Between-class (BC) learning:
+- We generate between-class examples by mixing two training examples belonging to different classes with a random ratio.
+- We then input the mixed data to the model and
+train the model to output the mixing ratio.
+- Original paper: [Learning from Between-class Examples for Deep Sound Recognition](https://arxiv.org/abs/1711.10282) by us ([github](https://github.com/mil-tokyo/bc_learning_sound))
+
+## Contents
+- BC learning for images
+ - BC: mix two images simply using internal divisions.
+ - BC+: mix two images treating them as waveforms.
+- Training of 11-layer CNN on CIFAR datasets
+
+
+## Setup
+- Install [Chainer](https://chainer.org/) v1.24 on a machine with CUDA GPU.
+- Prepare CIFAR datasets.
+
+
+## Training
+- Template:
+
+ python main.py --dataset [cifar10 or cifar100] --netType convnet --data path/to/dataset/directory/ (--BC) (--plus)
+
+- Recipes:
+ - Standard learning on CIFAR-10 (around 6.1% error):
+
+ python main.py --dataset cifar10 --netType convnet --data path/to/dataset/directory/
+
+
+ - BC learning on CIFAR-10 (around 5.4% error):
+
+ python main.py --dataset cifar10 --netType convnet --data path/to/dataset/directory/ --BC
+
+ - BC+ learning on CIFAR-10 (around 5.2% error):
+
+ python main.py --dataset cifar10 --netType convnet --data path/to/dataset/directory/ --BC --plus
+
+- Notes:
+ - It uses the same data augmentation scheme as [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch).
+ - By default, it runs training 10 times. You can specify the number of trials by using --nTrials command.
+ - Please check [opts.py](https://github.com/mil-tokyo/bc_learning_image/blob/master/opts.py) for other command line arguments.
+
+## Results
+
+Error rate (average of 10 trials)
+
+| Learning | CIFAR-10 | CIFAR-100 |
+|:--|:-:|:-:|
+| Standard | 6.07 | 26.68 |
+| BC (ours) | 5.40 | 24.28 |
+| BC+ (ours) | **5.22** | **23.68** |
+
+- Other results (please see [paper](https://arxiv.org/abs/1711.10284)):
+ - The performance of [Shake-Shake Regularization](https://github.com/xgastaldi/shake-shake) [[1]](#1) on CIFAR-10 was improved from 2.86% to 2.26%.
+ - The performance of [ResNeXt](https://github.com/facebookresearch/ResNeXt) [[2]](#2) on ImageNet-1K was improved from 20.4% to 19.4% (single-crop top-1 validation error).
+
+---
+
+#### Reference
+[1] X. Gastaldi. Shake-shake regularization. In *ICLR Workshop*, 2017.
+
+[2] S. Xie, R. Girshick, P. Dollar, Z. Tu, and K. He. Aggregated residual transformations for deep neural networks. In *CVPR*, 2017.
\ No newline at end of file
diff --git a/dataset.py b/dataset.py
new file mode 100644
index 0000000..eb22a09
--- /dev/null
+++ b/dataset.py
@@ -0,0 +1,117 @@
+import os
+import numpy as np
+import random
+import cPickle
+import chainer
+
+import utils as U
+
+
+class ImageDataset(chainer.dataset.DatasetMixin):
+ def __init__(self, images, labels, opt, train=True):
+ self.base = chainer.datasets.TupleDataset(images, labels)
+ self.opt = opt
+ self.train = train
+ self.mix = (opt.BC and train)
+ if opt.dataset == 'cifar10':
+ if opt.plus:
+ self.mean = np.array([4.60, 2.24, -6.84])
+ self.std = np.array([55.9, 53.7, 56.5])
+ else:
+ self.mean = np.array([125.3, 123.0, 113.9])
+ self.std = np.array([63.0, 62.1, 66.7])
+ else:
+ if opt.plus:
+ self.mean = np.array([7.37, 2.13, -9.50])
+ self.std = np.array([57.6, 54.0, 58.5])
+ else:
+ self.mean = np.array([129.3, 124.1, 112.4])
+ self.std = np.array([68.2, 65.4, 70.4])
+
+ self.preprocess_funcs = self.preprocess_setup()
+
+ def __len__(self):
+ return len(self.base)
+
+ def preprocess_setup(self):
+ if self.opt.plus:
+ normalize = U.zero_mean
+ else:
+ normalize = U.normalize
+ if self.train:
+ funcs = [normalize(self.mean, self.std),
+ U.horizontal_flip(),
+ U.padding(4),
+ U.random_crop(32),
+ ]
+ else:
+ funcs = [normalize(self.mean, self.std)]
+
+ return funcs
+
+ def preprocess(self, image):
+ for f in self.preprocess_funcs:
+ image = f(image)
+
+ return image
+
+ def get_example(self, i):
+ if self.mix: # Training phase of BC learning
+ while True: # Select two training examples
+ image1, label1 = self.base[random.randint(0, len(self.base) - 1)]
+ image2, label2 = self.base[random.randint(0, len(self.base) - 1)]
+ if label1 != label2:
+ break
+ image1 = self.preprocess(image1)
+ image2 = self.preprocess(image2)
+
+ # Mix two images
+ r = np.array(random.random())
+ if self.opt.plus:
+ g1 = np.std(image1)
+ g2 = np.std(image2)
+ p = 1.0 / (1 + g1 / g2 * (1 - r) / r)
+ image = ((image1 * p + image2 * (1 - p)) / np.sqrt(p ** 2 + (1 - p) ** 2)).astype(np.float32)
+ else:
+ image = (image1 * r + image2 * (1 - r)).astype(np.float32)
+
+ # Mix two labels
+ eye = np.eye(self.opt.nClasses)
+ label = (eye[label1] * r + eye[label2] * (1 - r)).astype(np.float32)
+
+ else: # Training phase of standard learning or testing phase
+ image, label = self.base[i]
+ image = self.preprocess(image).astype(np.float32)
+ label = np.array(label, dtype=np.int32)
+
+ return image, label
+
+
+def setup(opt):
+ def unpickle(fn):
+ with open(fn, 'rb') as f:
+ data = cPickle.load(f)
+ return data
+
+ if opt.dataset == 'cifar10':
+ train = [unpickle(os.path.join(opt.data, 'data_batch_{}'.format(i))) for i in range(1, 6)]
+ train_images = np.concatenate([d['data'] for d in train]).reshape((-1, 3, 32, 32))
+ train_labels = np.concatenate([d['labels'] for d in train])
+ val = unpickle(os.path.join(opt.data, 'test_batch'))
+ val_images = val['data'].reshape((-1, 3, 32, 32))
+ val_labels = val['labels']
+ else:
+ train = unpickle(os.path.join(opt.data, 'train'))
+ train_images = train['data'].reshape(-1, 3, 32, 32)
+ train_labels = train['fine_labels']
+ val = unpickle(os.path.join(opt.data, 'test'))
+ val_images = val['data'].reshape((-1, 3, 32, 32))
+ val_labels = val['fine_labels']
+
+ # Iterator setup
+ train_data = ImageDataset(train_images, train_labels, opt, train=True)
+ val_data = ImageDataset(val_images, val_labels, opt, train=False)
+ train_iter = chainer.iterators.MultiprocessIterator(train_data, opt.batchSize, repeat=False)
+ val_iter = chainer.iterators.SerialIterator(val_data, opt.batchSize, repeat=False, shuffle=False)
+
+ return train_iter, val_iter
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..4cf9037
--- /dev/null
+++ b/main.py
@@ -0,0 +1,55 @@
+"""
+ Between-class Learning for Image Classification.
+ Yuji Tokozume, Yoshitaka Ushiku, and Tatsuya Harada
+
+"""
+
+import sys
+
+sys.path.append('/home/harada/local/envs/chainer/site-packages/')
+
+import os
+import datetime
+import chainer
+
+import opts
+import models
+import dataset
+from train import Trainer
+
+
+def main():
+ opt = opts.parse()
+ chainer.cuda.get_device_from_id(opt.gpu).use()
+ for i in range(1, opt.nTrials + 1):
+ print('+-- Trial {} --+'.format(i))
+ train(opt, i)
+
+
+def train(opt, trial):
+ model = getattr(models, opt.netType)(opt.nClasses)
+ model.to_gpu()
+ optimizer = chainer.optimizers.NesterovAG(lr=opt.LR, momentum=opt.momentum)
+ optimizer.setup(model)
+ optimizer.add_hook(chainer.optimizer.WeightDecay(opt.weightDecay))
+ train_iter, val_iter = dataset.setup(opt)
+ trainer = Trainer(model, optimizer, train_iter, val_iter, opt)
+
+ for epoch in range(1, opt.nEpochs + 1):
+ train_loss, train_top1 = trainer.train(epoch)
+ val_top1 = trainer.val()
+ sys.stderr.write('\r\033[K')
+ sys.stdout.write(
+ '| Epoch: {}/{} | Train: LR {} Loss {:.3f} top1 {:.2f} | Val: top1 {:.2f}\n'.format(
+ epoch, opt.nEpochs, trainer.optimizer.lr, train_loss, train_top1, val_top1))
+ sys.stdout.flush()
+
+ if opt.save != 'None':
+ chainer.serializers.save_npz(
+ os.path.join(opt.save, 'model_trial{}.npz'.format(datetime.datetime.now())), model)
+ chainer.serializers.save_npz(
+ os.path.join(opt.save, 'optimizer_trial{}.npz'.format(datetime.datetime.now())), optimizer)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000..03821fc
--- /dev/null
+++ b/models/__init__.py
@@ -0,0 +1 @@
+from convnet import ConvNet as convnet
diff --git a/models/convbnrelu.py b/models/convbnrelu.py
new file mode 100644
index 0000000..a199ba2
--- /dev/null
+++ b/models/convbnrelu.py
@@ -0,0 +1,19 @@
+import chainer
+import chainer.functions as F
+import chainer.links as L
+
+
+class ConvBNReLU(chainer.Chain):
+ def __init__(self, in_channels, out_channels, ksize, stride=1, pad=0,
+ initialW=chainer.initializers.HeNormal(), nobias=True):
+ super(ConvBNReLU, self).__init__(
+ conv=L.Convolution2D(in_channels, out_channels, ksize, stride, pad,
+ initialW=initialW, nobias=nobias),
+ bn=L.BatchNormalization(out_channels, eps=1e-5)
+ )
+
+ def __call__(self, x, train):
+ h = self.conv(x)
+ h = self.bn(h, test=not train)
+
+ return F.relu(h)
diff --git a/models/convnet.py b/models/convnet.py
new file mode 100644
index 0000000..d3d759c
--- /dev/null
+++ b/models/convnet.py
@@ -0,0 +1,44 @@
+import math
+import chainer
+import chainer.functions as F
+import chainer.links as L
+from chainer.initializers import Uniform
+from convbnrelu import ConvBNReLU
+
+
+class ConvNet(chainer.Chain):
+ def __init__(self, n_classes):
+ super(ConvNet, self).__init__(
+ conv11=ConvBNReLU(3, 64, 3, pad=1),
+ conv12=ConvBNReLU(64, 64, 3, pad=1),
+ conv21=ConvBNReLU(64, 128, 3, pad=1),
+ conv22=ConvBNReLU(128, 128, 3, pad=1),
+ conv31=ConvBNReLU(128, 256, 3, pad=1),
+ conv32=ConvBNReLU(256, 256, 3, pad=1),
+ conv33=ConvBNReLU(256, 256, 3, pad=1),
+ conv34=ConvBNReLU(256, 256, 3, pad=1),
+ fc4=L.Linear(256 * 4 * 4, 1024, initialW=Uniform(1. / math.sqrt(256 * 4 * 4))),
+ fc5=L.Linear(1024, 1024, initialW=Uniform(1. / math.sqrt(1024))),
+ fc6=L.Linear(1024, n_classes, initialW=Uniform(1. / math.sqrt(1024)))
+ )
+ self.train = True
+
+ def __call__(self, x):
+ h = self.conv11(x, self.train)
+ h = self.conv12(h, self.train)
+ h = F.max_pooling_2d(h, 2)
+
+ h = self.conv21(h, self.train)
+ h = self.conv22(h, self.train)
+ h = F.max_pooling_2d(h, 2)
+
+ h = self.conv31(h, self.train)
+ h = self.conv32(h, self.train)
+ h = self.conv33(h, self.train)
+ h = self.conv34(h, self.train)
+ h = F.max_pooling_2d(h, 2)
+
+ h = F.dropout(F.relu(self.fc4(h)), train=self.train)
+ h = F.dropout(F.relu(self.fc5(h)), train=self.train)
+
+ return self.fc6(h)
diff --git a/opts.py b/opts.py
new file mode 100644
index 0000000..ecb2f90
--- /dev/null
+++ b/opts.py
@@ -0,0 +1,77 @@
+import os
+import argparse
+
+
+def parse():
+ parser = argparse.ArgumentParser(description='BC learning for image classification')
+
+ # General settings
+ parser.add_argument('--dataset', required=True, choices=['cifar10', 'cifar100'])
+ parser.add_argument('--netType', required=True, choices=['convnet'])
+ parser.add_argument('--data', required=True, help='Path to dataset')
+ parser.add_argument('--nTrials', type=int, default=10)
+ parser.add_argument('--save', default='None', help='Directory to save the results')
+ parser.add_argument('--gpu', type=int, default=0)
+
+ # Learning settings
+ parser.add_argument('--BC', action='store_true', help='BC learning')
+ parser.add_argument('--plus', action='store_true', help='Use BC+')
+ parser.add_argument('--nEpochs', type=int, default=-1)
+ parser.add_argument('--LR', type=float, default=-1, help='Initial learning rate')
+ parser.add_argument('--schedule', type=float, nargs='*', default=-1, help='When to divide the LR')
+ parser.add_argument('--warmup', type=int, default=-1, help='Number of epochs to warm up')
+ parser.add_argument('--batchSize', type=int, default=-1)
+ parser.add_argument('--weightDecay', type=float, default=5e-4)
+ parser.add_argument('--momentum', type=float, default=0.9)
+
+ opt = parser.parse_args()
+ if opt.plus and not opt.BC:
+ raise Exception('Using only --plus option is invalid.')
+
+ # Dataset details
+ if opt.dataset == 'cifar10':
+ opt.nClasses = 10
+ else: # cifar100
+ opt.nClasses = 100
+
+ # Default settings
+ default_settings = dict()
+ default_settings['cifar10'] = {
+ 'convnet': {'nEpochs': 250, 'LR': 0.1, 'schedule': [0.4, 0.6, 0.8], 'warmup': 0, 'batchSize': 128}
+ }
+ default_settings['cifar100'] = {
+ 'convnet': {'nEpochs': 250, 'LR': 0.1, 'schedule': [0.4, 0.6, 0.8], 'warmup': 0, 'batchSize': 128}
+ }
+ for key in ['nEpochs', 'LR', 'schedule', 'warmup', 'batchSize']:
+ if eval('opt.{}'.format(key)) == -1:
+ setattr(opt, key, default_settings[opt.dataset][opt.netType][key])
+
+ if opt.save != 'None' and not os.path.isdir(opt.save):
+ os.makedirs(opt.save)
+
+ display_info(opt)
+
+ return opt
+
+
+def display_info(opt):
+ if opt.BC:
+ if opt.plus:
+ learning = 'BC+'
+ else:
+ learning = 'BC'
+ else:
+ learning = 'standard'
+
+ print('+------------------------------+')
+ print('| CIFAR classification')
+ print('+------------------------------+')
+ print('| dataset : {}'.format(opt.dataset))
+ print('| netType : {}'.format(opt.netType))
+ print('| learning : {}'.format(learning))
+ print('| nEpochs : {}'.format(opt.nEpochs))
+ print('| LRInit : {}'.format(opt.LR))
+ print('| schedule : {}'.format(opt.schedule))
+ print('| warmup : {}'.format(opt.warmup))
+ print('| batchSize: {}'.format(opt.batchSize))
+ print('+------------------------------+')
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..2ef9995
--- /dev/null
+++ b/train.py
@@ -0,0 +1,82 @@
+import sys
+import numpy as np
+import chainer
+from chainer import cuda
+import chainer.functions as F
+import time
+
+import utils
+
+
+class Trainer:
+ def __init__(self, model, optimizer, train_iter, val_iter, opt):
+ self.model = model
+ self.optimizer = optimizer
+ self.train_iter = train_iter
+ self.val_iter = val_iter
+ self.opt = opt
+ self.n_batches = (len(train_iter.dataset) - 1) // opt.batchSize + 1
+ self.start_time = time.time()
+
+ def train(self, epoch):
+ self.optimizer.lr = self.lr_schedule(epoch)
+ train_loss = 0
+ train_acc = 0
+ for i, batch in enumerate(self.train_iter):
+ x_array, t_array = chainer.dataset.concat_examples(batch)
+ x = chainer.Variable(cuda.to_gpu(x_array))
+ t = chainer.Variable(cuda.to_gpu(t_array))
+ self.optimizer.zero_grads()
+ y = self.model(x)
+ if self.opt.BC:
+ loss = utils.kl_divergence(y, t)
+ acc = F.accuracy(y, F.argmax(t, axis=1))
+ else:
+ loss = F.softmax_cross_entropy(y, t)
+ acc = F.accuracy(y, t)
+
+ loss.backward()
+ self.optimizer.update()
+ train_loss += float(loss.data) * len(t.data)
+ train_acc += float(acc.data) * len(t.data)
+
+ elapsed_time = time.time() - self.start_time
+ progress = (self.n_batches * (epoch - 1) + i + 1) * 1.0 / (self.n_batches * self.opt.nEpochs)
+ eta = elapsed_time / progress - elapsed_time
+
+ line = '* Epoch: {}/{} ({}/{}) | Train: LR {} | Time: {} (ETA: {})'.format(
+ epoch, self.opt.nEpochs, i + 1, self.n_batches,
+ self.optimizer.lr, utils.to_hms(elapsed_time), utils.to_hms(eta))
+ sys.stderr.write('\r\033[K' + line)
+ sys.stderr.flush()
+
+ self.train_iter.reset()
+ train_loss /= len(self.train_iter.dataset)
+ train_top1 = 100 * (1 - train_acc / len(self.train_iter.dataset))
+
+ return train_loss, train_top1
+
+ def val(self):
+ self.model.train = False
+ val_acc = 0
+ for batch in self.val_iter:
+ x_array, t_array = chainer.dataset.concat_examples(batch)
+ x = chainer.Variable(cuda.to_gpu(x_array), volatile=True)
+ t = chainer.Variable(cuda.to_gpu(t_array), volatile=True)
+ y = F.softmax(self.model(x))
+ acc = F.accuracy(y, t)
+ val_acc += float(acc.data) * len(t.data)
+
+ self.val_iter.reset()
+ self.model.train = True
+ val_top1 = 100 * (1 - val_acc / len(self.val_iter.dataset))
+
+ return val_top1
+
+ def lr_schedule(self, epoch):
+ divide_epoch = np.array([self.opt.nEpochs * i for i in self.opt.schedule])
+ decay = sum(epoch > divide_epoch)
+ if epoch <= self.opt.warmup:
+ decay = 1
+
+ return self.opt.LR * np.power(0.1, decay)
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..d0ec32f
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,64 @@
+import numpy as np
+import random
+import chainer.functions as F
+
+
+def padding(pad):
+ def f(image):
+ return np.pad(image, ((0, 0), (pad, pad), (pad, pad)), 'constant')
+
+ return f
+
+
+def random_crop(size):
+ def f(image):
+ _, h, w = image.shape
+ p = random.randint(0, h - size)
+ q = random.randint(0, w - size)
+ return image[:, p: p + size, q: q + size]
+
+ return f
+
+
+def horizontal_flip():
+ def f(image):
+ if random.randint(0, 1):
+ image = image[:, :, ::-1]
+ return image
+
+ return f
+
+
+def normalize(mean, std):
+ def f(image):
+ return (image - mean[:, None, None]) / std[:, None, None]
+
+ return f
+
+
+# For BC+
+def zero_mean(mean, std):
+ def f(image):
+ image_mean = np.mean(image, keepdims=True)
+ return (image - image_mean - mean[:, None, None]) / std[:, None, None]
+
+ return f
+
+
+def kl_divergence(y, t):
+ entropy = - F.sum(t[t.data.nonzero()] * F.log(t[t.data.nonzero()]))
+ crossEntropy = - F.sum(t * F.log_softmax(y))
+
+ return (crossEntropy - entropy) / y.shape[0]
+
+
+def to_hms(time):
+ h = int(time // 3600)
+ m = int((time - h * 3600) // 60)
+ s = int(time - h * 3600 - m * 60)
+ if h > 0:
+ line = '{}h{:02d}m'.format(h, m)
+ else:
+ line = '{}m{:02d}s'.format(m, s)
+
+ return line