diff --git a/avalanche/_annotations.py b/avalanche/_annotations.py index 1820ae811..c23ac7d54 100644 --- a/avalanche/_annotations.py +++ b/avalanche/_annotations.py @@ -81,13 +81,12 @@ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - warnings.simplefilter("always", DeprecationWarning) + warnings.simplefilter("once", DeprecationWarning) warnings.warn( msg.format(name=func.__name__, version=version, reason=reason), category=DeprecationWarning, stacklevel=2, ) - warnings.simplefilter("default", DeprecationWarning) return func(*args, **kwargs) return wrapper diff --git a/avalanche/benchmarks/scenarios/online.py b/avalanche/benchmarks/scenarios/online.py index b88c427d4..1e51dfadc 100644 --- a/avalanche/benchmarks/scenarios/online.py +++ b/avalanche/benchmarks/scenarios/online.py @@ -167,6 +167,7 @@ def __init__( shuffle: bool = True, drop_last: bool = False, access_task_boundaries: bool = False, + seed: int = None, ) -> None: """Returns a lazy stream generated by splitting an experience into smaller ones. @@ -181,7 +182,8 @@ def __init__( :param experience_size: The experience size (number of instances). :param shuffle: If True, instances will be shuffled before splitting. :param drop_last: If True, the last mini-experience will be dropped if - not of size `experience_size` + not of size `experience_size`. + :param seed: random seed for shuffling the data if `shuffle == True`. :return: The list of datasets that will be used to create the mini-experiences. """ @@ -190,10 +192,12 @@ def __init__( self.shuffle = shuffle self.drop_last = drop_last self.access_task_boundaries = access_task_boundaries + self.seed = seed # we need to fix the seed because repeated calls to the generator # must return the same order every time. - self.seed = random.randint(0, 2**32 - 1) + if seed is None: + self.seed = random.randint(0, 2**32 - 1) def __iter__(self) -> Generator[OnlineCLExperience, None, None]: exp_dataset = self.experience.dataset @@ -250,6 +254,7 @@ def _default_online_split( access_task_boundaries: bool, exp: DatasetExperience, size: int, + seed: int, ): return FixedSizeExperienceSplitter( experience=exp, @@ -257,6 +262,7 @@ def _default_online_split( shuffle=shuffle, drop_last=drop_last, access_task_boundaries=access_task_boundaries, + seed=seed, ) @@ -272,6 +278,7 @@ def split_online_stream( ] ] = None, access_task_boundaries: bool = False, + seed: int = None, ) -> CLStream[DatasetExperience[TCLDataset]]: """Split a stream of large batches to create an online stream of small mini-batches. @@ -300,6 +307,7 @@ def split_online_stream( A good starting to understand the mechanism is to look at the implementation of the standard splitting function :func:`fixed_size_experience_split`. + :param seed: random seed used for shuffling by the default splitter. :return: A lazy online stream with experiences of size `experience_size`. """ @@ -308,7 +316,7 @@ def split_online_stream( # However, MyPy does not understand what a partial is -_- def default_online_split_wrapper(e, e_sz): return _default_online_split( - shuffle, drop_last, access_task_boundaries, e, e_sz + shuffle, drop_last, access_task_boundaries, e, e_sz, seed=seed ) split_strategy = default_online_split_wrapper diff --git a/avalanche/benchmarks/scenarios/validation_scenario.py b/avalanche/benchmarks/scenarios/validation_scenario.py index 74d3bb2ad..d3ad9e4ce 100644 --- a/avalanche/benchmarks/scenarios/validation_scenario.py +++ b/avalanche/benchmarks/scenarios/validation_scenario.py @@ -96,6 +96,7 @@ def random_validation_split_strategy_wrapper(data): # don't drop classes-timeline for compatibility with old API e0 = next(iter(train_stream)) + if hasattr(e0, "dataset") and hasattr(e0.dataset, "targets"): train_stream = with_classes_timeline(train_stream) valid_stream = with_classes_timeline(valid_stream) diff --git a/avalanche/core.py b/avalanche/core.py index 0d2364359..e1ed5368b 100644 --- a/avalanche/core.py +++ b/avalanche/core.py @@ -1,13 +1,160 @@ +""" +This module contains Protocols for some of the main components of Avalanche, +such as strategy plugins and the agent state. + +Most of these protocols are checked dynamically at runtime, so it is often not +necessary to inherit explicit from them or implement all the methods. +""" + from abc import ABC -from typing import Any, TypeVar, Generic +from typing import Any, TypeVar, Generic, Protocol, runtime_checkable from typing import TYPE_CHECKING +from avalanche.benchmarks import CLExperience + if TYPE_CHECKING: from avalanche.training.templates.base import BaseTemplate Template = TypeVar("Template", bound="BaseTemplate") +class Agent: + """Avalanche Continual Learning Agent. + + The agent stores the state needed by continual learning training methods, + such as optimizers, models, regularization losses. + You can add any objects as attributes dynamically: + + .. code-block:: + + agent = Agent() + agent.replay = ReservoirSamplingBuffer(max_size=200) + agent.loss = MaskedCrossEntropy() + agent.reg_loss = LearningWithoutForgetting(alpha=1, temperature=2) + agent.model = my_model + agent.opt = SGD(agent.model.parameters(), lr=0.001) + agent.scheduler = ExponentialLR(agent.opt, gamma=0.999) + + Many CL objects will need to perform some operation before or + after training on each experience. This is supported via the `Adaptable` + Protocol, which requires the `pre_adapt` and `post_adapt` methods. + To call the pre/post adaptation you can implement your training loop + like in the following example: + + .. code-block:: + + def train(agent, exp): + agent.pre_adapt(exp) + # do training here + agent.post_adapt(exp) + + Objects that implement the `Adaptable` Protocol will be called by the Agent. + + You can also add additional functionality to the adaptation phases with + hooks. For example: + + .. code-block:: + agent.add_pre_hooks(lambda a, e: update_optimizer(a.opt, new_params={}, optimized_params=dict(a.model.named_parameters()))) + # we update the lr scheduler after each experience (not every epoch!) + agent.add_post_hooks(lambda a, e: a.scheduler.step()) + + + """ + + def __init__(self, verbose=False): + """Init. + + :param verbose: If True, print every time an adaptable object or hook + is called during the adaptation. Useful for debugging. + """ + self._updatable_objects = [] + self.verbose = verbose + self._pre_hooks = [] + self._post_hooks = [] + + def __setattr__(self, name, value): + super().__setattr__(name, value) + if hasattr(value, "pre_adapt") or hasattr(value, "post_adapt"): + self._updatable_objects.append(value) + if self.verbose: + print("Added updatable object ", value) + + def pre_adapt(self, exp): + """Pre-adaptation. + + Remember to call this before training on a new experience. + + :param exp: current experience + """ + for uo in self._updatable_objects: + if hasattr(uo, "pre_adapt"): + uo.pre_adapt(self, exp) + if self.verbose: + print("pre_adapt ", uo) + for foo in self._pre_hooks: + if self.verbose: + print("pre_adapt hook ", foo) + foo(self, exp) + + def post_adapt(self, exp): + """Post-adaptation. + + Remember to call this after training on a new experience. + + :param exp: current experience + """ + for uo in self._updatable_objects: + if hasattr(uo, "post_adapt"): + uo.post_adapt(self, exp) + if self.verbose: + print("post_adapt ", uo) + for foo in self._post_hooks: + if self.verbose: + print("post_adapt hook ", foo) + foo(self, exp) + + def add_pre_hooks(self, foo): + """Add a pre-adaptation hooks + + Hooks take two arguments: ``. + + :param foo: the hook function + """ + self._pre_hooks.append(foo) + + def add_post_hooks(self, foo): + """Add a post-adaptation hooks + + Hooks take two arguments: ``. + + :param foo: the hook function + """ + self._post_hooks.append(foo) + + +class Adaptable(Protocol): + """Adaptable objects Protocol. + + These class documents the Adaptable objects API but it is not necessary + for an object to inherit from it since the `Agent` will search for the methods + dynamically. + + Adaptable objects are objects that require to run their `pre_adapt` and + `post_adapt` methods before (and after, respectively) training on each + experience. + + Adaptable objects can implement only the method that they need since the + `Agent` will look for the methods dynamically and call it only if it is + implemented. + """ + + def pre_adapt(self, agent: Agent, exp: CLExperience): + pass + + def post_adapt(self, agent: Agent, exp: CLExperience): + pass + + class BasePlugin(Generic[Template], ABC): """ABC for BaseTemplate plugins. diff --git a/avalanche/evaluation/collector.py b/avalanche/evaluation/collector.py new file mode 100644 index 000000000..9809f8e59 --- /dev/null +++ b/avalanche/evaluation/collector.py @@ -0,0 +1,187 @@ +""" +Utilities for metrics collection outside of Avalanche training Templates. +""" + +import json + +import numpy as np + +from avalanche._annotations import experimental + + +@experimental() +class MetricCollector: + """A simple metric collector object. + + Functionlity includes the ability to store metrics over time and compute + aggregated values of them. Serialization is supported via json files. + + You can pass a stream to the `update` method to compute separate metrics + for each stream. Metric names will become `{stream.name}/{metric_name}`. + When you recover the stream using the `get` method, you need to pass the + same stream. + + Example usage: + + .. code-block:: python + + mc = MetricCollector() + for exp in train_stream: + res = {"Acc": 0.1} # metric dictionary for the current timestep + mc.update(res, stream=test_stream) + acc_timeline = mc.get("Accuracy", exp_reduce="sample_mean", stream=test_stream) + + """ + + def __init__(self): + """Init.""" + self.metrics_res = {} + + self._stream_len = {} # stream-name -> stream length + self._coeffs = {} # stream-name -> list_of_exp_length + + def _init_stream(self, stream): + # we compute the stream length and number of samples in each experience + # only the first time we encounter the experience because + # iterating over the stream may be expensive + self._stream_len[stream.name] = 0 + coeffs = [] + for exp in stream: + self._stream_len[stream.name] += 1 + coeffs.append(len(exp.dataset)) + coeffs = np.array(coeffs) + self._coeffs[stream.name] = coeffs / coeffs.sum() + + def update(self, res, *, stream=None): + """Update the metrics. + + :param res: a dictionary of new metrics with items. + :param stream: optional stream. If a stream is given the full metric + name becomes `f'{stream.name}/{metric_name}'`. + """ + for k, v in res.items(): + # optional safety check on metric shape + # metrics are expected to have one value for each experience in the stream + if stream is not None: + if stream.name not in self._stream_len: + self._init_stream(stream) + if len(v) != self._stream_len[stream.name]: + raise ValueError( + f"Length does not correspond to stream. " + f"Found {len(v)}, expected {self._stream_len[stream.name]}" + ) + + # update metrics dictionary + if stream is not None: + k = f"{stream.name}/{k}" + if k in self.metrics_res: + self.metrics_res[k].append(v) + else: + self.metrics_res[k] = [v] + + def get( + self, name, *, time_reduce=None, exp_reduce=None, stream=None, weights=None + ): + """Returns a metric value given its name and aggregation method. + + :param name: name of the metric. + :param time_reduce: Aggregation over the time dimension. One of {None, 'last', 'mean'}, where: + - None (default) does not use any aggregation + - 'last' returns the last timestep + - 'mean' averages over time + :param exp_reduce: Aggregation over the experience dimension. One of {None, 'sample_mean', 'experience_mean'} where: + - None (default) does not use any aggregation + - `sample_mean` is an average weighted by the number of samples in each experience + - `experience_mean` is an experience average. + - 'weighted_sum' is a weighted sum of the experiences using the `weights` argument. + :param stream: stream that was used to compute the metric. This is + needed to build the full metric name if the get was called with a + stream name and if `exp_reduce == sample_mean` to get the number + of samples from each experience. + :param weights: weights for each experience when `exp_reduce == 'weighted_sum`. + :return: aggregated metric value. + """ + assert time_reduce in {None, "last", "mean"} + assert exp_reduce in {None, "sample_mean", "experience_mean", "weighted_sum"} + if exp_reduce == "weighted_sum": + assert ( + weights is not None + ), "You should set the `weights` argument when `exp_reduce == 'weighted_sum'`." + else: + assert ( + weights is None + ), "Can't use the `weights` argument when `exp_reduce != 'weighted_sum'`" + + if stream is not None: + name = f"{stream.name}/{name}" + if name not in self.metrics_res: + print( + f"{name} metric was not found. Maybe you forgot the `stream` argument?" + ) + + mvals = np.array(self.metrics_res[name]) + if exp_reduce is None: + pass # nothing to do here + elif exp_reduce == "sample_mean": + if stream is None: + raise ValueError( + "If you want to use `exp_reduce == sample_mean` you need to provide" + "the `stream` argument to the `update` and `get` methods." + ) + if stream.name not in self._coeffs: + self._init_stream(stream) + mvals = mvals * self._coeffs[stream.name][None, :] # weight by num.samples + mvals = mvals.sum(axis=1) # weighted avg across exp. + elif exp_reduce == "experience_mean": + mvals = mvals.mean(axis=1) # avg across exp. + elif exp_reduce == "weighted_sum": + weights = np.array(weights)[None, :] + mvals = (mvals * weights).sum(axis=1) + else: + raise ValueError("BUG. It should never get here.") + + if time_reduce is None: + pass # nothing to do here + elif time_reduce == "last": + mvals = mvals[-1] # last timestep + elif time_reduce == "mean": + mvals = mvals.mean(axis=0) # avg. over time + else: + raise ValueError("BUG. It should never get here.") + + return mvals + + def get_dict(self): + """Returns metrics dictionary. + + :return: metrics dictionary + """ + # TODO: test + return self.metrics_res + + def load_dict(self, d): + """Loads a new metrics dictionary. + + :param d: metrics dictionary + """ + # TODO: test + self.metrics_res = d + + def load_json(self, fname): + """Loads a metrics dictionary from a json file. + + :param fname: file name + """ + # TODO: test + with open(fname, "w") as f: + self.metrics_res = json.load(f) + + def to_json(self, fname): + """Stores metrics dictionary as a json filename. + + :param fname: + :return: + """ + # TODO: test + with open(fname, "w") as f: + json.dump(obj=self.metrics_res, fp=f) diff --git a/avalanche/evaluation/functional.py b/avalanche/evaluation/functional.py new file mode 100644 index 000000000..862e8a240 --- /dev/null +++ b/avalanche/evaluation/functional.py @@ -0,0 +1,36 @@ +import numpy as np + + +def forgetting(accuracy_matrix, boundary_indices=None): + """Forgetting. + + Given an experience `k` learned at time `boundary_indices[k]`, + forgetting is the `accuracy_matrix[t, k] - accuracy_matrix[t, boundary_indices[k]]`. + Forgetting is set to zero before learning on the experience. + + :param accuracy_matrix: 2D accuracy matrix with shape + :param boundary_indices: time index for each experience, corresponding to + the time after the experience was learned. Optional if + `accuracy_matrix` is a square matrix, which is the most common case. + In this setting, `boundary_indices[k] = k`. + :return: + """ + accuracy_matrix = np.array(accuracy_matrix) + forgetting_matrix = np.zeros_like(accuracy_matrix) + + if boundary_indices is None: + assert accuracy_matrix.shape[0] == accuracy_matrix.shape[1] + boundary_indices = list(range(accuracy_matrix.shape[0])) + + for k in range(accuracy_matrix.shape[1]): + t_task = boundary_indices[k] + acc_first = accuracy_matrix[t_task][k] + + # before learning exp k, forgetting is zero + forgetting_matrix[: t_task + 1, k] = 0 + # then, it's acc_first - acc[t, k] + forgetting_matrix[t_task + 1 :, k] = ( + acc_first - accuracy_matrix[t_task + 1 :, k] + ) + + return forgetting_matrix diff --git a/avalanche/evaluation/plot_utils.py b/avalanche/evaluation/plot_utils.py index 44a05ee6d..0c7b5f591 100644 --- a/avalanche/evaluation/plot_utils.py +++ b/avalanche/evaluation/plot_utils.py @@ -8,7 +8,8 @@ # E-mail: contact@continualai.org # # Website: avalanche.continualai.org # ################################################################################ -import matplotlib.pyplot as plt +import numpy as np +from matplotlib import pyplot as plt def learning_curves_plot(all_metrics: dict): @@ -28,3 +29,38 @@ def learning_curves_plot(all_metrics: dict): ax.set_xlabel("Iterations") ax.set_ylabel("Experience Accuracy") return fig + + +def plot_metric_matrix(metric_matrix, title, *, ax=None, text_values=True): + """Plot a matrix of metrics (e.g. forgetting over time). + + :param metric_matrix: 2D accuracy matrix with shape + :param title: plot title + :param ax: axes to use for figure + :param text_values: (bool) whether to add the value as text in each cell. + + :return: a matplotlib.Figure + """ + if ax is None: + fig, ax = plt.subplots() + else: + fig = ax.figure + + metric_matrix = np.array(metric_matrix).T + ax.matshow(metric_matrix) + ax.set_ylabel("Experience") + ax.set_xlabel("Time") + ax.set_title(title) + + if text_values: + for i in range(len(metric_matrix)): + for j in range(len(metric_matrix[0])): + ax.text( + j, + i, + f"{metric_matrix[i][j]:.3f}", + ha="center", + va="center", + color="w", + ) + return fig diff --git a/avalanche/models/dynamic_modules.py b/avalanche/models/dynamic_modules.py index 436c75bee..9a43cda54 100644 --- a/avalanche/models/dynamic_modules.py +++ b/avalanche/models/dynamic_modules.py @@ -60,7 +60,7 @@ class DynamicModule(Module): expanded to allow architectural modifications (multi-head classifiers, progressive networks, ...). - Compared to pytoch Modules, they provide an additional method, + Compared to pytorch Modules, they provide an additional method, `model_adaptation`, which adapts the model given the current experience. """ @@ -73,7 +73,7 @@ def __init__(self, auto_adapt=True): super().__init__() self._auto_adapt = auto_adapt - def recursive_adaptation(self, experience): + def pre_adapt(self, agent, experience): """ Calls self.adaptation recursively accross the hierarchy of pytorch module childrens diff --git a/avalanche/models/dynamic_optimizers.py b/avalanche/models/dynamic_optimizers.py index 33d6eb72d..10e487dab 100644 --- a/avalanche/models/dynamic_optimizers.py +++ b/avalanche/models/dynamic_optimizers.py @@ -18,7 +18,9 @@ import numpy as np -from avalanche._annotations import deprecated +from avalanche._annotations import deprecated, experimental +from avalanche.benchmarks import CLExperience +from avalanche.core import Adaptable, Agent colors = { "END": "\033[0m", @@ -31,6 +33,71 @@ colors[None] = colors["END"] +@experimental( + "New dynamic optimizers. The API may slightly change in the next versions." +) +class DynamicOptimizer(Adaptable): + """Avalanche dynamic optimizer. + + In continual learning, many model architecture may change over time (e.g. + adding new units to the classifier). This is handled by the `DynamicModule`. + Changing the model's architecture requires updating the optimizer to add + the new parameters in the optimizer's state. + + This class provides a simple wrapper to handle the optimizer + update via the `Adaptable` Protocol. + + This class provides direct methods only for `zero_grad` and `step` to support + basic training functionality. Other methods of the base optimizers must be + called by using the base optimizer directly (e.g. `self.optim.add_param_group`). + + NOTE: the model must be adapted *before* calling this method. + To ensure this, ensure that the optimizer is added to the agent state + after the model: + + .. code-block:: + + agent.model = model + agent.optimizer = optimizer + + # ... more init code + + # pre_adapt will call the pre_adapt methods in order, + # first model.pre_adapt, then optimizer.pre_adapt + agent.pre_adapt(experience) + """ + + def __init__(self, optim, model, reset_state=False, verbose=False): + self.optim = optim + self.reset_state = reset_state + self.verbose = verbose + + # initialize param-id + self._optimized_param_id = update_optimizer( + self.optim, + dict(model.named_parameters()), + None, + reset_state=True, + verbose=self.verbose, + ) + + def zero_grad(self): + self.optim.zero_grad() + + def step(self): + self.optim.step() + + def pre_adapt(self, agent: Agent, exp: CLExperience): + """Adapt the optimizer before training on the current experience.""" + self._optimized_param_id = update_optimizer( + self.optim, + dict(agent.model.named_parameters()), + self._optimized_param_id, + reset_state=self.reset_state, + verbose=self.verbose, + ) + + def _map_optimized_params(optimizer, parameters, old_params=None): """ Establishes a mapping between a list of named parameters and the parameters @@ -275,16 +342,22 @@ def update_optimizer( Newly added parameters are added by default to parameter group 0 - :param new_params: Dict (name, param) of new parameters + WARNING: the first call to `update_optimizer` must be done before + calling the model's adaptation. + + :param optimizer: the Optimizer object. + :param new_params: Dict (name, param) of new parameters. :param optimized_params: Dict (name, param) of - currently optimized parameters + currently optimized parameters. In most use cases, it will be `None in + the first call and the return value of the last `update_optimizer` call + for the subsequent calls. :param reset_state: Whether to reset the optimizer's state (i.e momentum). Defaults to False. :param remove_params: Whether to remove parameters that were in the optimizer but are not found in new parameters. For safety reasons, defaults to False. :param verbose: If True, prints information about inferred - parameter groups for new params + parameter groups for new params. :return: Dict (name, param) of optimized parameters """ @@ -327,7 +400,6 @@ def update_optimizer( new_p = new_params[key] group = param_structure[key].single_group optimizer.param_groups[group]["params"].append(new_p) - optimized_params[key] = new_p optimizer.state[new_p] = {} if reset_state: diff --git a/avalanche/training/losses.py b/avalanche/training/losses.py index 014cb4129..460b01aa2 100644 --- a/avalanche/training/losses.py +++ b/avalanche/training/losses.py @@ -8,6 +8,7 @@ from avalanche.training.plugins import SupervisedPlugin from avalanche.training.regularization import cross_entropy_with_oh_targets +from avalanche._annotations import deprecated class ICaRLLossPlugin(SupervisedPlugin): @@ -214,12 +215,20 @@ def current_mask(self, logit_shape): elif self.mask == "all": return list(range(int(logit_shape))) - def adaptation(self, new_classes): + def _adaptation(self, new_classes): self.old_classes = self.old_classes.union(self.current_classes) self.current_classes = set(new_classes) + def pre_adapt(self, agent, exp): + self._adaptation(exp.classes_in_this_experience) + + @deprecated(0.7, "Please switch to the `pre_adapt`or `_adaptation` methods.") + def adaptation(self, new_classes): + self._adaptation(new_classes) + + @deprecated(0.7, "Please switch to the `pre_adapt` method.") def before_training_exp(self, strategy, **kwargs): - self.adaptation(strategy.experience.classes_in_this_experience) + self._adaptation(strategy.experience.classes_in_this_experience) __all__ = ["ICaRLLossPlugin", "SCRLoss", "MaskedCrossEntropy"] diff --git a/avalanche/training/plugins/lwf.py b/avalanche/training/plugins/lwf.py index 28ae33ecd..4ed23d141 100644 --- a/avalanche/training/plugins/lwf.py +++ b/avalanche/training/plugins/lwf.py @@ -31,4 +31,4 @@ def after_training_exp(self, strategy, **kwargs): Save a copy of the model after each experience and update self.prev_classes to include the newly learned classes. """ - self.lwf.update(strategy.experience, strategy.model) + self.lwf.post_adapt(strategy, strategy.experience) diff --git a/avalanche/training/regularization.py b/avalanche/training/regularization.py index 277dc1ee3..dc4ab310f 100644 --- a/avalanche/training/regularization.py +++ b/avalanche/training/regularization.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from avalanche.models import MultiTaskModule, avalanche_forward +from avalanche._annotations import deprecated def stable_softmax(x): @@ -39,9 +40,16 @@ class RegularizationMethod: of an experience. """ + @deprecated(0.7, "please switch to pre_update and post_update methods.") def update(self, *args, **kwargs): raise NotImplementedError() + def pre_adapt(self, agent, exp): + pass # implementation may be empty if adapt is not needed + + def post_adapt(self, agent, exp): + pass # implementation may be empty if adapt is not needed + def __call__(self, *args, **kwargs): raise NotImplementedError() @@ -133,6 +141,29 @@ def __call__(self, mb_x, mb_pred, model): ) return alpha * self._lwf_penalty(mb_pred, mb_x, model) + def post_adapt(self, agent, exp): + """Save a copy of the model after each experience and + update self.prev_classes to include the newly learned classes. + + :param agent: agent state + :param exp: current experience + """ + self.expcount += 1 + self.prev_model = copy.deepcopy(agent.model) + task_ids = exp.dataset.targets_task_labels.uniques + + for task_id in task_ids: + task_data = exp.dataset.task_set[task_id] + pc = set(task_data.targets.uniques) + + if task_id not in self.prev_classes_by_task: + self.prev_classes_by_task[task_id] = pc + else: + self.prev_classes_by_task[task_id] = self.prev_classes_by_task[ + task_id + ].union(pc) + + @deprecated(0.7, "switch to post_udpate method") def update(self, experience, model): """Save a copy of the model after each experience and update self.prev_classes to include the newly learned classes. @@ -140,7 +171,6 @@ def update(self, experience, model): :param experience: current experience :param model: current model """ - self.expcount += 1 self.prev_model = copy.deepcopy(model) task_ids = experience.dataset.targets_task_labels.uniques diff --git a/avalanche/training/storage_policy.py b/avalanche/training/storage_policy.py index c0df18c59..ca450a9dd 100644 --- a/avalanche/training/storage_policy.py +++ b/avalanche/training/storage_policy.py @@ -24,6 +24,7 @@ ) from avalanche.models import FeatureExtractorBackbone from ..benchmarks.utils.utils import concat_datasets +from avalanche._annotations import deprecated if TYPE_CHECKING: from .templates import SupervisedTemplate, BaseSGDTemplate @@ -54,7 +55,7 @@ def buffer(self) -> AvalancheDataset: def buffer(self, new_buffer: AvalancheDataset): self._buffer = new_buffer - @abstractmethod + @deprecated(0.7, "switch to pre_adapt and post_adapt") def update(self, strategy: "SupervisedTemplate", **kwargs): """Update `self.buffer` using the `strategy` state. @@ -62,7 +63,17 @@ def update(self, strategy: "SupervisedTemplate", **kwargs): :param kwargs: :return: """ - ... + # this should work until we deprecate self.update + self.post_adapt(strategy, strategy.experience) + + def post_adapt(self, agent_state, exp): + """Update `self.buffer` using the agent state and current experience. + + :param agent_state: + :param exp: + :return: + """ + pass @abstractmethod def resize(self, strategy: "SupervisedTemplate", new_size: int): @@ -92,11 +103,9 @@ def __init__(self, max_size: int): # INVARIANT: _buffer_weights is always sorted. self._buffer_weights = torch.zeros(0) - def update(self, strategy: "SupervisedTemplate", **kwargs): + def post_adapt(self, agent, exp): """Update buffer.""" - experience = strategy.experience - assert experience is not None - self.update_from_dataset(experience.dataset) + self.update_from_dataset(exp.dataset) def update_from_dataset(self, new_data: AvalancheDataset): """Update the buffer using the given dataset. @@ -193,16 +202,6 @@ def buffer(self, new_buffer): "You should modify `self.buffer_groups instead." ) - @abstractmethod - def update(self, strategy: "SupervisedTemplate", **kwargs): - """Update `self.buffer_groups` using the `strategy` state. - - :param strategy: - :param kwargs: - :return: - """ - ... - def resize(self, strategy, new_size): """Update the maximum size of the buffers.""" self.max_size = new_size @@ -229,19 +228,19 @@ def __init__(self, max_size: int, adaptive_size: bool = True, num_experiences=No of experiences to divide capacity over. """ super().__init__(max_size, adaptive_size, num_experiences) + self._num_exps = 0 - def update(self, strategy: "SupervisedTemplate", **kwargs): - assert strategy.experience is not None - new_data = strategy.experience.dataset - num_exps = strategy.clock.train_exp_counter + 1 - lens = self.get_group_lengths(num_exps) + def post_adapt(self, agent, exp): + self._num_exps += 1 + new_data = exp.dataset + lens = self.get_group_lengths(self._num_exps) new_buffer = ReservoirSamplingBuffer(lens[-1]) new_buffer.update_from_dataset(new_data) - self.buffer_groups[num_exps - 1] = new_buffer + self.buffer_groups[self._num_exps - 1] = new_buffer for ll, b in zip(lens, self.buffer_groups.values()): - b.resize(strategy, ll) + b.resize(agent, ll) class ClassBalancedBuffer(BalancedExemplarsBuffer[ReservoirSamplingBuffer]): @@ -279,10 +278,9 @@ def __init__( self.total_num_classes = total_num_classes self.seen_classes: Set[int] = set() - def update(self, strategy: "BaseSGDTemplate", **kwargs): + def post_adapt(self, agent, exp): """Update buffer.""" - assert strategy.experience is not None - self.update_from_dataset(strategy.experience.dataset, strategy) + self.update_from_dataset(exp.dataset, agent) def update_from_dataset( self, new_data: AvalancheDataset, strategy: Optional["BaseSGDTemplate"] = None @@ -361,10 +359,9 @@ def __init__( self.seen_groups: Set[int] = set() self._curr_strategy = None - def update(self, strategy: "SupervisedTemplate", **kwargs): - assert strategy.experience is not None - new_data: AvalancheDataset = strategy.experience.dataset - new_groups = self._make_groups(strategy, new_data) + def post_adapt(self, agent, exp): + new_data: AvalancheDataset = exp.dataset + new_groups = self._make_groups(agent, new_data) self.seen_groups.update(new_groups.keys()) # associate lengths to classes @@ -378,16 +375,16 @@ def update(self, strategy: "SupervisedTemplate", **kwargs): ll = group_to_len[group_id] if group_id in self.buffer_groups: old_buffer_g = self.buffer_groups[group_id] - old_buffer_g.update_from_dataset(strategy, new_data_g) - old_buffer_g.resize(strategy, ll) + old_buffer_g.update_from_dataset(agent, new_data_g) + old_buffer_g.resize(agent, ll) else: new_buffer = _ParametricSingleBuffer(ll, self.selection_strategy) - new_buffer.update_from_dataset(strategy, new_data_g) + new_buffer.update_from_dataset(agent, new_data_g) self.buffer_groups[group_id] = new_buffer # resize buffers for group_id, class_buf in self.buffer_groups.items(): - self.buffer_groups[group_id].resize(strategy, group_to_len[group_id]) + self.buffer_groups[group_id].resize(agent, group_to_len[group_id]) def _make_groups( self, strategy, data: AvalancheDataset diff --git a/avalanche/training/templates/common_templates.py b/avalanche/training/templates/common_templates.py index 8405c5a6f..20ba0465a 100644 --- a/avalanche/training/templates/common_templates.py +++ b/avalanche/training/templates/common_templates.py @@ -80,7 +80,6 @@ class SupervisedTemplate( PLUGIN_CLASS = SupervisedPlugin - # TODO: remove default values of model and optimizer when legacy positional arguments are definitively removed def __init__( self, *, @@ -194,7 +193,6 @@ class SupervisedMetaLearningTemplate( PLUGIN_CLASS = SupervisedPlugin - # TODO: remove default values of model and optimizer when legacy positional arguments are definitively removed def __init__( self, *, diff --git a/docs/gitbook/from-zero-to-hero-tutorial/02_models.md b/docs/gitbook/from-zero-to-hero-tutorial/02_models.md index 0c566f83d..ba6ca49c2 100644 --- a/docs/gitbook/from-zero-to-hero-tutorial/02_models.md +++ b/docs/gitbook/from-zero-to-hero-tutorial/02_models.md @@ -109,7 +109,6 @@ print(mt_model) ### Nested Dynamic Modules Whenever one or more dynamic modules are nested one inside the other, you must call the `recursive_adaptation` method, and if they are nested inside a normal pytorch module (non dynamic), you can call the `avalanche_model_adaptation` function. Avalanche strategies will by default adapt the models before training on each experience by calling `avalanche_model_adaptation` - ```python benchmark = SplitMNIST(5, shuffle=False, class_ids_from_zero_in_each_exp=True, return_task_id=True) @@ -118,7 +117,7 @@ mt_model = as_multitask(model, 'classifier') print(mt_model) for exp in benchmark.train_stream: - mt_model.recursive_adaptation(exp) + mt_model.pre_adapt(exp) print(mt_model) ``` diff --git a/examples/updatable_objects.py b/examples/updatable_objects.py new file mode 100644 index 000000000..5313561d6 --- /dev/null +++ b/examples/updatable_objects.py @@ -0,0 +1,154 @@ +import os + +import setGPU + +import torch +from torch.optim import SGD +from torch.optim.lr_scheduler import ExponentialLR +from torch.utils.data import DataLoader + +from avalanche.benchmarks import SplitMNIST, with_classes_timeline +from avalanche.benchmarks.scenarios import split_online_stream +from avalanche.benchmarks.utils.data_loader import ReplayDataLoader +from avalanche.core import Agent + +from avalanche.evaluation.collector import MetricCollector +from avalanche.evaluation.functional import forgetting +from avalanche.evaluation.metrics import Accuracy +from avalanche.evaluation.plot_utils import plot_metric_matrix +from avalanche.models import SimpleMLP, IncrementalClassifier +from avalanche.models.dynamic_modules import avalanche_model_adaptation +from avalanche.models.dynamic_optimizers import update_optimizer, DynamicOptimizer +from avalanche.training import ReservoirSamplingBuffer, LearningWithoutForgetting +from avalanche.training.losses import MaskedCrossEntropy + + +def train_experience(agent_state, exp, epochs=10): + agent_state.model.train() + # avalanche datasets have train/eval modes to switch augmentations + data = exp.dataset.train() + agent_state.pre_adapt(exp) # update objects and call pre_hooks + for ep in range(epochs): + if len(agent_state.replay.buffer) > 0: + # if the replay buffer is not empty we sample from + # current data and replay buffer in parallel + dl = ReplayDataLoader( + data, agent_state.replay.buffer, batch_size=32, shuffle=True + ) + else: + dl = DataLoader(data, batch_size=32, shuffle=True) + + for x, y, _ in dl: + x, y = x.cuda(), y.cuda() + agent_state.opt.zero_grad() + yp = agent_state.model(x) + l = agent_state.loss(yp, y) + + # you may have to change this if your regularization loss + # needs different arguments + l += agent_state.reg_loss(x, yp, agent_state.model) + + l.backward() + agent_state.opt.step() + agent_state.post_adapt(exp) # update objects and call post_hooks + + +@torch.no_grad() +def my_eval(model, stream, metrics): + """Evaluate `model` on `stream` computing `metrics`. + + Returns a dictionary {metric_name: list-of-results}. + """ + model.eval() + res = {uo.__class__.__name__: [] for uo in metrics} + for exp in stream: + [uo.reset() for uo in metrics] + dl = DataLoader(exp.dataset.eval(), batch_size=512, num_workers=8) + for x, y, _ in dl: + x, y = x.cuda(), y.cuda() + yp = model(x) + [uo.update(yp, y) for uo in metrics] + [res[uo.__class__.__name__].append(uo.result()) for uo in metrics] + return res + + +if __name__ == "__main__": + bm = SplitMNIST(n_experiences=5) + train_stream, test_stream = bm.train_stream, bm.test_stream + # we split the training stream into online experiences + ocl_stream = split_online_stream(train_stream, experience_size=256, seed=1234) + # we add class attributes to the experiences + ocl_stream = with_classes_timeline(ocl_stream) + + # agent state collects all the objects that are needed during training + # many of these objects will have some state that is updated during training. + # Put all the stateful training objects here. Calling the agent `pre_adapt` + # and `post_adapt` methods will adapt all the objects. + # Sometimes, it may be useful to also put non-stateful objects (e.g. losses) + # to easily switch between them while keeping the same training loop. + agent = Agent() + agent.replay = ReservoirSamplingBuffer(max_size=200) + agent.loss = MaskedCrossEntropy() + agent.reg_loss = LearningWithoutForgetting(alpha=1, temperature=2) + + # Avalanche models support dynamic addition and expansion of parameters + agent.model = SimpleMLP().cuda() + agent.model.classifier = IncrementalClassifier(in_features=512).cuda() + agent.add_pre_hooks(lambda a, e: avalanche_model_adaptation(a.model, e)) + + # optimizer and scheduler + # we have update the optimizer before each experience. + # This is needed because the model's parameters may change if you are using + # a dynamic model. + opt = SGD(agent.model.parameters(), lr=0.001) + agent.opt = DynamicOptimizer(opt) + agent.scheduler = ExponentialLR(opt, gamma=0.999) + # we use a hook to call the scheduler. + # we update the lr scheduler after each experience (not every epoch!) + agent.add_post_hooks(lambda a, e: a.scheduler.step()) + + print("Start training...") + metrics = [Accuracy()] + mc = MetricCollector() # we put all the metric values here + for exp in ocl_stream: + if exp.logging().is_first_subexp: # new drift + print("training on classes ", exp.classes_in_this_experience) + train_experience(agent, exp, epochs=2) + if exp.logging().is_last_subexp: + # after learning new classes, do the eval + # notice that even in online settings we use the non-online stream + # for evaluation because each experience in the original stream + # corresponds to a separate data distribution. + res = my_eval(agent.model, train_stream, metrics) + mc.update(res, stream=train_stream) + print("TRAIN METRICS: ", res) + res = my_eval(agent.model, test_stream, metrics) + mc.update(res, stream=test_stream) + print("TEST METRICS: ", res) + print() + + os.makedirs("./logs", exist_ok=True) + torch.save(agent.model.state_dict(), "./logs/model.pth") + + print(mc.get_dict()) + mc.to_json("./logs/mc.json") # save metrics + + acc_timeline = mc.get("Accuracy", exp_reduce="sample_mean", stream=test_stream) + print("Accuracy:\n", mc.get("Accuracy", exp_reduce=None, stream=test_stream)) + print("AVG Accuracy: ", acc_timeline) + + acc_matrix = mc.get("Accuracy", exp_reduce=None, stream=test_stream) + fig = plot_metric_matrix(acc_matrix, title="Accuracy - Train") + fig.savefig("./logs/accuracy.png") + + fm = forgetting(acc_matrix) + print("Forgetting:\n", fm) + fig = plot_metric_matrix(fm, title="Forgetting - Train") + fig.savefig("./logs/forgetting.png") + + # BONUS: example of re-loading the model from disk + model = SimpleMLP() + model.classifier = IncrementalClassifier(in_features=512) + for exp in ocl_stream: # we need to adapt the model's architecture again + avalanche_model_adaptation(model, exp) + model.load_state_dict(torch.load("./logs/model.pth")) diff --git a/tests/evaluation/test_collector.py b/tests/evaluation/test_collector.py new file mode 100644 index 000000000..78ee2ac34 --- /dev/null +++ b/tests/evaluation/test_collector.py @@ -0,0 +1,124 @@ +""" Metrics Tests""" + +import unittest +from types import SimpleNamespace + +import torch + +from avalanche.evaluation.collector import MetricCollector +from avalanche.evaluation.functional import forgetting +import numpy as np + + +class MetricCollectorTests(unittest.TestCase): + def test_time_reduce(self): + # we just need an object with __len__ + class ToyStream: + def __init__(self): + self.els = [ + SimpleNamespace(dataset=list(range(1))), + SimpleNamespace(dataset=list(range(2))), + ] + self.name = "toy" + + def __getitem__(self, item): + return self.els[item] + + def __len__(self): + return len(self.els) + + fake_stream = ToyStream() + # m = MockMetric([0, 1, 2, 3, 4, 5]) + mc = MetricCollector() + stream_of_mvals = [[1, 3], [5, 7], [11, 13]] + for vvv in stream_of_mvals: + mc.update({"toy/FakeMetric": vvv}) + + # time_reduce = None + v = mc.get("FakeMetric", time_reduce=None, exp_reduce=None, stream=fake_stream) + np.testing.assert_array_almost_equal(v, stream_of_mvals) + v = mc.get( + "FakeMetric", time_reduce=None, exp_reduce="sample_mean", stream=fake_stream + ) + np.testing.assert_array_almost_equal( + v, [(1 + 3 * 2) / 3, (5 + 7 * 2) / 3, (11 + 13 * 2) / 3] + ) + v = mc.get( + "FakeMetric", + time_reduce=None, + exp_reduce="experience_mean", + stream=fake_stream, + ) + np.testing.assert_array_almost_equal( + v, [(1 + 3) / 2, (5 + 7) / 2, (11 + 13) / 2] + ) + v = mc.get( + "FakeMetric", + time_reduce=None, + exp_reduce="weighted_sum", + weights=[1, 2], + stream=fake_stream, + ) + np.testing.assert_array_almost_equal( + v, [(1 + 3 * 2), (5 + 7 * 2), (11 + 13 * 2)] + ) + + # time = "last" + v = mc.get( + "FakeMetric", time_reduce="last", exp_reduce=None, stream=fake_stream + ) + np.testing.assert_array_almost_equal(v, [11, 13]) + v = mc.get( + "FakeMetric", + time_reduce="last", + exp_reduce="sample_mean", + stream=fake_stream, + ) + self.assertAlmostEqual(v, (11 + 13 * 2) / 3) + v = mc.get( + "FakeMetric", + time_reduce="last", + exp_reduce="experience_mean", + stream=fake_stream, + ) + self.assertAlmostEqual(v, (11 + 13) / 2) + v = mc.get( + "FakeMetric", + time_reduce="last", + exp_reduce="weighted_sum", + stream=fake_stream, + weights=[1, 2], + ) + self.assertAlmostEqual(v, 11 + 13 * 2) + + # time_reduce = "mean" + v = mc.get( + "FakeMetric", time_reduce="mean", exp_reduce=None, stream=fake_stream + ) + np.testing.assert_array_almost_equal(v, [(1 + 5 + 11) / 3, (3 + 7 + 13) / 3]) + v = mc.get( + "FakeMetric", + time_reduce="mean", + exp_reduce="sample_mean", + stream=fake_stream, + ) + self.assertAlmostEqual(v, ((1 + 5 + 11) / 3 + (3 + 7 + 13) / 3 * 2) / 3) + v = mc.get( + "FakeMetric", + time_reduce="mean", + exp_reduce="experience_mean", + stream=fake_stream, + ) + self.assertAlmostEqual(v, ((1 + 5 + 11) / 3 + (3 + 7 + 13) / 3) / 2) + v = mc.get( + "FakeMetric", + time_reduce="mean", + exp_reduce="weighted_sum", + stream=fake_stream, + weights=[1, 2], + ) + self.assertAlmostEqual(v, ((1 + 3 * 2) + (5 + 7 * 2) + (11 + 13 * 2)) / 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/evaluation/test_functional.py b/tests/evaluation/test_functional.py new file mode 100644 index 000000000..7c78a509f --- /dev/null +++ b/tests/evaluation/test_functional.py @@ -0,0 +1,27 @@ +""" Metrics Tests""" + +import unittest + +import torch + +from avalanche.evaluation.functional import forgetting +import numpy as np + + +class FunctionalMetricTests(unittest.TestCase): + def test_diag_forgetting(self): + vals = [[1.0, 0.0, 0.0], [0.5, 0.8, 1.0], [0.2, 1.0, 0.5]] + target = [[0.0, 0.0, 0.0], [0.5, 0.0, 0.0], [0.8, -0.2, 0.0]] + fm = forgetting(vals) + np.testing.assert_array_almost_equal(fm, target) + + def test_non_diag_forgetting(self): + """non-diagonal matrix requires explicit boundary indices""" + vals = [[1.0, 0.0, 0.0], [0.5, 0.8, 1.0], [0.5, 0.5, 1.0], [0.2, 1.0, 0.5]] + target = [[0.0, 0.0, 0.0], [0.5, 0.0, 0.0], [0.5, 0.3, 0.0], [0.8, -0.2, 0.0]] + fm = forgetting(vals, boundary_indices=[0, 1, 3]) + np.testing.assert_array_almost_equal(fm, target) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/evaluation/test_plots.py b/tests/evaluation/test_plots.py new file mode 100644 index 000000000..2a628417a --- /dev/null +++ b/tests/evaluation/test_plots.py @@ -0,0 +1,23 @@ +""" Metrics Tests""" + +import unittest + +from avalanche.evaluation.plot_utils import plot_metric_matrix + + +class PlotMetricTests(unittest.TestCase): + def test_basic(self): + """just checking that the code runs without errors.""" + mvals = [ + [0.76370887, 0.43307481, 0.30893094, 0.25093633, 0.2075], + [0.74177468, 0.82286715, 0.55876494, 0.42234707, 0.341], + [0.57178465, 0.65048787, 0.7563081, 0.56941323, 0.4562], + [0.65802592, 0.52339254, 0.58947543, 0.68339576, 0.5474], + [0.65802592, 0.52339254, 0.58947543, 0.68339576, 0.5474], + [0.4995015, 0.58819114, 0.57320717, 0.62322097, 0.6925], + ] + plot_metric_matrix(mvals, title="Accuracy - Train") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/models/test_dynamic_optimizers.py b/tests/models/test_dynamic_optimizers.py new file mode 100644 index 000000000..c7f8b6f25 --- /dev/null +++ b/tests/models/test_dynamic_optimizers.py @@ -0,0 +1,130 @@ +import unittest + +from torch.nn import CrossEntropyLoss +from torch.optim import SGD +from torch.utils.data import DataLoader + +from avalanche.core import Agent +from avalanche.models import SimpleMLP, as_multitask, IncrementalClassifier, MTSimpleMLP +from avalanche.models.dynamic_modules import avalanche_model_adaptation +from avalanche.models.dynamic_optimizers import DynamicOptimizer +from avalanche.training import MaskedCrossEntropy +from tests.unit_tests_utils import get_fast_benchmark + + +class TestDynamicOptimizers(unittest.TestCase): + def test_dynamic_optimizer(self): + bm = get_fast_benchmark(use_task_labels=True) + agent = Agent() + agent.loss = MaskedCrossEntropy() + agent.model = as_multitask(SimpleMLP(input_size=6), "classifier") + opt = SGD(agent.model.parameters(), lr=0.001) + agent.opt = DynamicOptimizer(opt, agent.model, verbose=False) + + for exp in bm.train_stream: + agent.model.train() + data = exp.dataset.train() + agent.pre_adapt(exp) + for ep in range(1): + dl = DataLoader(data, batch_size=32, shuffle=True) + for x, y, t in dl: + agent.opt.zero_grad() + yp = agent.model(x, t) + l = agent.loss(yp, y) + l.backward() + agent.opt.step() + agent.post_adapt(exp) + + def init_scenario(self, multi_task=False): + if multi_task: + model = MTSimpleMLP(input_size=6, hidden_size=10) + else: + model = SimpleMLP(input_size=6, hidden_size=10) + model.classifier = IncrementalClassifier(10, 1) + criterion = CrossEntropyLoss() + benchmark = get_fast_benchmark(use_task_labels=multi_task) + return model, criterion, benchmark + + def _is_param_in_optimizer(self, param, optimizer): + for group in optimizer.param_groups: + for curr_p in group["params"]: + if hash(curr_p) == hash(param): + return True + return False + + def _is_param_in_optimizer_group(self, param, optimizer): + for group_idx, group in enumerate(optimizer.param_groups): + for curr_p in group["params"]: + if hash(curr_p) == hash(param): + return group_idx + return None + + def test_optimizer_groups_clf_til(self): + """ + Tests the automatic assignation of new + MultiHead parameters to the optimizer + """ + model, criterion, benchmark = self.init_scenario(multi_task=True) + + g1 = [] + g2 = [] + for n, p in model.named_parameters(): + if "classifier" in n: + g1.append(p) + else: + g2.append(p) + + agent = Agent() + agent.model = model + optimizer = SGD([{"params": g1, "lr": 0.1}, {"params": g2, "lr": 0.05}]) + agent.optimizer = DynamicOptimizer(optimizer, model=model, verbose=False) + + for experience in benchmark.train_stream: + avalanche_model_adaptation(model, experience) + agent.optimizer.pre_adapt(agent, experience) + + for n, p in model.named_parameters(): + assert self._is_param_in_optimizer(p, agent.optimizer.optim) + if "classifier" in n: + self.assertEqual( + self._is_param_in_optimizer_group(p, agent.optimizer.optim), 0 + ) + else: + self.assertEqual( + self._is_param_in_optimizer_group(p, agent.optimizer.optim), 1 + ) + + def test_optimizer_groups_clf_cil(self): + """ + Tests the automatic assignation of new + IncrementalClassifier parameters to the optimizer + """ + model, criterion, benchmark = self.init_scenario(multi_task=False) + + g1 = [] + g2 = [] + for n, p in model.named_parameters(): + if "classifier" in n: + g1.append(p) + else: + g2.append(p) + + agent = Agent() + agent.model = model + optimizer = SGD([{"params": g1, "lr": 0.1}, {"params": g2, "lr": 0.05}]) + agent.optimizer = DynamicOptimizer(optimizer, model) + + for experience in benchmark.train_stream: + avalanche_model_adaptation(model, experience) + agent.optimizer.pre_adapt(agent, experience) + + for n, p in model.named_parameters(): + assert self._is_param_in_optimizer(p, agent.optimizer.optim) + if "classifier" in n: + self.assertEqual( + self._is_param_in_optimizer_group(p, agent.optimizer.optim), 0 + ) + else: + self.assertEqual( + self._is_param_in_optimizer_group(p, agent.optimizer.optim), 1 + ) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 86c5f5854..fd1dda0f0 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -369,10 +369,10 @@ def test_recursive_adaptation(self): sizes = {} for t, exp in enumerate(benchmark.train_stream): sizes[t] = np.max(exp.classes_in_this_experience) + 1 - model1.recursive_adaptation(exp) + model1.pre_adapt(None, exp) # Second adaptation should not change anything for t, exp in enumerate(benchmark.train_stream): - model1.recursive_adaptation(exp) + model1.pre_adapt(None, exp) for t, s in sizes.items(): self.assertEqual(s, model1.classifiers[str(t)].classifier.out_features) @@ -385,7 +385,7 @@ def test_recursive_loop(self): model2.layer2 = model1 benchmark = get_fast_benchmark(use_task_labels=True, shuffle=True) - model1.recursive_adaptation(benchmark.train_stream[0]) + model1.pre_adapt(None, benchmark.train_stream[0]) def test_multi_head_classifier_masking(self): benchmark = get_fast_benchmark(use_task_labels=True, shuffle=True) diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 000000000..905f05bb6 --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,60 @@ +import unittest + +from avalanche.core import Agent + + +class SimpleAdaptableObject: + def __init__(self): + self.has_called_pre = False + self.has_called_post = False + + def pre_adapt(self, agent, exp): + self.has_called_pre = True + + def post_adapt(self, agent, exp): + self.has_called_post = True + + +class AgentTests(unittest.TestCase): + def test_adapt_hooks(self): + agent = Agent() + + IS_FOO_PRE_CALLED = False + + def foo_pre(a, e): + nonlocal IS_FOO_PRE_CALLED + IS_FOO_PRE_CALLED = True + + agent.add_pre_hooks(foo_pre) + + IS_FOO_POST_CALLED = False + + def foo_post(a, e): + nonlocal IS_FOO_POST_CALLED + IS_FOO_POST_CALLED = True + + agent.add_post_hooks(foo_post) + + e = None + agent.pre_adapt(e) + assert IS_FOO_PRE_CALLED + assert not IS_FOO_POST_CALLED + + IS_FOO_PRE_CALLED = False + agent.post_adapt(e) + assert IS_FOO_POST_CALLED + assert not IS_FOO_PRE_CALLED + + def test_adapt_objects(self): + agent = Agent() + agent.obj = SimpleAdaptableObject() + + e = None + agent.pre_adapt(e) + assert agent.obj.has_called_pre + assert not agent.obj.has_called_post + + agent.obj.has_called_pre = False + agent.post_adapt(e) + assert agent.obj.has_called_post + assert not agent.obj.has_called_pre diff --git a/tests/training/test_losses.py b/tests/training/test_losses.py index 4935fb88c..5a483c785 100644 --- a/tests/training/test_losses.py +++ b/tests/training/test_losses.py @@ -41,8 +41,8 @@ def test_loss(self): cross_entropy = nn.CrossEntropyLoss() criterion = MaskedCrossEntropy(mask="new") - criterion.adaptation([1, 2, 3, 4]) - criterion.adaptation([5, 6, 7]) + criterion._adaptation([1, 2, 3, 4]) + criterion._adaptation([5, 6, 7]) mb_y = torch.tensor([5, 5, 6, 7, 6]) diff --git a/tests/training/test_regularization.py b/tests/training/test_regularization.py index f407b7b2a..92a56ace9 100644 --- a/tests/training/test_regularization.py +++ b/tests/training/test_regularization.py @@ -1,4 +1,5 @@ import unittest +from types import SimpleNamespace import torch from torch.utils.data import DataLoader @@ -33,7 +34,7 @@ def test_lwf(self): assert loss == 0 else: assert loss > 0.0 - lwf.update(exp, teacher) + lwf.post_adapt(SimpleNamespace(model=teacher), exp) lwf = LearningWithoutForgetting() teacher = MTSimpleMLP(input_size=6) @@ -60,7 +61,7 @@ def test_lwf(self): assert torch.norm(weight.grad) > 0 model.zero_grad() - lwf.update(exp, teacher) + lwf.post_adapt(SimpleNamespace(model=teacher), exp) if __name__ == "__main__":