diff --git a/avalanche/models/bic_model.py b/avalanche/models/bic_model.py index e8de493a5..889c9e63e 100644 --- a/avalanche/models/bic_model.py +++ b/avalanche/models/bic_model.py @@ -1,3 +1,4 @@ +from typing import Iterable, SupportsInt import torch @@ -10,25 +11,28 @@ class BiasLayer(torch.nn.Module): Recognition. 2019" """ - def __init__(self, device, clss): + def __init__(self, clss: Iterable[SupportsInt]): """ - :param device: device used by the main model. 'cpu' or 'cuda' :param clss: list of classes of the current layer. This are use to identify the columns which are multiplied by the Bias correction Layer. """ super().__init__() - self.alpha = torch.nn.Parameter(torch.ones(1, device=device)) - self.beta = torch.nn.Parameter(torch.zeros(1, device=device)) + self.alpha = torch.nn.Parameter(torch.ones(1)) + self.beta = torch.nn.Parameter(torch.zeros(1)) - self.clss = torch.Tensor(list(clss)).long().to(device) - self.not_clss = None + unique_classes = list(sorted(set(int(x) for x in clss))) + + self.register_buffer("clss", torch.tensor(unique_classes, dtype=torch.long)) def forward(self, x): alpha = torch.ones_like(x) - beta = torch.ones_like(x) + beta = torch.zeros_like(x) alpha[:, self.clss] = self.alpha beta[:, self.clss] = self.beta return alpha * x + beta + + +__all__ = ["BiasLayer"] diff --git a/avalanche/training/plugins/bic.py b/avalanche/training/plugins/bic.py index d5d0780bf..61535fd37 100644 --- a/avalanche/training/plugins/bic.py +++ b/avalanche/training/plugins/bic.py @@ -2,24 +2,24 @@ from typing import ( Dict, List, + Literal, Optional, - TYPE_CHECKING, Sequence, Set, SupportsInt, + Union, ) from copy import deepcopy import torch +from torch import Tensor +from torch.nn import Module from torch.utils.data import DataLoader from torch.optim.lr_scheduler import MultiStepLR -from avalanche.benchmarks.utils import ( - _taskaware_classification_subset, - _concat_taskaware_classification_datasets, -) from avalanche.benchmarks.utils.data import AvalancheDataset from avalanche.benchmarks.utils.data_loader import ReplayDataLoader +from avalanche.benchmarks.utils.utils import concat_datasets from avalanche.training.plugins.strategy_plugin import SupervisedPlugin from avalanche.training.storage_policy import ( ExemplarsBuffer, @@ -29,8 +29,7 @@ from avalanche.models.dynamic_modules import MultiTaskModule from avalanche.models.bic_model import BiasLayer -if TYPE_CHECKING: - from avalanche.training.templates import SupervisedTemplate +from avalanche.training.templates import SupervisedTemplate class BiCPlugin(SupervisedPlugin): @@ -58,6 +57,8 @@ def __init__( stage_2_epochs: int = 200, lamb: float = -1, lr: float = 0.1, + num_workers: Union[int, Literal["as_strategy"]] = "as_strategy", + verbose: bool = False, ): """ :param mem_size: replay buffer size. @@ -83,6 +84,10 @@ def __init__( loss and the classification loss. :param lr: hyperparameter used as a learning rate for the second phase of training. + :param num_workers: number of workers using during stage 2 data loading. + Defaults to "as_strategy", which means that the number of workers + will be the same as the one used by the strategy. + :param verbose: if True, prints additional info regarding the stage 2 stage """ # Replay (Phase 1) @@ -107,17 +112,17 @@ def __init__( self.lamb = lamb self.mem_size = mem_size self.lr = lr + self.num_workers: Union[int, Literal["as_strategy"]] = num_workers self.seen_classes: Set[int] = set() self.class_to_tasks: Dict[int, int] = {} - self.bias_layer: Dict[int, BiasLayer] = {} - self.model_old = None + self.bias_layer: Optional[BiasLayer] = None + self.model_old: Optional[Module] = None self.val_buffer: Dict[int, ReservoirSamplingBuffer] = {} - # TODO: remove ext_mem - # @property - # def ext_mem(self): - # return self.storage_policy.buffer_groups # a Dict + self.is_first_experience: bool = True + + self.verbose: bool = verbose def before_training(self, strategy: "SupervisedTemplate", *args, **kwargs): assert not isinstance( @@ -165,9 +170,7 @@ def before_train_dataset_adaptation(self, strategy: "SupervisedTemplate", **kwar for class_id, class_buf in self.val_buffer.items(): class_buf.resize(strategy, class_to_len[class_id]) - strategy.experience.dataset = _concat_taskaware_classification_datasets( - train_data - ) + strategy.experience.dataset = concat_datasets(train_data) def before_training_exp( self, @@ -181,11 +184,13 @@ def before_training_exp( the training dataset """ assert strategy.adapted_dataset is not None - task_id = strategy.clock.train_exp_counter - if task_id not in self.bias_layer: - targets = getattr(strategy.adapted_dataset, "targets") - self.bias_layer[task_id] = BiasLayer(strategy.device, list(targets.uniques)) + # During the distillation phase this layer is not trained and is only + # used to correct the bias of the classes encountered in the previous experience. + # It will be unlocked in the bias correction phase. + if self.bias_layer is not None: + for param in self.bias_layer.parameters(): + param.requires_grad = False if len(self.storage_policy.buffer) == 0: # first experience. We don't use the buffer, no need to change @@ -211,108 +216,269 @@ def before_training_exp( shuffle=shuffle, ) - def after_forward(self, strategy, **kwargs): - for t in self.bias_layer.keys(): - strategy.mb_output = self.bias_layer[t](strategy.mb_output) - def after_eval_forward(self, strategy, **kwargs): - for t in self.bias_layer.keys(): - strategy.mb_output = self.bias_layer[t](strategy.mb_output) + if self.is_first_experience: + # https://github.com/wuyuebupt/LargeScaleIncrementalLearning/blob/7f687a323ae3629109b35c369b547af74a94e73d/resnet.py#L488 + return + + strategy.mb_output = self.bias_forward(strategy.mb_output) + + def bias_forward(self, input_data: Tensor) -> Tensor: + if self.bias_layer is None: + return input_data + + return self.bias_layer(input_data) def before_backward(self, strategy, **kwargs): - # Distill - task_id = strategy.clock.train_exp_counter + # Distillation + if self.model_old is not None: # That is, from the second experience onwards + distillation_loss = self.make_distillation_loss(strategy) - if self.model_old is not None: - out_old = self.model_old(strategy.mb_x.to(strategy.device)) - out_new = strategy.model(strategy.mb_x.to(strategy.device)) + # Count the number of already seen classes (i.e., classes from previous experiences) + initial_classes, previous_classes, current_classes = self._classes_groups( + strategy + ) - old_clss = [] - for c in self.class_to_tasks.keys(): - if self.class_to_tasks[c] < task_id: - old_clss.append(c) + # Make old_classes and all_classes + old_clss: Set[int] = set(initial_classes) | set(previous_classes) + all_clss: Set[int] = old_clss | set(current_classes) - loss_dist = self.cross_entropy(out_new[:, old_clss], out_old[:, old_clss]) if self.lamb == -1: - lamb = len(old_clss) / len(self.seen_classes) - return (1.0 - lamb) * strategy.loss + lamb * loss_dist + lamb = len(old_clss) / len(all_clss) + strategy.loss = (1.0 - lamb) * strategy.loss + lamb * distillation_loss else: - return strategy.loss + self.lamb * loss_dist + strategy.loss = strategy.loss + self.lamb * distillation_loss def after_training_exp(self, strategy, **kwargs): + self.is_first_experience = False + + # Make sure that the old_model is frozen (including batch norm layers) + # requires_grad=False is not sufficient to freeze BN layers, + # we also need eval() + self.model_old = None self.model_old = deepcopy(strategy.model) + self.model_old.eval() + for param in self.model_old.parameters(): + param.requires_grad = False + task_id = strategy.clock.train_exp_counter self.storage_policy.update(strategy, **kwargs) if task_id > 0: - list_subsets = [] - for _, class_buf in self.val_buffer.items(): - list_subsets.append(class_buf.buffer) - - stage_set = _concat_taskaware_classification_datasets(list_subsets) - stage_loader = DataLoader( - stage_set, - batch_size=strategy.train_mb_size, - shuffle=True, - num_workers=4, + num_workers = ( + int(kwargs.get("num_workers", 0)) + if self.num_workers == "as_strategy" + else self.num_workers + ) + persistent_workers = ( + False if num_workers == 0 else kwargs.get("persistent_workers", False) ) - bic_optimizer = torch.optim.SGD( - self.bias_layer[task_id].parameters(), lr=self.lr, momentum=0.9 + self.bias_correction_step( + strategy, + persistent_workers=persistent_workers, + num_workers=num_workers, ) - # verbose here is actually correct - # The PyTorch type stubs for MultiStepLR are broken - scheduler = MultiStepLR( - bic_optimizer, milestones=[50, 100, 150], gamma=0.1, verbose=False - ) # type: ignore + def cross_entropy(self, new_outputs, old_outputs): + """Calculates cross-entropy with temperature scaling""" + # logp = torch.nn.functional.log_softmax(new_outputs / self.T, dim=1) + # pre_p = torch.nn.functional.softmax(old_outputs / self.T, dim=1) + # return -torch.mean(torch.sum(pre_p * logp, dim=1)) * self.T * self.T + + # The previous implementation (above), multiplied the final loss by T^2, which is not correct. + # In addition, this is more aligned to how it's done in the original implementation. + dis_logits_soft = torch.nn.functional.softmax(old_outputs / 2, dim=0) + loss_distill = torch.nn.functional.cross_entropy( + new_outputs / 2, dis_logits_soft + ) + return loss_distill + + def get_group_lengths(self, num_groups): + """Compute groups lengths given the number of groups `num_groups`.""" + max_size = int(self.val_percentage * self.mem_size) + lengths = [max_size // num_groups for _ in range(num_groups)] + # distribute remaining size among experiences. + rem = max_size - sum(lengths) + for i in range(rem): + lengths[i] += 1 - # Loop epochs - for e in range(self.stage_2_epochs): - total, t_acc, t_loss = 0, 0, 0 - for inputs in stage_loader: - x = inputs[0].to(strategy.device) - y_real = inputs[1].to(strategy.device) + return lengths + def make_distillation_loss(self, strategy): + assert self.model_old is not None + initial_classes, previous_classes, current_classes = self._classes_groups( + strategy + ) + # print('initial_classes', initial_classes, 'previous_classes', previous_classes, 'current_classes', current_classes) + + # Forward current minibatch through the old model + with torch.no_grad(): + out_old: Tensor = self.model_old(strategy.mb_x) + + if len(initial_classes) == 0: + # We are in the second experience, no need to correct the bias + # https://github.com/wuyuebupt/LargeScaleIncrementalLearning/blob/7f687a323ae3629109b35c369b547af74a94e73d/resnet.py#L561 + pass + else: + # We are in the third experience or later + # bias_forward will apply the bias correction to the output of the old model for the classes + # found in previous_classes (bias correction is not applied to initial_classes or current_classes)! + # https://github.com/wuyuebupt/LargeScaleIncrementalLearning/blob/7f687a323ae3629109b35c369b547af74a94e73d/resnet.py#L564 + assert self.bias_layer is not None + assert set(self.bias_layer.clss.tolist()) == set(previous_classes) + with torch.no_grad(): + # out_old_before = out_old.clone() + out_old = self.bias_forward(out_old) + + # Asserts commented out for performance reasons. + # Remove the comments if you want to check that the bias correction is applied correctly. + # assert torch.equal(out_old_before[:, initial_classes], out_old[:, initial_classes]) + # assert torch.equal(out_old_before[:, current_classes], out_old[:, current_classes]) + # assert not torch.equal(out_old_before[:, previous_classes], out_old[:, previous_classes]) + + # To compute the distillation loss, we need the output of the new model + # without the bias correction. During train, the output of the new model + # does not undergo bias correction, so we can use mb_output directly. + out_new: Tensor = strategy.mb_output + + # Union of initial_classes and previous_classes: needed to select the logits of all the old classes + old_clss: List[int] = sorted(set(initial_classes) | set(previous_classes)) + + # Distillation loss on the logits of the old classes + return self.cross_entropy(out_new[:, old_clss], out_old[:, old_clss]) + + def bias_correction_step( + self, + strategy: SupervisedTemplate, + persistent_workers: bool = False, + num_workers: int = 0, + ): + # --- Prepare the models --- + # Freeze the base model, only train the new bias layer + strategy.model.eval() + + # Note: we use torch.no_grad for this. + # In this way, we don't need to store the status of each requires_grad + # which is useful when we have multiple parameters with different + # requires_grad status. + # for param in strategy.model.parameters(): + # param.requires_grad = False + + # Create the bias layer of the current experience + targets = getattr(strategy.adapted_dataset, "targets") + self.bias_layer = BiasLayer(targets.uniques) + self.bias_layer.to(strategy.device) + self.bias_layer.train() + for param in self.bias_layer.parameters(): + param.requires_grad = True + + bic_optimizer = torch.optim.SGD( + self.bias_layer.parameters(), lr=self.lr, momentum=0.9 + ) + + # Typing note: verbose here is actually correct + # The PyTorch type stubs for MultiStepLR are broken in some versions + scheduler = MultiStepLR( + bic_optimizer, milestones=[50, 100, 150], gamma=0.1, verbose=False + ) # type: ignore + + # --- Prepare the dataloader for the validation set --- + list_subsets: List[AvalancheDataset] = [] + for _, class_buf in self.val_buffer.items(): + list_subsets.append(class_buf.buffer) + + stage_set = concat_datasets(list_subsets) + stage_loader = DataLoader( + stage_set, + batch_size=strategy.train_mb_size, + shuffle=True, + num_workers=num_workers, + persistent_workers=persistent_workers, + ) + + # Loop epochs + for e in range(self.stage_2_epochs): + total, t_acc, t_loss = 0, 0, 0 + for inputs in stage_loader: + x = inputs[0].to(strategy.device) + y_real = inputs[1].to(strategy.device) + + with torch.no_grad(): outputs = strategy.model(x) - for t in self.bias_layer.keys(): - outputs = self.bias_layer[t](outputs) - loss = torch.nn.functional.cross_entropy(outputs, y_real) + outputs = self.bias_layer(outputs) + + loss = torch.nn.functional.cross_entropy(outputs, y_real) - _, preds = torch.max(outputs, 1) - t_acc += torch.sum(preds == y_real.data) - t_loss += loss.item() * x.size(0) - total += x.size(0) + _, preds = torch.max(outputs, 1) + t_acc += torch.sum(preds == y_real.data) + t_loss += loss.item() * x.size(0) + total += x.size(0) - loss += 0.1 * ((self.bias_layer[task_id].beta.sum() ** 2) / 2) + # Hand-made L2 loss + # https://github.com/wuyuebupt/LargeScaleIncrementalLearning/blob/7f687a323ae3629109b35c369b547af74a94e73d/resnet.py#L636 + loss += 0.1 * ((self.bias_layer.beta.sum() ** 2) / 2) - bic_optimizer.zero_grad() - loss.backward() - bic_optimizer.step() + bic_optimizer.zero_grad() + loss.backward() + bic_optimizer.step() - scheduler.step() - if (e + 1) % (int(self.stage_2_epochs / 4)) == 0: + scheduler.step() + if self.verbose and (self.stage_2_epochs // 4) > 0: + if (e + 1) % (self.stage_2_epochs // 4) == 0: print( "| E {:3d} | Train: loss={:.3f}, S2 acc={:5.1f}% |".format( e + 1, t_loss / total, 100 * t_acc / total ) ) - def cross_entropy(self, outputs, targets): - """Calculates cross-entropy with temperature scaling""" - logp = torch.nn.functional.log_softmax(outputs / self.T, dim=1) - pre_p = torch.nn.functional.softmax(targets / self.T, dim=1) - return -torch.mean(torch.sum(pre_p * logp, dim=1)) * self.T * self.T + # Freeze the bias layer + self.bias_layer.eval() + for param in self.bias_layer.parameters(): + param.requires_grad = False - def get_group_lengths(self, num_groups): - """Compute groups lengths given the number of groups `num_groups`.""" - max_size = int(self.val_percentage * self.mem_size) - lengths = [max_size // num_groups for _ in range(num_groups)] - # distribute remaining size among experiences. - rem = max_size - sum(lengths) - for i in range(rem): - lengths[i] += 1 + if self.verbose: + print( + "Bias correction done: alpha={}, beta={}".format( + self.bias_layer.alpha.item(), self.bias_layer.beta.item() + ) + ) - return lengths + def _classes_groups(self, strategy: SupervisedTemplate): + current_experience: int = strategy.experience.current_experience + # Split between + # - "initial" classes: seen between in experiences [0, current_experience-2] + # - "previous" classes: seen in current_experience-1 + # - "current" classes: seen in current_experience + + # "initial" classes + initial_classes: Set[ + int + ] = set() # pre_initial_cl in the original implementation + previous_classes: Set[int] = set() # pre_new_cl in the original implementation + current_classes: Set[int] = set() # new_cl in the original implementation + # Note: pre_initial_cl + pre_new_cl is "initial_cl" in the original implementation + + for cls, exp_id in self.class_to_tasks.items(): + assert exp_id >= 0 + assert exp_id <= current_experience + + if exp_id < current_experience - 1: + initial_classes.add(cls) + elif exp_id == current_experience - 1: + previous_classes.add(cls) + else: + current_classes.add(cls) + + return ( + sorted(initial_classes), + sorted(previous_classes), + sorted(current_classes), + ) + + +__all__ = [ + "BiCPlugin", +] diff --git a/avalanche/training/storage_policy.py b/avalanche/training/storage_policy.py index 907e93715..c0df18c59 100644 --- a/avalanche/training/storage_policy.py +++ b/avalanche/training/storage_policy.py @@ -461,7 +461,10 @@ def update(self, strategy: "SupervisedTemplate", **kwargs): self.update_from_dataset(strategy, new_data) def update_from_dataset(self, strategy, new_data): - self.buffer = self.buffer.concat(new_data) + if len(self.buffer) == 0: + self.buffer = new_data + else: + self.buffer = self.buffer.concat(new_data) self.resize(strategy, self.max_size) def resize(self, strategy, new_size: int): diff --git a/tests/benchmarks/scenarios/test_task_aware.py b/tests/benchmarks/scenarios/test_task_aware.py index 4d340e0e1..a7822f661 100644 --- a/tests/benchmarks/scenarios/test_task_aware.py +++ b/tests/benchmarks/scenarios/test_task_aware.py @@ -9,19 +9,30 @@ from avalanche.benchmarks.utils.classification_dataset import ClassificationDataset from avalanche.benchmarks.utils.data_attribute import DataAttribute from torch.utils.data import TensorDataset +import numpy as np class TestsTaskAware(unittest.TestCase): def test_taskaware(self): """Common use case: add tas labels to class-incremental benchmark.""" n_classes, n_samples_per_class, n_features = 10, 3, 7 - dataset = make_classification( - n_samples=n_classes * n_samples_per_class, - n_classes=n_classes, - n_features=n_features, - n_informative=6, - n_redundant=0, - ) + + for _ in range(10000): + dataset = make_classification( + n_samples=n_classes * n_samples_per_class, + n_classes=n_classes, + n_features=n_features, + n_informative=6, + n_redundant=0, + ) + + # The following check is required to ensure that at least 2 exemplars + # per class are generated. Otherwise, the train_test_split function will + # fail. + _, unique_count = np.unique(dataset[1], return_counts=True) + if np.min(unique_count) > 1: + break + X = torch.from_numpy(dataset[0]).float() y = torch.from_numpy(dataset[1]).long() train_X, test_X, train_y, test_y = train_test_split(