This repository was archived by the owner on Nov 30, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
751 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
{ | ||
"version": "0.2.0", | ||
"configurations": [ | ||
{ | ||
"name": "Baseline", | ||
"type": "python", | ||
"request": "launch", | ||
"program": "${workspaceFolder}/main.py", | ||
"args": [ | ||
"--scenario=multi-task-nc", | ||
"--sub_dir=baseline-debug" | ||
], | ||
"console": "integratedTerminal" | ||
}, | ||
{ | ||
"name": "EWC", | ||
"type": "python", | ||
"request": "launch", | ||
"program": "${workspaceFolder}/main_ewc.py", | ||
"args": [ | ||
"--scenario=multi-task-nc", | ||
"--sub_dir=ewc-debug" | ||
], | ||
"console": "integratedTerminal" | ||
} | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Setup | ||
|
||
```bash | ||
sh fetch_data_and_setup.sh | ||
conda env create -f environment.yml | ||
conda activate clvision-challenge | ||
sh create_submission.sh | ||
``` | ||
|
||
# Run | ||
|
||
```bash | ||
python main.py --scenario="multi-task-nc" --epochs="5" --sub_dir="baseline" | ||
python main_ewc.py --scenario="multi-task-nc" --epochs="5" --sub_dir="ewc" | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
from torch import autograd | ||
import numpy as np | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
class ElasticWeightConsolidation: | ||
|
||
def __init__(self, model, crit, lr=0.001, weight=1000000): | ||
self.model = model | ||
self.weight = weight | ||
self.crit = crit | ||
self.optimizer = optim.Adam(self.model.parameters(), lr) | ||
|
||
def _update_mean_params(self): | ||
for param_name, param in self.model.named_parameters(): | ||
_buff_param_name = param_name.replace('.', '__') | ||
self.model.register_buffer(_buff_param_name+'_estimated_mean', param.data.clone()) | ||
|
||
def _update_fisher_params(self, current_ds, batch_size, num_batch): | ||
dl = DataLoader(current_ds, batch_size, shuffle=True) | ||
log_liklihoods = [] | ||
for i, (input, target) in enumerate(dl): | ||
if i > num_batch: | ||
break | ||
output = F.log_softmax(self.model(input), dim=1) | ||
log_liklihoods.append(output[:, target]) | ||
log_likelihood = torch.cat(log_liklihoods).mean() | ||
grad_log_liklihood = autograd.grad(log_likelihood, self.model.parameters()) | ||
_buff_param_names = [param[0].replace('.', '__') for param in self.model.named_parameters()] | ||
for _buff_param_name, param in zip(_buff_param_names, grad_log_liklihood): | ||
self.model.register_buffer(_buff_param_name+'_estimated_fisher', param.data.clone() ** 2) | ||
|
||
def register_ewc_params(self, dataset, batch_size, num_batches): | ||
self._update_fisher_params(dataset, batch_size, num_batches) | ||
self._update_mean_params() | ||
|
||
def _compute_consolidation_loss(self, weight): | ||
try: | ||
losses = [] | ||
for param_name, param in self.model.named_parameters(): | ||
_buff_param_name = param_name.replace('.', '__') | ||
estimated_mean = getattr(self.model, '{}_estimated_mean'.format(_buff_param_name)) | ||
estimated_fisher = getattr(self.model, '{}_estimated_fisher'.format(_buff_param_name)) | ||
losses.append((estimated_fisher * (param - estimated_mean) ** 2).sum()) | ||
return (weight / 2) * sum(losses) | ||
except AttributeError: | ||
return 0 | ||
|
||
def forward_backward_update(self, input, target): | ||
output = self.model(input) | ||
loss = self._compute_consolidation_loss(self.weight) + self.crit(output, target) | ||
self.optimizer.zero_grad() | ||
loss.backward() | ||
self.optimizer.step() | ||
|
||
def save(self, filename): | ||
torch.save(self.model, filename) | ||
|
||
def load(self, filename): | ||
self.model = torch.load(filename) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
################################################################################ | ||
# Copyright (c) 2019. Vincenzo Lomonaco, Massimo Caccia, Pau Rodriguez, # | ||
# Lorenzo Pellegrini. All rights reserved. # | ||
# Copyrights licensed under the CC BY 4.0 License. # | ||
# See the accompanying LICENSE file for terms. # | ||
# # | ||
# Date: 1-02-2019 # | ||
# Author: Vincenzo Lomonaco # | ||
# E-mail: [email protected] # | ||
# Website: vincenzolomonaco.com # | ||
################################################################################ | ||
|
||
""" | ||
Getting Started example for the CVPR 2020 CLVision Challenge. It will load the | ||
data and create the submission file for you in the | ||
cvpr_clvision_challenge/submissions directory. | ||
""" | ||
|
||
# Python 2-3 compatible | ||
from __future__ import print_function | ||
from __future__ import division | ||
from __future__ import absolute_import | ||
|
||
import argparse | ||
import os | ||
import time | ||
import copy | ||
from core50.dataset import CORE50 | ||
import torch | ||
import numpy as np | ||
from utils.train_test import train_net, test_multitask, preprocess_imgs | ||
import torchvision.models as models | ||
from utils.common import create_code_snapshot | ||
|
||
|
||
def main(args): | ||
|
||
# print args recap | ||
print(args, end="\n\n") | ||
|
||
# do not remove this line | ||
start = time.time() | ||
|
||
# Create the dataset object for example with the "ni, multi-task-nc, or nic | ||
# tracks" and assuming the core50 location in ./core50/data/ | ||
dataset = CORE50(root='core50/data/', scenario=args.scenario, | ||
preload=args.preload_data) | ||
|
||
# Get the validation set | ||
print("Recovering validation set...") | ||
full_valdidset = dataset.get_full_valid_set() | ||
|
||
# model | ||
if args.classifier == 'ResNet18': | ||
classifier = models.resnet18(pretrained=True) | ||
classifier.fc = torch.nn.Linear(512, args.n_classes) | ||
|
||
opt = torch.optim.SGD(classifier.parameters(), lr=args.lr) | ||
criterion = torch.nn.CrossEntropyLoss() | ||
|
||
# vars to update over time | ||
valid_acc = [] | ||
ext_mem_sz = [] | ||
ram_usage = [] | ||
heads = [] | ||
ext_mem = None | ||
|
||
# loop over the training incremental batches (x, y, t) | ||
for i, train_batch in enumerate(dataset): | ||
train_x, train_y, t = train_batch | ||
|
||
# adding eventual replay patterns to the current batch | ||
idxs_cur = np.random.choice( | ||
train_x.shape[0], args.replay_examples, replace=False | ||
) | ||
|
||
if i == 0: | ||
ext_mem = [train_x[idxs_cur], train_y[idxs_cur]] | ||
else: | ||
ext_mem = [ | ||
np.concatenate((train_x[idxs_cur], ext_mem[0])), | ||
np.concatenate((train_y[idxs_cur], ext_mem[1]))] | ||
|
||
train_x = np.concatenate((train_x, ext_mem[0])) | ||
train_y = np.concatenate((train_y, ext_mem[1])) | ||
|
||
print("----------- batch {0} -------------".format(i)) | ||
print("x shape: {0}, y shape: {1}" | ||
.format(train_x.shape, train_y.shape)) | ||
print("Task Label: ", t) | ||
|
||
# train the classifier on the current batch/task | ||
_, _, stats = train_net( | ||
opt, classifier, criterion, args.batch_size, train_x, train_y, t, | ||
args.epochs, preproc=preprocess_imgs | ||
) | ||
if args.scenario == "multi-task-nc": | ||
heads.append(copy.deepcopy(classifier.fc)) | ||
|
||
# collect statistics | ||
ext_mem_sz += stats['disk'] | ||
ram_usage += stats['ram'] | ||
|
||
# test on the validation set | ||
stats, _ = test_multitask( | ||
classifier, full_valdidset, args.batch_size, | ||
preproc=preprocess_imgs, multi_heads=heads, verbose=False | ||
) | ||
|
||
valid_acc += stats['acc'] | ||
print("------------------------------------------") | ||
print("Avg. acc: {}".format(stats['acc'])) | ||
print("------------------------------------------") | ||
|
||
# Generate submission.zip | ||
# directory with the code snapshot to generate the results | ||
sub_dir = 'submissions/' + args.sub_dir | ||
if not os.path.exists(sub_dir): | ||
os.makedirs(sub_dir) | ||
|
||
# copy code | ||
create_code_snapshot(".", sub_dir + "/code_snapshot") | ||
|
||
# generating metadata.txt: with all the data used for the CLScore | ||
elapsed = (time.time() - start) / 60 | ||
print("Training Time: {}m".format(elapsed)) | ||
with open(sub_dir + "/metadata.txt", "w") as wf: | ||
for obj in [ | ||
np.average(valid_acc), elapsed, np.average(ram_usage), | ||
np.max(ram_usage), np.average(ext_mem_sz), np.max(ext_mem_sz) | ||
]: | ||
wf.write(str(obj) + "\n") | ||
print(f'Average Accuracy Over Time on the Validation Set: {np.average(valid_acc)}') | ||
print(f'Total Training/Test time: {elapsed} Minutes') | ||
print(f'Average RAM Usage: {np.average(ram_usage)} MB') | ||
print(f'Max RAM Usage: {np.max(ram_usage)} MB') | ||
|
||
print("Experiment completed.") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser('CVPR Continual Learning Challenge') | ||
|
||
# General | ||
parser.add_argument('--scenario', type=str, default="multi-task-nc", | ||
choices=['ni', 'multi-task-nc', 'nic']) | ||
parser.add_argument('--preload_data', type=bool, default=True, | ||
help='preload data into RAM') | ||
parser.add_argument('--no_preload_data', dest='preload_data', | ||
action='store_false') | ||
|
||
# Model | ||
parser.add_argument('-cls', '--classifier', type=str, default='ResNet18', | ||
choices=['ResNet18']) | ||
|
||
# Optimization | ||
parser.add_argument('--lr', type=float, default=0.01, | ||
help='learning rate') | ||
parser.add_argument('--batch_size', type=int, default=32, | ||
help='batch_size') | ||
parser.add_argument('--epochs', type=int, default=1, | ||
help='number of epochs') | ||
|
||
# Continual Learning | ||
parser.add_argument('--replay_examples', type=int, default=0, | ||
help='data examples to keep in memory for each batch ' | ||
'for replay.') | ||
|
||
# Misc | ||
parser.add_argument('--sub_dir', type=str, default="multi-task-nc", | ||
help='directory of the submission file for this exp.') | ||
|
||
args = parser.parse_args() | ||
args.n_classes = 50 | ||
args.input_size = [3, 128, 128] | ||
|
||
args.cuda = torch.cuda.is_available() | ||
args.device = 'cuda:0' if args.cuda else 'cpu' | ||
|
||
main(args) |
Oops, something went wrong.