Skip to content
This repository was archived by the owner on Nov 30, 2020. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
HaoranZhuExplorer committed Apr 6, 2020
1 parent 59279c3 commit 30dd500
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 72 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ Please see [AI Project 2 Report](/README.assets/AI_Project2_Report.pdf)

We followed the idea from *[Overcoming catastrophic forgetting in neural networks (EWC)](https://arxiv.org/abs/1612.00796)* and conducted experiments on *both* MINST *and* CORe50 datasets. Results shows that our model beats the Resnet-18 baseline by 2% on CORe50.

# Technical Information
# Running on MNIST and Fashion MNIST dataset
'python ewc_mnist.py'

# Running on CoRe50 dataset
## Setup

Follow the [CVPR Starter](https://github.com/vlomonaco/cvpr_clvision_challenge) Instruction
Expand All @@ -24,6 +26,7 @@ conda env create -f environment.yml
conda activate clvision-challenge
```


## Run

```bash
Expand Down
90 changes: 90 additions & 0 deletions ewc/EWC.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@

# -*- coding: utf-8 -*-
"""EWC.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/12cvrWSH0i5LYE8c4ATDfMP-g02fd5c0t
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torch import autograd

# hyperparameters
# loss function criterion
criterion = nn.NLLLoss()
# learning rate
lr = 0.001
# set how important old task is to new task
old_new_rate = 10000
# optimizer
optimizer_name = "Adam"
epoch = 1
train_batch_size = 64
device = "cuda"
log_interval = 1000
max_iter_per_epoch = 100

class EWC:
# initialize parameters
def __init__(self, model, old_new_rate):
self.model = model.to(device)
self.old_new_rate = old_new_rate
self.approximate_mean = 0
self.approximate_fisher_information_matrix = 0

# function to compute loss regarding to previous task, use an approximate mean and fisher matrix to simplify compute
def get_old_task_loss(self):
try:
losses = []
for param_name, param in self.model.named_parameters():

_buff_param_name = param_name.replace('.', '__')
estimated_mean = getattr(self.model, '{}_estimated_mean'.format(_buff_param_name))
estimated_fisher = getattr(self.model, '{}_estimated_fisher'.format(_buff_param_name))
losses.append((estimated_fisher * (param - estimated_mean) ** 2).sum())
return (old_new_rate / 2) * sum(losses)
except Exception:
return 0


# training given model with data
def train(self, data, target):
if optimizer_name =="Adam":
optimizer = optim.Adam(self.model.parameters(), lr=lr)
output = self.model(data).to(device)

optimizer.zero_grad()
loss_new_task = criterion(output, target)
loss_old_task = self.get_old_task_loss()
loss = loss_new_task + loss_old_task
loss.backward()
optimizer.step()

# update approximate mean and fisher information matrix
# use this function after training is over
def update(self, current_ds, batch_size, num_batch):
# update approximate mean
for param_name, param in self.model.named_parameters():
_buff_param_name = param_name.replace('.', '__')
self.model.register_buffer(_buff_param_name+'_estimated_mean', param.data.clone())

# update approximate fisher information matrix
dl = DataLoader(current_ds, batch_size, shuffle=True)
log_liklihoods = []
for i, (input, target) in enumerate(dl):
if i > num_batch:
break
input = input.to(device)
target = target.to(device)
self.model = self.model.to(device)
output = F.log_softmax(self.model(input), dim=1)
log_liklihoods.append(output[:, target])
log_likelihood = torch.cat(log_liklihoods).mean()
grad_log_liklihood = autograd.grad(log_likelihood, self.model.parameters())
_buff_param_names = [param[0].replace('.', '__') for param in self.model.named_parameters()]
for _buff_param_name, param in zip(_buff_param_names, grad_log_liklihood):
self.model.register_buffer(_buff_param_name+'_estimated_fisher', param.data.clone() ** 2)
64 changes: 0 additions & 64 deletions ewc/elastic_weight_consolidation.py

This file was deleted.

90 changes: 90 additions & 0 deletions ewc_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@

# -*- coding: utf-8 -*-
"""EWC.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/12cvrWSH0i5LYE8c4ATDfMP-g02fd5c0t
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torch import autograd

# hyperparameters
# loss function criterion
criterion = nn.NLLLoss()
# learning rate
lr = 0.001
# set how important old task is to new task
old_new_rate = 10000
# optimizer
optimizer_name = "Adam"
epoch = 1
train_batch_size = 64
device = "cuda"
log_interval = 1000
max_iter_per_epoch = 100

class EWC:
# initialize parameters
def __init__(self, model, old_new_rate):
self.model = model.to(device)
self.old_new_rate = old_new_rate
self.approximate_mean = 0
self.approximate_fisher_information_matrix = 0

# function to compute loss regarding to previous task, use an approximate mean and fisher matrix to simplify compute
def get_old_task_loss(self):
try:
losses = []
for param_name, param in self.model.named_parameters():

_buff_param_name = param_name.replace('.', '__')
estimated_mean = getattr(self.model, '{}_estimated_mean'.format(_buff_param_name))
estimated_fisher = getattr(self.model, '{}_estimated_fisher'.format(_buff_param_name))
losses.append((estimated_fisher * (param - estimated_mean) ** 2).sum())
return (old_new_rate / 2) * sum(losses)
except Exception:
return 0


# training given model with data
def train(self, data, target):
if optimizer_name =="Adam":
optimizer = optim.Adam(self.model.parameters(), lr=lr)
output = self.model(data).to(device)

optimizer.zero_grad()
loss_new_task = criterion(output, target)
loss_old_task = self.get_old_task_loss()
loss = loss_new_task + loss_old_task
loss.backward()
optimizer.step()

# update approximate mean and fisher information matrix
# use this function after training is over
def update(self, current_ds, batch_size, num_batch):
# update approximate mean
for param_name, param in self.model.named_parameters():
_buff_param_name = param_name.replace('.', '__')
self.model.register_buffer(_buff_param_name+'_estimated_mean', param.data.clone())

# update approximate fisher information matrix
dl = DataLoader(current_ds, batch_size, shuffle=True)
log_liklihoods = []
for i, (input, target) in enumerate(dl):
if i > num_batch:
break
input = input.to(device)
target = target.to(device)
self.model = self.model.to(device)
output = F.log_softmax(self.model(input), dim=1)
log_liklihoods.append(output[:, target])
log_likelihood = torch.cat(log_liklihoods).mean()
grad_log_liklihood = autograd.grad(log_likelihood, self.model.parameters())
_buff_param_names = [param[0].replace('.', '__') for param in self.model.named_parameters()]
for _buff_param_name, param in zip(_buff_param_names, grad_log_liklihood):
self.model.register_buffer(_buff_param_name+'_estimated_fisher', param.data.clone() ** 2)
17 changes: 12 additions & 5 deletions main_ewc.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
import torchvision.models as models
from utils.common import create_code_snapshot

from ewc.elastic_weight_consolidation import ElasticWeightConsolidation
from ewc.EWC import EWC




def main(args):

Expand All @@ -63,7 +66,7 @@ def main(args):

opt = torch.optim.SGD(classifier.parameters(), lr=args.lr)
criterion = torch.nn.CrossEntropyLoss()
ewc = ElasticWeightConsolidation(classifier, crit=criterion, lr=1e-4, weight=args.ewc_weight)
ewc = EWC(classifier, old_new_ratio=args.ewc_weight)

# vars to update over time
valid_acc = []
Expand Down Expand Up @@ -99,9 +102,10 @@ def main(args):
# train the classifier on the current batch/task
_, _, stats, preprocessed_dataset = train_net_ewc(
opt, ewc, criterion, args.batch_size, train_x, train_y, t,
args.epochs, preproc=preprocess_imgs
args.epochs, preproc=preprocess_imgs,
ewc_explosion_multr_cap=args.ewc_explosion_multr_cap,
)
ewc.register_ewc_params(preprocessed_dataset, args.batch_size, dataset.nbatch[dataset.scenario])
ewc.update(preprocessed_dataset, args.batch_size, dataset.nbatch[dataset.scenario])

if args.scenario == "multi-task-nc":
heads.append(copy.deepcopy(classifier.fc))
Expand Down Expand Up @@ -174,8 +178,11 @@ def main(args):
parser.add_argument('--replay_examples', type=int, default=0,
help='data examples to keep in memory for each batch '
'for replay.')
parser.add_argument('--ewc_weight', type=int, default=100,
parser.add_argument('--ewc_weight', type=int, default=50,
help='weight for elastic weight consolidation.')
parser.add_argument('--ewc_explosion_multr_cap', type=int, default=10,
help='limit max multiplier of ewc loss.')


# Misc
parser.add_argument('--sub_dir', type=str, default="multi-task-nc",
Expand Down
4 changes: 2 additions & 2 deletions utils/train_test_ewc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torch.autograd import Variable
from .common import pad_data, shuffle_in_unison, check_ext_mem, check_ram_usage

from ewc.elastic_weight_consolidation import ElasticWeightConsolidation
from ewc.EWC import EWC
from torch.utils.data import TensorDataset

def train_net_ewc(optimizer, ewc, criterion, mb_size, x, y, t,
Expand Down Expand Up @@ -103,7 +103,7 @@ def train_net_ewc(optimizer, ewc, criterion, mb_size, x, y, t,
_, pred_label = torch.max(logits, 1)
correct_cnt += (pred_label == y_mb).sum()
ewc_loss = maybe_cuda(torch.as_tensor(
ewc._compute_consolidation_loss(ewc.weight),
ewc.get_old_task_loss(),
dtype=torch.float32), use_cuda=use_cuda)
loss = criterion(logits, y_mb) + torch.min(
loss_explosion_cap * criterion(logits, y_mb), ewc_loss
Expand Down

0 comments on commit 30dd500

Please sign in to comment.