Skip to content

Commit

Permalink
Experiment with layer-wise fine-tuning on ImageNet
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Aug 31, 2019
1 parent e94ae89 commit 0121997
Show file tree
Hide file tree
Showing 2 changed files with 485 additions and 3 deletions.
121 changes: 118 additions & 3 deletions cnn/imagenet/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
import torchvision.models as models
from . import logger as log
# from . import resnet as models
from . import utils

from cnn.mobilenet_imagenet import MobileNet
from cnn.mobilenet_imagenet import Butterfly1x1Conv

try:
from apex.parallel import DistributedDataParallel as DDP
Expand Down Expand Up @@ -282,8 +284,8 @@ def train(train_loader, model_and_loss, optimizer, lr_scheduler, fp16, logger, e

def get_val_step(model_and_loss):
def _step(input, target):
input_var = Variable(input)
target_var = Variable(target)
input_var = input
target_var = target

with torch.no_grad():
loss, output = model_and_loss(input_var, target_var)
Expand Down Expand Up @@ -391,3 +393,116 @@ def train_loop(model_and_loss, optimizer, lr_scheduler, train_loader, val_loader
'optimizer' : optimizer.state_dict(),
}, is_best, checkpoint_dir=checkpoint_dir, backup_filename=backup_filename)
# }}}

def get_input_cov(model, train_loader, layer_names, max_batches=None):
# hook to capture intermediate inputs
def hook(module, input):
x, = input
b, c, h, w = x.shape
x = x.permute(0, 2, 3, 1).reshape(b * h * w, c)
if not hasattr(module, '_count'):
module._count = 1
else:
module._count += 1
# Compute the first moment E[X], averaged over batches.
current_mean = x.mean(dim=0)
if not hasattr(module, '_mean'):
module._mean = current_mean
else:
module._mean += (current_mean - module._mean) / module._count
# Compute the covariance (actually 2nd moment) E[X^T X], averaged over batches.
current_cov = (x.t() @ x) / x.shape[0]
if not hasattr(module, '_cov'):
module._cov = current_cov
else:
module._cov += (current_cov - module._cov) / module._count

module_dict = dict(model.named_modules())
hook_handles = []
for layer_name in layer_names:
hook_handles.append(module_dict[layer_name].register_forward_pre_hook(hook))

model.eval()
data_iter = enumerate(train_loader)
for batch_idx, (input, _) in data_iter:
with torch.no_grad():
model(input)
if max_batches is not None and batch_idx >= max_batches:
if hasattr(train_loader, 'dalipipeline'):
train_loader.dalipipeline.reset()
break
for h in hook_handles:
h.remove()
# mean = {layer_name + '.mean': module_dict[layer_name]._mean for layer_name in layer_names}
cov = {layer_name: module_dict[layer_name]._cov for layer_name in layer_names}
if torch.distributed.is_initialized():
cov = {layer_name: utils.reduce_tensor(c.data) for layer_name, c in cov.items()}
return cov


def butterfly_projection_cov(teacher_module, input_cov, butterfly_structure='odo_1',
n_Adam_steps=20000, n_LBFGS_steps=50):
try:
in_channels = teacher_module.in_channels
out_channels = teacher_module.out_channels
except:
raise ValueError("Only convolutional layers currently supported.")
param = butterfly_structure.split('_')[0]
residual = param.endswith('res')
param = param.replace('res', '')
nblocks = 0 if len(butterfly_structure.split('_')) <= 1 else int(butterfly_structure.split('_')[1])
student_module = Butterfly1x1Conv(in_channels, out_channels,
bias=False, tied_weight=False, ortho_init=True,
param=param, nblocks=nblocks)
student_module = student_module.to(input_cov.device)

with torch.no_grad():
Sigma, U = torch.symeig(input_cov, eigenvectors=True)
Sigma = Sigma.clamp(0) # avoid small negative eigenvalues
input = torch.diag(Sigma.sqrt()) @ U.t()
input = input.reshape(in_channels, in_channels, 1, 1) # To be compatible with conv2d
target = teacher_module(input)
# Normalize input so that output has MSE 1.0
input /= (target ** 2).mean().sqrt()
target = teacher_module(input)

def loss_fn():
output = student_module(input)
if residual:
if output.shape[1] == 2 * input.shape[1]:
b, c, h, w = input.shape
output = (output.reshape(b, 2, c, h, w) + input.reshape(b, 1, c, h, w)).reshape(b, 2 * c, h, w)
else:
output = output + input
return F.mse_loss(output, target)

optimizer = optim.Adam(student_module.parameters())
student_module.train()
for _ in range(n_Adam_steps):
optimizer.zero_grad()
loss = loss_fn()
loss.backward()
optimizer.step()

optimizer = optim.LBFGS(filter(lambda p: p.requires_grad, student_module.parameters()),
tolerance_grad=1e-7, # Pytorch 1.2 sets this too high https://github.com/pytorch/pytorch/pull/25240
line_search_fn='strong_wolfe')
def closure():
optimizer.zero_grad()
loss = loss_fn()
loss.backward()
return loss
for i in range(n_LBFGS_steps):
loss = optimizer.step(closure)

if torch.distributed.is_initialized():
# Get the model from the process with the lowest loss.
# Losses could be different due to different initialization of student_module.
all_losses = [torch.empty_like(loss) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(all_losses, loss)
best_rank = min(range(len(all_losses)), key=lambda i: all_losses[i])
loss = all_losses[best_rank]
for p in student_module.parameters():
torch.distributed.broadcast(p, best_rank)

return student_module, loss.item()
Loading

0 comments on commit 0121997

Please sign in to comment.