diff --git a/avalanche/training/plugins/bic.py b/avalanche/training/plugins/bic.py index aa7a142d2..61535fd37 100644 --- a/avalanche/training/plugins/bic.py +++ b/avalanche/training/plugins/bic.py @@ -2,10 +2,12 @@ from typing import ( Dict, List, + Literal, Optional, Sequence, Set, SupportsInt, + Union, ) from copy import deepcopy @@ -15,11 +17,9 @@ from torch.utils.data import DataLoader from torch.optim.lr_scheduler import MultiStepLR -from avalanche.benchmarks.utils import ( - _concat_taskaware_classification_datasets, -) from avalanche.benchmarks.utils.data import AvalancheDataset from avalanche.benchmarks.utils.data_loader import ReplayDataLoader +from avalanche.benchmarks.utils.utils import concat_datasets from avalanche.training.plugins.strategy_plugin import SupervisedPlugin from avalanche.training.storage_policy import ( ExemplarsBuffer, @@ -57,7 +57,7 @@ def __init__( stage_2_epochs: int = 200, lamb: float = -1, lr: float = 0.1, - num_workers: int = 4, + num_workers: Union[int, Literal["as_strategy"]] = "as_strategy", verbose: bool = False, ): """ @@ -84,7 +84,9 @@ def __init__( loss and the classification loss. :param lr: hyperparameter used as a learning rate for the second phase of training. - :param num_workers: number of workers using during stage 2 data loading + :param num_workers: number of workers using during stage 2 data loading. + Defaults to "as_strategy", which means that the number of workers + will be the same as the one used by the strategy. :param verbose: if True, prints additional info regarding the stage 2 stage """ @@ -110,7 +112,7 @@ def __init__( self.lamb = lamb self.mem_size = mem_size self.lr = lr - self.num_workers = num_workers + self.num_workers: Union[int, Literal["as_strategy"]] = num_workers self.seen_classes: Set[int] = set() self.class_to_tasks: Dict[int, int] = {} @@ -168,9 +170,7 @@ def before_train_dataset_adaptation(self, strategy: "SupervisedTemplate", **kwar for class_id, class_buf in self.val_buffer.items(): class_buf.resize(strategy, class_to_len[class_id]) - strategy.experience.dataset = _concat_taskaware_classification_datasets( - train_data - ) + strategy.experience.dataset = concat_datasets(train_data) def before_training_exp( self, @@ -266,8 +266,19 @@ def after_training_exp(self, strategy, **kwargs): self.storage_policy.update(strategy, **kwargs) if task_id > 0: + num_workers = ( + int(kwargs.get("num_workers", 0)) + if self.num_workers == "as_strategy" + else self.num_workers + ) + persistent_workers = ( + False if num_workers == 0 else kwargs.get("persistent_workers", False) + ) + self.bias_correction_step( - strategy, persistent_workers=kwargs.get("persistent_workers", False) + strategy, + persistent_workers=persistent_workers, + num_workers=num_workers, ) def cross_entropy(self, new_outputs, old_outputs): @@ -334,13 +345,15 @@ def make_distillation_loss(self, strategy): # Union of initial_classes and previous_classes: needed to select the logits of all the old classes old_clss: List[int] = sorted(set(initial_classes) | set(previous_classes)) - # print('old', old_clss) # Distillation loss on the logits of the old classes return self.cross_entropy(out_new[:, old_clss], out_old[:, old_clss]) def bias_correction_step( - self, strategy: SupervisedTemplate, persistent_workers: bool = False + self, + strategy: SupervisedTemplate, + persistent_workers: bool = False, + num_workers: int = 0, ): # --- Prepare the models --- # Freeze the base model, only train the new bias layer @@ -372,17 +385,17 @@ def bias_correction_step( ) # type: ignore # --- Prepare the dataloader for the validation set --- - list_subsets = [] + list_subsets: List[AvalancheDataset] = [] for _, class_buf in self.val_buffer.items(): list_subsets.append(class_buf.buffer) - stage_set = _concat_taskaware_classification_datasets(list_subsets) + stage_set = concat_datasets(list_subsets) stage_loader = DataLoader( stage_set, batch_size=strategy.train_mb_size, shuffle=True, - num_workers=self.num_workers, - persistent_workers=persistent_workers if self.num_workers > 0 else False, + num_workers=num_workers, + persistent_workers=persistent_workers, ) # Loop epochs