Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature replay and feature distillation (CIL Exfree) #1535

Merged
merged 8 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions avalanche/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from avalanche.benchmarks.utils import _make_taskaware_classification_dataset
from avalanche.models.dynamic_modules import MultiTaskModule, DynamicModule
from collections import OrderedDict

import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from collections import OrderedDict

from avalanche.benchmarks.scenarios import CLExperience
from avalanche.benchmarks.utils import _make_taskaware_classification_dataset
from avalanche.models.dynamic_modules import DynamicModule, MultiTaskModule


def is_multi_task_module(model: nn.Module) -> bool:
Expand Down Expand Up @@ -75,6 +76,23 @@ def add_hooks(self, model):
)


class FeatureExtractorModel(nn.Module):
"""
Feature extractor that additionnaly stores the features
"""

def __init__(self, feature_extractor, train_classifier):
super().__init__()
self.feature_extractor = feature_extractor
self.train_classifier = train_classifier
self.features = None

def forward(self, x):
self.features = self.feature_extractor(x)
x = self.train_classifier(self.features)
return x


class Flatten(nn.Module):
"""
Simple nn.Module to flatten each tensor of a batch of tensors.
Expand Down Expand Up @@ -118,4 +136,10 @@ def forward(self, x):
return self.mlp(x)


__all__ = ["avalanche_forward", "FeatureExtractorBackbone", "MLP", "Flatten"]
__all__ = [
"avalanche_forward",
"FeatureExtractorBackbone",
"MLP",
"Flatten",
"FeatureExtractorModel",
]
1 change: 1 addition & 0 deletions avalanche/training/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from .rar import RARPlugin
from .update_ncm import *
from .update_fecam import *
from .feature_distillation import *
62 changes: 62 additions & 0 deletions avalanche/training/plugins/feature_distillation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/usr/bin/env python3
import copy
from typing import Dict

import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from torch import Tensor, nn

from avalanche.models.utils import avalanche_forward
from avalanche.training.plugins import SupervisedPlugin
from avalanche.training.storage_policy import ClassBalancedBuffer
from avalanche.training.utils import _at_task_boundary, cycle


class FeatureDistillationPlugin(SupervisedPlugin):
def __init__(self, alpha=1, mode="cosine"):
"""
Adds a Distillation loss term on the features of the model,
trying to maximize the cosine similarity between current and old features

:param alpha: distillation hyperparameter. It can be either a float
number or a list containing alpha for each experience.
"""
super().__init__()
self.alpha = alpha
self.prev_model = None
assert mode in ["mse", "cosine"]
self.mode = mode

def before_backward(self, strategy, **kwargs):
"""
Add distillation loss
"""
if self.prev_model is None:
return

with torch.no_grad():
avalanche_forward(self.prev_model, strategy.mb_x, strategy.mb_task_id)
old_features = self.prev_model.features

new_features = strategy.model.features

if self.mode == "cosine":
strategy.loss += self.alpha * (
1 - F.cosine_similarity(new_features, old_features, dim=1).mean()
)
elif self.mode == "mse":
strategy.loss += self.alpha * F.mse_loss(new_features, old_features, dim=1)

def after_training_exp(self, strategy, **kwargs):
"""
Save a copy of the model after each experience and
update self.prev_classes to include the newly learned classes.
"""
if _at_task_boundary(strategy.experience, before=False):
strategy.model.features = None
self.prev_model = copy.deepcopy(strategy.model)


__all__ = ["FeatureDistillationPlugin"]
1 change: 1 addition & 0 deletions avalanche/training/supervised/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
from .supervised_contrastive_replay import SCR
from .expert_gate import ExpertGateStrategy
from .mer import MER
from .feature_replay import *
152 changes: 152 additions & 0 deletions avalanche/training/supervised/feature_replay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#!/usr/bin/env python3
from typing import Callable, List, Optional, Union

import torch
import torch.nn as nn
from torch.optim import Optimizer

from avalanche.core import SupervisedPlugin
from avalanche.models.utils import FeatureExtractorModel, avalanche_forward
from avalanche.training import ACECriterion
from avalanche.training.plugins.evaluation import EvaluationPlugin, default_evaluator
from avalanche.training.storage_policy import ClassBalancedBuffer
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.utils import cycle
from avalanche.training.losses import MaskedCrossEntropy


class FeatureDataset(torch.utils.data.Dataset):
"""
Wrapper around features tensor dataset
Required for compatibility with storage policy
"""

def __init__(self, data, targets):
self.data = data
self.targets = targets

def __len__(self):
return len(self.data)

def __getitem__(self, index):
return self.data[index], self.targets[index]


class FeatureReplay(SupervisedTemplate):
"""
Store some last layer features and use them for replay

Replay is performed following the PR-ACE protocol
defined in Magistri et al. https://openreview.net/forum?id=7D9X2cFnt1

Training the current task with masked cross entropy for current task classes
and training the classifier with cross entropy
criterion over all previously seen classes
"""

def __init__(
self,
model: nn.Module,
optimizer: Optimizer,
criterion=MaskedCrossEntropy(),
last_layer_name: str = "classifier",
mem_size: int = 200,
batch_size_mem: int = 10,
train_mb_size: int = 1,
train_epochs: int = 1,
eval_mb_size: Optional[int] = 1,
device: Union[str, torch.device] = "cpu",
plugins: Optional[List[SupervisedPlugin]] = None,
evaluator: Union[
EvaluationPlugin, Callable[[], EvaluationPlugin]
] = default_evaluator,
eval_every=-1,
peval_mode="epoch",
):
super().__init__(
model,
optimizer,
criterion,
train_mb_size,
train_epochs,
eval_mb_size,
device,
plugins,
evaluator,
eval_every,
peval_mode,
)
self.mem_size = mem_size
self.batch_size_mem = batch_size_mem
self.storage_policy = ClassBalancedBuffer(
max_size=self.mem_size, adaptive_size=True
)
self.replay_loader = None

# Criterion used when doing feature replay
# self._criterion is used only on current data
self.full_criterion = nn.CrossEntropyLoss()

# Turn the model into feature extractor
last_layer = getattr(self.model, last_layer_name)
setattr(self.model, last_layer_name, nn.Identity())
self.model = FeatureExtractorModel(self.model, last_layer)

def _after_training_exp(self, **kwargs):
super()._after_training_exp(**kwargs)
feature_dataset = self.gather_feature_dataset()
self.storage_policy.update_from_dataset(feature_dataset)

if self.batch_size_mem is None:
self.batch_size_mem = self.train_mb_size

# Create replay loader
self.replay_loader = cycle(
torch.utils.data.DataLoader(
self.storage_policy.buffer,
batch_size=self.batch_size_mem,
shuffle=True,
drop_last=True,
)
)

def criterion(self, **kwargs):
if self.replay_loader is None:
return self._criterion(self.mb_output, self.mb_y)

batch_feats, batch_y, batch_t = next(self.replay_loader)
batch_feats, batch_y = batch_feats.to(self.device), batch_y.to(self.device)
out = self.model.train_classifier(batch_feats)

# Select additional outputs from current
# output to be learned with cross-entropy
weight_current = 1 / (self.experience.current_experience + 1)

mb_output = self.model.train_classifier(self.model.features.detach())

loss = 0.5 * self._criterion(self.mb_output, self.mb_y) + 0.5 * (
(1 - weight_current) * self.full_criterion(out, batch_y)
+ weight_current * self.full_criterion(mb_output, self.mb_y)
)
return loss

@torch.no_grad()
def gather_feature_dataset(self):
self.model.eval()
dataloader = torch.utils.data.DataLoader(
self.experience.dataset, batch_size=self.train_mb_size, shuffle=True
)
all_features = []
all_labels = []
for x, y, t in dataloader:
x, y = x.to(self.device), y.to(self.device)
feats = avalanche_forward(self.model.feature_extractor, x, t)
all_features.append(feats.cpu())
all_labels.append(y.cpu())
all_features = torch.cat(all_features)
all_labels = torch.cat(all_labels)
features_dataset = FeatureDataset(all_features, all_labels)
return features_dataset


__all__ = ["FeatureReplay"]
62 changes: 57 additions & 5 deletions tests/training/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
EvaluationPlugin,
SupervisedPlugin,
LwFPlugin,
FeatureDistillationPlugin,
ReplayPlugin,
RWalkPlugin,
EarlyStoppingPlugin,
Expand Down Expand Up @@ -51,6 +52,7 @@
LearningToPrompt,
ExpertGateStrategy,
MER,
FeatureReplay,
)
from avalanche.training.supervised.cumulative import Cumulative
from avalanche.training.supervised.icarl import ICaRL
Expand All @@ -62,6 +64,7 @@
from tests.training.test_strategy_utils import run_strategy
from tests.unit_tests_utils import get_fast_benchmark, get_device
from torchvision import transforms
from avalanche.models.utils import FeatureExtractorModel


class BaseStrategyTest(unittest.TestCase):
Expand Down Expand Up @@ -850,13 +853,14 @@ def test_expertgate(self):
strategy.train(experience)

def test_icarl(self):
model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False)
model = SimpleMLP(input_size=6, hidden_size=10)
optimizer = SGD(model.parameters(), lr=0.000001)
_, optimizer, criterion, benchmark = self.init_scenario(multi_task=False)
fe = torch.nn.Linear(6, 10)
cls = torch.nn.Sigmoid()
optimizer = SGD([*fe.parameters(), *cls.parameters()], lr=0.001)

strategy = ICaRL(
model.features,
model.classifier,
fe,
cls,
optimizer,
20,
buffer_transform=None,
Expand Down Expand Up @@ -1054,6 +1058,54 @@ def test_der(self):
)
run_strategy(benchmark, strategy)

def test_feature_distillation(self):
# SIT scenario
model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False)

# Modify model to make it compatible
last_fc_name, _ = get_last_fc_layer(model)
old_layer = getattr(model, last_fc_name)
setattr(model, last_fc_name, torch.nn.Identity())
model = FeatureExtractorModel(model, old_layer)

feature_distillation = FeatureDistillationPlugin(alpha=10)

plugins = [feature_distillation]

strategy = Naive(
model,
optimizer,
criterion,
device=self.device,
train_mb_size=10,
eval_mb_size=50,
train_epochs=2,
plugins=plugins,
)
run_strategy(benchmark, strategy)

def test_feature_replay(self):
# SIT scenario
model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False)

last_fc_name, _ = get_last_fc_layer(model)

plugins = []

# We do not add MT criterion cause FeatureReplay uses
# MaskedCrossEntropy as main criterion
strategy = FeatureReplay(
model,
optimizer,
device=self.device,
last_layer_name=last_fc_name,
train_mb_size=10,
eval_mb_size=50,
train_epochs=2,
plugins=plugins,
)
run_strategy(benchmark, strategy)

def load_benchmark(
self,
use_task_labels=False,
Expand Down
Loading