Skip to content

Commit

Permalink
BiC: remove use of internal utility. Generalize num_workers argument.
Browse files Browse the repository at this point in the history
  • Loading branch information
lrzpellegrini committed Jan 24, 2024
1 parent 3ced031 commit b93752d
Showing 1 changed file with 29 additions and 16 deletions.
45 changes: 29 additions & 16 deletions avalanche/training/plugins/bic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from typing import (
Dict,
List,
Literal,
Optional,
Sequence,
Set,
SupportsInt,
Union,
)

from copy import deepcopy
Expand All @@ -15,11 +17,9 @@
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR

from avalanche.benchmarks.utils import (
_concat_taskaware_classification_datasets,
)
from avalanche.benchmarks.utils.data import AvalancheDataset
from avalanche.benchmarks.utils.data_loader import ReplayDataLoader
from avalanche.benchmarks.utils.utils import concat_datasets
from avalanche.training.plugins.strategy_plugin import SupervisedPlugin
from avalanche.training.storage_policy import (
ExemplarsBuffer,
Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(
stage_2_epochs: int = 200,
lamb: float = -1,
lr: float = 0.1,
num_workers: int = 4,
num_workers: Union[int, Literal["as_strategy"]] = "as_strategy",
verbose: bool = False,
):
"""
Expand All @@ -84,7 +84,9 @@ def __init__(
loss and the classification loss.
:param lr: hyperparameter used as a learning rate for
the second phase of training.
:param num_workers: number of workers using during stage 2 data loading
:param num_workers: number of workers using during stage 2 data loading.
Defaults to "as_strategy", which means that the number of workers
will be the same as the one used by the strategy.
:param verbose: if True, prints additional info regarding the stage 2 stage
"""

Expand All @@ -110,7 +112,7 @@ def __init__(
self.lamb = lamb
self.mem_size = mem_size
self.lr = lr
self.num_workers = num_workers
self.num_workers: Union[int, Literal["as_strategy"]] = num_workers

self.seen_classes: Set[int] = set()
self.class_to_tasks: Dict[int, int] = {}
Expand Down Expand Up @@ -168,9 +170,7 @@ def before_train_dataset_adaptation(self, strategy: "SupervisedTemplate", **kwar
for class_id, class_buf in self.val_buffer.items():
class_buf.resize(strategy, class_to_len[class_id])

strategy.experience.dataset = _concat_taskaware_classification_datasets(
train_data
)
strategy.experience.dataset = concat_datasets(train_data)

def before_training_exp(
self,
Expand Down Expand Up @@ -266,8 +266,19 @@ def after_training_exp(self, strategy, **kwargs):
self.storage_policy.update(strategy, **kwargs)

if task_id > 0:
num_workers = (
int(kwargs.get("num_workers", 0))
if self.num_workers == "as_strategy"
else self.num_workers
)
persistent_workers = (
False if num_workers == 0 else kwargs.get("persistent_workers", False)
)

self.bias_correction_step(
strategy, persistent_workers=kwargs.get("persistent_workers", False)
strategy,
persistent_workers=persistent_workers,
num_workers=num_workers,
)

def cross_entropy(self, new_outputs, old_outputs):
Expand Down Expand Up @@ -334,13 +345,15 @@ def make_distillation_loss(self, strategy):

# Union of initial_classes and previous_classes: needed to select the logits of all the old classes
old_clss: List[int] = sorted(set(initial_classes) | set(previous_classes))
# print('old', old_clss)

# Distillation loss on the logits of the old classes
return self.cross_entropy(out_new[:, old_clss], out_old[:, old_clss])

def bias_correction_step(
self, strategy: SupervisedTemplate, persistent_workers: bool = False
self,
strategy: SupervisedTemplate,
persistent_workers: bool = False,
num_workers: int = 0,
):
# --- Prepare the models ---
# Freeze the base model, only train the new bias layer
Expand Down Expand Up @@ -372,17 +385,17 @@ def bias_correction_step(
) # type: ignore

# --- Prepare the dataloader for the validation set ---
list_subsets = []
list_subsets: List[AvalancheDataset] = []
for _, class_buf in self.val_buffer.items():
list_subsets.append(class_buf.buffer)

stage_set = _concat_taskaware_classification_datasets(list_subsets)
stage_set = concat_datasets(list_subsets)
stage_loader = DataLoader(
stage_set,
batch_size=strategy.train_mb_size,
shuffle=True,
num_workers=self.num_workers,
persistent_workers=persistent_workers if self.num_workers > 0 else False,
num_workers=num_workers,
persistent_workers=persistent_workers,
)

# Loop epochs
Expand Down

0 comments on commit b93752d

Please sign in to comment.