From 0c2a6aef56fd0bb3b2a220d01705ad0dca025138 Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Fri, 26 Jan 2024 17:08:46 +0100 Subject: [PATCH 01/21] WIP: Python 3.11 multiple inheritance compatibility --- .../ffcv_support/ffcv_transform_utils.py | 2 +- avalanche/training/supervised/ar1.py | 8 +- avalanche/training/supervised/cumulative.py | 13 +- avalanche/training/supervised/deep_slda.py | 14 +- avalanche/training/supervised/der.py | 5 + avalanche/training/supervised/er_ace.py | 29 ++-- avalanche/training/supervised/er_aml.py | 29 ++-- avalanche/training/supervised/expert_gate.py | 8 +- .../training/supervised/feature_replay.py | 24 +-- .../training/supervised/strategy_wrappers.py | 145 +++++++++--------- .../supervised_contrastive_replay.py | 26 ++-- avalanche/training/templates/base.py | 4 +- avalanche/training/templates/base_sgd.py | 33 +++- .../training/templates/common_templates.py | 115 +++++++++----- .../observation_type/batch_observation.py | 4 +- .../problem_type/supervised_problem.py | 3 + .../templates/strategy_mixin_protocol.py | 6 +- .../templates/update_type/sgd_update.py | 3 + examples/naive.py | 6 +- 19 files changed, 296 insertions(+), 181 deletions(-) diff --git a/avalanche/benchmarks/utils/ffcv_support/ffcv_transform_utils.py b/avalanche/benchmarks/utils/ffcv_support/ffcv_transform_utils.py index f72940b88..4642ffb0f 100644 --- a/avalanche/benchmarks/utils/ffcv_support/ffcv_transform_utils.py +++ b/avalanche/benchmarks/utils/ffcv_support/ffcv_transform_utils.py @@ -12,8 +12,8 @@ Tuple, Type, Union, + Literal, ) -from typing_extensions import Literal import warnings import numpy as np diff --git a/avalanche/training/supervised/ar1.py b/avalanche/training/supervised/ar1.py index 3a2f52075..8be1d75b9 100644 --- a/avalanche/training/supervised/ar1.py +++ b/avalanche/training/supervised/ar1.py @@ -65,6 +65,7 @@ def __init__( EvaluationPlugin, Callable[[], EvaluationPlugin] ] = default_evaluator, eval_every=-1, + **kwargs ): """ Creates an instance of the AR1 strategy. @@ -166,9 +167,9 @@ def __init__( self.replay_mb_size = 0 super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -176,6 +177,7 @@ def __init__( plugins=plugins, evaluator=evaluator, eval_every=eval_every, + **kwargs ) def _before_training_exp(self, **kwargs): diff --git a/avalanche/training/supervised/cumulative.py b/avalanche/training/supervised/cumulative.py index 88d1cc626..5596ac592 100644 --- a/avalanche/training/supervised/cumulative.py +++ b/avalanche/training/supervised/cumulative.py @@ -3,9 +3,7 @@ from torch.nn import Module from torch.optim import Optimizer -from torch.utils.data import ConcatDataset -from avalanche.benchmarks.utils import _concat_taskaware_classification_datasets from avalanche.benchmarks.utils.utils import concat_datasets from avalanche.training.plugins.evaluation import default_evaluator from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin @@ -33,6 +31,7 @@ def __init__( EvaluationPlugin, Callable[[], EvaluationPlugin] ] = default_evaluator, eval_every=-1, + **kwargs ): """Init. @@ -54,9 +53,9 @@ def __init__( """ super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -64,6 +63,7 @@ def __init__( plugins=plugins, evaluator=evaluator, eval_every=eval_every, + **kwargs ) self.dataset = None # cumulative dataset @@ -79,3 +79,6 @@ def train_dataset_adaptation(self, **kwargs): else: self.dataset = concat_datasets([self.dataset, exp.dataset]) self.adapted_dataset = self.dataset + + +__all__ = ["Cumulative"] diff --git a/avalanche/training/supervised/deep_slda.py b/avalanche/training/supervised/deep_slda.py index 83cfec37c..00468612f 100644 --- a/avalanche/training/supervised/deep_slda.py +++ b/avalanche/training/supervised/deep_slda.py @@ -44,6 +44,7 @@ def __init__( EvaluationPlugin, Callable[[], EvaluationPlugin] ] = default_evaluator, eval_every=-1, + **kwargs, ): """Init function for the SLDA model. @@ -78,16 +79,17 @@ def __init__( ).eval() super(StreamingLDA, self).__init__( - slda_model, - None, # type: ignore - criterion, - train_mb_size, - train_epochs, - eval_mb_size, + model=slda_model, + optimizer=None, # type: ignore + criterion=criterion, + train_mb_size=train_mb_size, + train_epochs=train_epochs, + eval_mb_size=eval_mb_size, device=device, plugins=plugins, evaluator=evaluator, eval_every=eval_every, + **kwargs, ) # SLDA parameters diff --git a/avalanche/training/supervised/der.py b/avalanche/training/supervised/der.py index 0dcc151a4..8b6e87b86 100644 --- a/avalanche/training/supervised/der.py +++ b/avalanche/training/supervised/der.py @@ -173,6 +173,7 @@ def __init__( ] = default_evaluator, eval_every=-1, peval_mode="epoch", + **kwargs ): """ :param model: PyTorch model. @@ -219,6 +220,7 @@ def __init__( evaluator=evaluator, eval_every=eval_every, peval_mode=peval_mode, + **kwargs ) if batch_size_mem is None: self.batch_size_mem = train_mb_size @@ -328,3 +330,6 @@ def training_epoch(self, **kwargs): self._after_update(**kwargs) self._after_training_iteration(**kwargs) + + +__all__ = ["DER"] diff --git a/avalanche/training/supervised/er_ace.py b/avalanche/training/supervised/er_ace.py index 25b90d370..daf7ee480 100644 --- a/avalanche/training/supervised/er_ace.py +++ b/avalanche/training/supervised/er_ace.py @@ -47,6 +47,7 @@ def __init__( ] = default_evaluator, eval_every=-1, peval_mode="epoch", + **kwargs ): """ :param model: PyTorch model. @@ -71,17 +72,18 @@ def __init__( `eval_every` epochs or iterations (Default='epoch'). """ super().__init__( - model, - optimizer, - criterion, - train_mb_size, - train_epochs, - eval_mb_size, - device, - plugins, - evaluator, - eval_every, - peval_mode, + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=train_mb_size, + train_epochs=train_epochs, + eval_mb_size=eval_mb_size, + device=device, + plugins=plugins, + evaluator=evaluator, + eval_every=eval_every, + peval_mode=peval_mode, + **kwargs ) self.mem_size = mem_size self.batch_size_mem = batch_size_mem @@ -170,3 +172,8 @@ def _train_cleanup(self): super()._train_cleanup() # reset the value to avoid serialization failures self.replay_loader = None + + +__all__ = [ + "ER_ACE", +] diff --git a/avalanche/training/supervised/er_aml.py b/avalanche/training/supervised/er_aml.py index 0b89aefd4..d247f8b76 100644 --- a/avalanche/training/supervised/er_aml.py +++ b/avalanche/training/supervised/er_aml.py @@ -46,6 +46,7 @@ def __init__( ] = default_evaluator, eval_every=-1, peval_mode="epoch", + **kwargs ): """ :param model: PyTorch model. @@ -74,17 +75,18 @@ def __init__( `eval_every` epochs or iterations (Default='epoch'). """ super().__init__( - model, - optimizer, - criterion, - train_mb_size, - train_epochs, - eval_mb_size, - device, - plugins, - evaluator, - eval_every, - peval_mode, + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=train_mb_size, + train_epochs=train_epochs, + eval_mb_size=eval_mb_size, + device=device, + plugins=plugins, + evaluator=evaluator, + eval_every=eval_every, + peval_mode=peval_mode, + **kwargs ) self.mem_size = mem_size self.batch_size_mem = batch_size_mem @@ -198,3 +200,8 @@ def _train_cleanup(self): super()._train_cleanup() # reset the value to avoid serialization failures self.replay_loader = None + + +__all__ = [ + "ER_AML", +] diff --git a/avalanche/training/supervised/expert_gate.py b/avalanche/training/supervised/expert_gate.py index 223e488e9..67d84e3b5 100644 --- a/avalanche/training/supervised/expert_gate.py +++ b/avalanche/training/supervised/expert_gate.py @@ -12,15 +12,12 @@ from collections import OrderedDict import warnings -import torch from torch.nn import Module, CrossEntropyLoss -from torch.optim import Optimizer, SGD, Adam -from torch.nn.functional import log_softmax +from torch.optim import Optimizer, SGD from typing import Optional, List from avalanche.models.expert_gate import ExpertAutoencoder, ExpertModel, ExpertGate -from avalanche.models.dynamic_optimizers import reset_optimizer from avalanche.training.supervised import AETraining from avalanche.training.templates import SupervisedTemplate from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin, LwFPlugin @@ -367,3 +364,6 @@ def _train_autoencoder(self, strategy: "SupervisedTemplate", autoencoder): # Train with autoencoder strategy ae_strategy.train(strategy.experience) print("FINISHED TRAINING NEW AUTOENCODER\n") + + +__all__ = ["ExpertGateStrategy"] diff --git a/avalanche/training/supervised/feature_replay.py b/avalanche/training/supervised/feature_replay.py index a170b3340..40e29f2b5 100644 --- a/avalanche/training/supervised/feature_replay.py +++ b/avalanche/training/supervised/feature_replay.py @@ -62,19 +62,21 @@ def __init__( ] = default_evaluator, eval_every=-1, peval_mode="epoch", + **kwargs ): super().__init__( - model, - optimizer, - criterion, - train_mb_size, - train_epochs, - eval_mb_size, - device, - plugins, - evaluator, - eval_every, - peval_mode, + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=train_mb_size, + train_epochs=train_epochs, + eval_mb_size=eval_mb_size, + device=device, + plugins=plugins, + evaluator=evaluator, + eval_every=eval_every, + peval_mode=peval_mode, + **kwargs ) self.mem_size = mem_size self.batch_size_mem = batch_size_mem diff --git a/avalanche/training/supervised/strategy_wrappers.py b/avalanche/training/supervised/strategy_wrappers.py index 30a4484a0..e9c5b945b 100644 --- a/avalanche/training/supervised/strategy_wrappers.py +++ b/avalanche/training/supervised/strategy_wrappers.py @@ -65,8 +65,9 @@ class Naive(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, + *args, + model: Module = None, + optimizer: Optimizer = None, criterion=CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, @@ -100,10 +101,12 @@ def __init__( :param base_kwargs: any additional :class:`~avalanche.training.BaseTemplate` constructor arguments. """ + super().__init__( - model, - optimizer, - criterion, + *args, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -123,8 +126,9 @@ class PNNStrategy(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, + *args, + model: Module = None, + optimizer: Optimizer = None, criterion=CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, @@ -158,8 +162,11 @@ def __init__( :class:`~avalanche.training.BaseTemplate` constructor arguments. """ # Check that the model has the correct architecture. - assert isinstance(model, PNN), "PNNStrategy requires a PNN model." + assert isinstance(model, PNN) or ( + len(args) > 0 and isinstance(args[0], PNN) + ), "PNNStrategy requires a PNN model." super().__init__( + *args, model=model, optimizer=optimizer, criterion=criterion, @@ -244,9 +251,9 @@ def __init__( raise ValueError("PackNet requires a model that implements PackNetModule.") super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -306,9 +313,9 @@ def __init__( else: plugins.append(cwsp) super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -372,9 +379,9 @@ def __init__( else: plugins.append(rp) super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -503,9 +510,9 @@ def __init__( plugins.append(rp) super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -571,9 +578,9 @@ def __init__( """ super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -645,9 +652,9 @@ def __init__( """ super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -720,9 +727,9 @@ def __init__( else: plugins.append(rp) super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -787,9 +794,9 @@ def __init__( plugins.append(gdumb) super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -856,9 +863,9 @@ def __init__( plugins.append(lwf) super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -926,9 +933,9 @@ def __init__( plugins.append(agem) super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -996,9 +1003,9 @@ def __init__( plugins.append(gem) super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -1077,9 +1084,9 @@ def __init__( plugins.append(ewc) super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -1161,13 +1168,13 @@ def __init__( # entire implementation of the strategy! plugins.append(SynapticIntelligencePlugin(si_lambda=si_lambda, eps=eps)) - super(SynapticIntelligence, self).__init__( - model, - optimizer, - criterion, - train_mb_size, - train_epochs, - eval_mb_size, + super().__init__( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=train_mb_size, + train_epochs=train_epochs, + eval_mb_size=eval_mb_size, device=device, plugins=plugins, evaluator=evaluator, @@ -1240,9 +1247,9 @@ def __init__( else: plugins.append(copep) super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -1308,9 +1315,9 @@ def __init__( plugins.append(lfl) super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -1385,9 +1392,9 @@ def __init__( plugins.append(mas) super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -1477,9 +1484,9 @@ def __init__( plugins.append(bic) super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -1552,9 +1559,9 @@ def __init__( plugins.append(mir) super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -1621,9 +1628,9 @@ def __init__( else: plugins.append(fstp) super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, diff --git a/avalanche/training/supervised/supervised_contrastive_replay.py b/avalanche/training/supervised/supervised_contrastive_replay.py index 6f32cb33b..8eba2c6bc 100644 --- a/avalanche/training/supervised/supervised_contrastive_replay.py +++ b/avalanche/training/supervised/supervised_contrastive_replay.py @@ -48,6 +48,7 @@ def __init__( evaluator=default_evaluator, eval_every=-1, peval_mode="epoch", + **kwargs ): """ :param model: an Avalanche model like the avalanche.models.SCRModel, @@ -111,17 +112,17 @@ def __init__( else: raise ValueError("`plugins` parameter needs to be a list.") super().__init__( - model, - optimizer, - SCRLoss(temperature=self.temperature), - train_mb_size, - train_epochs, - eval_mb_size, - device, - plugins, - evaluator, - eval_every, - peval_mode, + model=model, + optimizer=optimizer, + criterion=SCRLoss(temperature=self.temperature), + train_mb_size=train_mb_size, + train_epochs=train_epochs, + eval_mb_size=eval_mb_size, + device=device, + plugins=plugins, + evaluator=evaluator, + eval_every=eval_every, + peval_mode=peval_mode, ) def criterion(self): @@ -189,3 +190,6 @@ def compute_class_means(self): class_means[label] /= class_means[label].norm() self.model.eval_classifier.update_class_means_dict(class_means) + + +__all__ = ["SCR"] diff --git a/avalanche/training/templates/base.py b/avalanche/training/templates/base.py index 7cd34c6a6..4cd5abe22 100644 --- a/avalanche/training/templates/base.py +++ b/avalanche/training/templates/base.py @@ -39,12 +39,14 @@ class BaseTemplate(BaseStrategyProtocol[TExperienceType]): def __init__( self, + *, model: Module, device: Union[str, torch.device] = "cpu", plugins: Optional[Sequence[BasePlugin]] = None, + **kwargs, ): - super().__init__() """Init.""" + super().__init__(model=model, device=device, plugins=plugins, **kwargs) self.model: Module = model """ PyTorch model. """ diff --git a/avalanche/training/templates/base_sgd.py b/avalanche/training/templates/base_sgd.py index 4bcaec94c..8c8686f9e 100644 --- a/avalanche/training/templates/base_sgd.py +++ b/avalanche/training/templates/base_sgd.py @@ -1,4 +1,5 @@ from typing import Any, Callable, Iterable, Sequence, Optional, TypeVar, Union +from typing_extensions import TypeAlias from packaging.version import parse import torch @@ -28,10 +29,12 @@ TMBInput = TypeVar("TMBInput") TMBOutput = TypeVar("TMBOutput") +CriterionType: TypeAlias = Union[Module, Callable[[Tensor, Tensor], Tensor]] + class BaseSGDTemplate( - SGDStrategyProtocol[TDatasetExperience, TMBInput, TMBOutput], BaseTemplate[TDatasetExperience], + SGDStrategyProtocol[TDatasetExperience, TMBInput, TMBOutput], ): """Base SGD class for continual learning skeletons. @@ -53,9 +56,10 @@ class BaseSGDTemplate( def __init__( self, + *, model: Module, optimizer: Optimizer, - criterion=CrossEntropyLoss(), + criterion: CriterionType = CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = 1, @@ -66,6 +70,7 @@ def __init__( ] = default_evaluator, eval_every=-1, peval_mode="epoch", + **kwargs ): """Init. @@ -87,8 +92,21 @@ def __init__( `eval_every` epochs or iterations (Default='epoch'). """ - super().__init__() # type: ignore - BaseTemplate.__init__(self=self, model=model, device=device, plugins=plugins) + # Call super with all args + super().__init__( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=train_mb_size, + train_epochs=train_epochs, + eval_mb_size=eval_mb_size, + device=device, + plugins=plugins, + evaluator=evaluator, + eval_every=eval_every, + peval_mode=peval_mode, + **kwargs + ) self.optimizer: Optimizer = optimizer """ PyTorch optimizer. """ @@ -617,3 +635,10 @@ def after_training_exp(self, strategy, **kwargs): # if self.peval_mode == "experience": # self._maybe_peval(strategy, strategy.clock.train_exp_counter, # **kwargs) + + +__all__ = [ + "CriterionType", + "BaseSGDTemplate", + "PeriodicEval", +] diff --git a/avalanche/training/templates/common_templates.py b/avalanche/training/templates/common_templates.py index 61ba13fac..280068941 100644 --- a/avalanche/training/templates/common_templates.py +++ b/avalanche/training/templates/common_templates.py @@ -1,8 +1,10 @@ from typing import Callable, Sequence, Optional, TypeVar, Union +import warnings import torch from torch.nn import Module, CrossEntropyLoss from torch.optim import Optimizer +from torch import Tensor from ...benchmarks.scenarios.deprecated import DatasetExperience from avalanche.core import BasePlugin, BaseSGDPlugin, SupervisedPlugin @@ -17,7 +19,7 @@ from .observation_type import * from .problem_type import * from .update_type import * -from .base_sgd import BaseSGDTemplate +from .base_sgd import BaseSGDTemplate, CriterionType TDatasetExperience = TypeVar("TDatasetExperience", bound=DatasetExperience) @@ -25,12 +27,45 @@ TMBOutput = TypeVar("TMBOutput") +def _merge_legacy_positional_arguments(args, kwargs): + # Manage the legacy positional constructor parameters + if len(args) > 0: + warnings.warn( + "Passing positional arguments to Strategy constructors is " + "deprecated. Please use keyword arguments instead." + ) + + # unroll args and apply it to kwargs + legacy_kwargs_order = [ + "model", + "optimizer", + "criterion", + "train_mb_size", + "train_epochs", + "eval_mb_size", + "device", + "plugins", + "evaluator", + "eval_every", + "peval_mode", + ] + + for i, arg in enumerate(args): + kwargs[legacy_kwargs_order[i]] = arg + + for key, value in kwargs.items(): + if value == "not_set": + raise ValueError(f"Parameter {key} is not set") + + return kwargs + + class SupervisedTemplate( BatchObservation, SupervisedProblem, SGDUpdate, - SupervisedStrategyProtocol[TDatasetExperience, TMBInput, TMBOutput], BaseSGDTemplate[TDatasetExperience, TMBInput, TMBOutput], + SupervisedStrategyProtocol[TDatasetExperience, TMBInput, TMBOutput], ): """Base class for continual learning strategies. @@ -79,11 +114,13 @@ class SupervisedTemplate( PLUGIN_CLASS = SupervisedPlugin + # TODO: remove default values of model and optimizer when legacy positional arguments are definitively removed def __init__( self, - model: Module, - optimizer: Optimizer, - criterion=CrossEntropyLoss(), + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = 1, @@ -94,6 +131,7 @@ def __init__( ] = default_evaluator, eval_every=-1, peval_mode="epoch", + **kwargs, ): """Init. @@ -119,21 +157,21 @@ def __init__( periodic evaluation during training should execute every `eval_every` epochs or iterations (Default='epoch'). """ - super().__init__() # type: ignore - BaseSGDTemplate.__init__( - self=self, - model=model, - optimizer=optimizer, - criterion=criterion, - train_mb_size=train_mb_size, - train_epochs=train_epochs, - eval_mb_size=eval_mb_size, - device=device, - plugins=plugins, - evaluator=evaluator, - eval_every=eval_every, - peval_mode=peval_mode, - ) + kwargs["model"] = model + kwargs["optimizer"] = optimizer + kwargs["criterion"] = criterion + kwargs["train_mb_size"] = train_mb_size + kwargs["train_epochs"] = train_epochs + kwargs["eval_mb_size"] = eval_mb_size + kwargs["device"] = device + kwargs["plugins"] = plugins + kwargs["evaluator"] = evaluator + kwargs["eval_every"] = eval_every + kwargs["peval_mode"] = peval_mode + + kwargs = _merge_legacy_positional_arguments(args, kwargs) + + super().__init__(**kwargs) ################################################################### # State variables. These are updated during the train/eval loops. # ################################################################### @@ -201,10 +239,12 @@ class SupervisedMetaLearningTemplate( PLUGIN_CLASS = SupervisedPlugin + # TODO: remove default values of model and optimizer when legacy positional arguments are definitively removed def __init__( self, - model: Module, - optimizer: Optimizer, + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", criterion=CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, @@ -216,6 +256,7 @@ def __init__( ] = default_evaluator, eval_every=-1, peval_mode="epoch", + **kwargs, ): """Init. @@ -241,21 +282,21 @@ def __init__( periodic evaluation during training should execute every `eval_every` epochs or iterations (Default='epoch'). """ - super().__init__() # type: ignore - BaseSGDTemplate.__init__( - self=self, - model=model, - optimizer=optimizer, - criterion=criterion, - train_mb_size=train_mb_size, - train_epochs=train_epochs, - eval_mb_size=eval_mb_size, - device=device, - plugins=plugins, - evaluator=evaluator, - eval_every=eval_every, - peval_mode=peval_mode, - ) + kwargs["model"] = model + kwargs["optimizer"] = optimizer + kwargs["criterion"] = criterion + kwargs["train_mb_size"] = train_mb_size + kwargs["train_epochs"] = train_epochs + kwargs["eval_mb_size"] = eval_mb_size + kwargs["device"] = device + kwargs["plugins"] = plugins + kwargs["evaluator"] = evaluator + kwargs["eval_every"] = eval_every + kwargs["peval_mode"] = peval_mode + + kwargs = _merge_legacy_positional_arguments(args, kwargs) + + super().__init__(**kwargs) ################################################################### # State variables. These are updated during the train/eval loops. # ################################################################### diff --git a/avalanche/training/templates/observation_type/batch_observation.py b/avalanche/training/templates/observation_type/batch_observation.py index e71bcf61e..6bf2cb696 100644 --- a/avalanche/training/templates/observation_type/batch_observation.py +++ b/avalanche/training/templates/observation_type/batch_observation.py @@ -12,8 +12,8 @@ class BatchObservation(SGDStrategyProtocol): - def __init__(self): - super().__init__() + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.optimized_param_id = None def model_adaptation(self, model=None): diff --git a/avalanche/training/templates/problem_type/supervised_problem.py b/avalanche/training/templates/problem_type/supervised_problem.py index b2a264b63..2479dfe05 100644 --- a/avalanche/training/templates/problem_type/supervised_problem.py +++ b/avalanche/training/templates/problem_type/supervised_problem.py @@ -8,6 +8,9 @@ # Also confirmed here: https://stackoverflow.com/a/70907644 # PyLance just does not understand it class SupervisedProblem(SupervisedStrategyProtocol): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + @property def mb_x(self): """Current mini-batch input.""" diff --git a/avalanche/training/templates/strategy_mixin_protocol.py b/avalanche/training/templates/strategy_mixin_protocol.py index 217e9a1f2..ea3c54f8e 100644 --- a/avalanche/training/templates/strategy_mixin_protocol.py +++ b/avalanche/training/templates/strategy_mixin_protocol.py @@ -104,7 +104,8 @@ def _after_training_iteration(self, **kwargs): class SupervisedStrategyProtocol( - SGDStrategyProtocol[TSGDExperienceType, TMBinput, TMBoutput], Protocol + SGDStrategyProtocol[TSGDExperienceType, TMBinput, TMBoutput], + Protocol[TSGDExperienceType, TMBinput, TMBoutput], ): mb_x: Tensor @@ -114,7 +115,8 @@ class SupervisedStrategyProtocol( class MetaLearningStrategyProtocol( - SGDStrategyProtocol[TSGDExperienceType, TMBinput, TMBoutput], Protocol + SGDStrategyProtocol[TSGDExperienceType, TMBinput, TMBoutput], + Protocol[TSGDExperienceType, TMBinput, TMBoutput], ): def _before_inner_updates(self, **kwargs): ... diff --git a/avalanche/training/templates/update_type/sgd_update.py b/avalanche/training/templates/update_type/sgd_update.py index 8014dc409..f44343367 100644 --- a/avalanche/training/templates/update_type/sgd_update.py +++ b/avalanche/training/templates/update_type/sgd_update.py @@ -2,6 +2,9 @@ class SGDUpdate(SGDStrategyProtocol): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + def training_epoch(self, **kwargs): """Training epoch. diff --git a/examples/naive.py b/examples/naive.py index f03c1fc6f..e63241c90 100644 --- a/examples/naive.py +++ b/examples/naive.py @@ -39,9 +39,9 @@ def main(): # create strategy strategy = Naive( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_epochs=1, device=device, train_mb_size=32, From 2137b3c6ade44be8c3443ff1c8fbe296d9f3f80e Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Mon, 29 Jan 2024 15:28:10 +0100 Subject: [PATCH 02/21] Resolve multiple inheritance compatibility --- avalanche/training/templates/base_sgd.py | 26 +++++++++++- .../training/templates/common_templates.py | 42 ++++++------------- 2 files changed, 36 insertions(+), 32 deletions(-) diff --git a/avalanche/training/templates/base_sgd.py b/avalanche/training/templates/base_sgd.py index 8c8686f9e..600af3361 100644 --- a/avalanche/training/templates/base_sgd.py +++ b/avalanche/training/templates/base_sgd.py @@ -1,3 +1,4 @@ +import sys from typing import Any, Callable, Iterable, Sequence, Optional, TypeVar, Union from typing_extensions import TypeAlias from packaging.version import parse @@ -33,8 +34,8 @@ class BaseSGDTemplate( - BaseTemplate[TDatasetExperience], SGDStrategyProtocol[TDatasetExperience, TMBInput, TMBOutput], + BaseTemplate[TDatasetExperience], ): """Base SGD class for continual learning skeletons. @@ -92,7 +93,6 @@ def __init__( `eval_every` epochs or iterations (Default='epoch'). """ - # Call super with all args super().__init__( model=model, optimizer=optimizer, @@ -108,6 +108,28 @@ def __init__( **kwargs ) + # Call super with all args + if sys.version_info >= (3, 11): + super().__init__( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=train_mb_size, + train_epochs=train_epochs, + eval_mb_size=eval_mb_size, + device=device, + plugins=plugins, + evaluator=evaluator, + eval_every=eval_every, + peval_mode=peval_mode, + **kwargs + ) + else: + super().__init__() # type: ignore + BaseTemplate.__init__( + self=self, model=model, device=device, plugins=plugins + ) + self.optimizer: Optimizer = optimizer """ PyTorch optimizer. """ diff --git a/avalanche/training/templates/common_templates.py b/avalanche/training/templates/common_templates.py index 280068941..9f48e1d7b 100644 --- a/avalanche/training/templates/common_templates.py +++ b/avalanche/training/templates/common_templates.py @@ -1,3 +1,4 @@ +import sys from typing import Callable, Sequence, Optional, TypeVar, Union import warnings import torch @@ -64,10 +65,9 @@ class SupervisedTemplate( BatchObservation, SupervisedProblem, SGDUpdate, - BaseSGDTemplate[TDatasetExperience, TMBInput, TMBOutput], SupervisedStrategyProtocol[TDatasetExperience, TMBInput, TMBOutput], + BaseSGDTemplate[TDatasetExperience, TMBInput, TMBOutput], ): - """Base class for continual learning strategies. SupervisedTemplate is the super class of all supervised task-based @@ -171,20 +171,11 @@ def __init__( kwargs = _merge_legacy_positional_arguments(args, kwargs) - super().__init__(**kwargs) - ################################################################### - # State variables. These are updated during the train/eval loops. # - ################################################################### - - # self.adapted_dataset = None - # """ Data used to train. It may be modified by plugins. Plugins can - # append data to it (e.g. for replay). - # - # .. note:: - # - # This dataset may contain samples from different experiences. If you - # want the original data for the current experience - # use :attr:`.BaseTemplate.experience`. + if sys.version_info >= (3, 11): + super().__init__(**kwargs) + else: + super().__init__() + BaseSGDTemplate.__init__(self=self, **kwargs) class SupervisedMetaLearningTemplate( @@ -296,20 +287,11 @@ def __init__( kwargs = _merge_legacy_positional_arguments(args, kwargs) - super().__init__(**kwargs) - ################################################################### - # State variables. These are updated during the train/eval loops. # - ################################################################### - - # self.adapted_dataset = None - # """ Data used to train. It may be modified by plugins. Plugins can - # append data to it (e.g. for replay). - # - # .. note:: - # - # This dataset may contain samples from different experiences. If you - # want the original data for the current experience - # use :attr:`.BaseTemplate.experience`. + if sys.version_info >= (3, 11): + super().__init__(**kwargs) + else: + super().__init__() + BaseSGDTemplate.__init__(self=self, **kwargs) __all__ = [ From 84a773e88134157a95ceeaa32293301036c6c873 Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Mon, 29 Jan 2024 15:28:21 +0100 Subject: [PATCH 03/21] black linting --- avalanche/_annotations.py | 1 + .../deprecated/classification_scenario.py | 6 +- .../scenarios/deprecated/dataset_scenario.py | 18 ++--- .../deprecated/generic_benchmark_creation.py | 6 +- .../deprecated/lazy_dataset_sequence.py | 18 +++-- .../deprecated/new_classes/nc_scenario.py | 7 +- .../scenarios/detection_scenario.py | 6 +- .../benchmarks/scenarios/generic_scenario.py | 6 +- .../utils/classification_dataset.py | 35 ++++------ avalanche/benchmarks/utils/data.py | 6 +- avalanche/benchmarks/utils/data_attribute.py | 6 +- .../benchmarks/utils/dataset_definitions.py | 6 +- avalanche/benchmarks/utils/dataset_utils.py | 6 +- .../benchmarks/utils/detection_dataset.py | 27 +++----- .../utils/ffcv_support/ffcv_epoch_iterator.py | 1 + .../utils/ffcv_support/ffcv_loader.py | 1 + avalanche/benchmarks/utils/flat_data.py | 18 ++--- avalanche/evaluation/metric_definitions.py | 6 +- .../evaluation/metrics/forgetting_bwt.py | 9 +-- .../evaluation/metrics/labels_repartition.py | 20 +++--- avalanche/evaluation/metrics/mean_scores.py | 1 + avalanche/models/slim_resnet18.py | 1 + avalanche/models/timm_vit.py | 9 ++- avalanche/training/__init__.py | 1 + avalanche/training/plugins/agem.py | 16 +++-- avalanche/training/plugins/bic.py | 6 +- avalanche/training/plugins/gem.py | 16 +++-- avalanche/training/regularization.py | 1 + avalanche/training/supervised/__init__.py | 1 + avalanche/training/supervised/expert_gate.py | 6 +- .../training/supervised/joint_training.py | 6 +- avalanche/training/supervised/lamaml.py | 8 ++- avalanche/training/supervised/lamaml_v2.py | 8 ++- avalanche/training/templates/__init__.py | 1 + avalanche/training/templates/base.py | 12 ++-- .../templates/observation_type/__init__.py | 1 + .../observation_type/batch_observation.py | 6 -- .../templates/problem_type/__init__.py | 1 + .../templates/strategy_mixin_protocol.py | 66 +++++++------------ examples/detection.py | 8 +-- examples/detection_lvis.py | 8 +-- examples/hf_datawrapper.py | 1 + examples/nlp_qa.py | 3 +- examples/tvdetection/train.py | 1 + profiling/data_merging.py | 1 + profiling/serialization.py | 1 + tests/test_loggers.py | 1 + 47 files changed, 176 insertions(+), 224 deletions(-) diff --git a/avalanche/_annotations.py b/avalanche/_annotations.py index 7a6b12df6..1820ae811 100644 --- a/avalanche/_annotations.py +++ b/avalanche/_annotations.py @@ -4,6 +4,7 @@ only for internal use in the library. """ + import inspect import warnings import functools diff --git a/avalanche/benchmarks/scenarios/deprecated/classification_scenario.py b/avalanche/benchmarks/scenarios/deprecated/classification_scenario.py index ecdabacaa..b195a27f6 100644 --- a/avalanche/benchmarks/scenarios/deprecated/classification_scenario.py +++ b/avalanche/benchmarks/scenarios/deprecated/classification_scenario.py @@ -257,12 +257,10 @@ def __len__(self) -> int: return len(self._benchmark.streams[self._stream]) @overload - def __getitem__(self, exp_id: int) -> Optional[Set[int]]: - ... + def __getitem__(self, exp_id: int) -> Optional[Set[int]]: ... @overload - def __getitem__(self, exp_id: slice) -> Tuple[Optional[Set[int]], ...]: - ... + def __getitem__(self, exp_id: slice) -> Tuple[Optional[Set[int]], ...]: ... def __getitem__(self, exp_id: Union[int, slice]) -> LazyClassesInExpsRet: indexing_collate = _LazyClassesInClassificationExps._slice_collate diff --git a/avalanche/benchmarks/scenarios/deprecated/dataset_scenario.py b/avalanche/benchmarks/scenarios/deprecated/dataset_scenario.py index 5547a23b5..5af73eb9b 100644 --- a/avalanche/benchmarks/scenarios/deprecated/dataset_scenario.py +++ b/avalanche/benchmarks/scenarios/deprecated/dataset_scenario.py @@ -184,17 +184,17 @@ def __init__( invoking the super constructor) to specialize the experience class. """ - self.experience_factory: Callable[ - [TCLStream, int], TDatasetExperience - ] = experience_factory + self.experience_factory: Callable[[TCLStream, int], TDatasetExperience] = ( + experience_factory + ) - self.stream_factory: Callable[ - [str, TDatasetScenario], TCLStream - ] = stream_factory + self.stream_factory: Callable[[str, TDatasetScenario], TCLStream] = ( + stream_factory + ) - self.stream_definitions: Dict[ - str, StreamDef[TCLDataset] - ] = DatasetScenario._check_stream_definitions(stream_definitions) + self.stream_definitions: Dict[str, StreamDef[TCLDataset]] = ( + DatasetScenario._check_stream_definitions(stream_definitions) + ) """ A structure containing the definition of the streams. """ diff --git a/avalanche/benchmarks/scenarios/deprecated/generic_benchmark_creation.py b/avalanche/benchmarks/scenarios/deprecated/generic_benchmark_creation.py index 9bf1e9b6e..351d7b6f3 100644 --- a/avalanche/benchmarks/scenarios/deprecated/generic_benchmark_creation.py +++ b/avalanche/benchmarks/scenarios/deprecated/generic_benchmark_creation.py @@ -143,9 +143,9 @@ def create_multi_dataset_generic_benchmark( "complete_test_set_only is True" ) - stream_definitions: Dict[ - str, Tuple[Iterable[TaskAwareClassificationDataset]] - ] = dict() + stream_definitions: Dict[str, Tuple[Iterable[TaskAwareClassificationDataset]]] = ( + dict() + ) for stream_name, dataset_list in input_streams.items(): initial_transform_group = "train" diff --git a/avalanche/benchmarks/scenarios/deprecated/lazy_dataset_sequence.py b/avalanche/benchmarks/scenarios/deprecated/lazy_dataset_sequence.py index 1d20fb774..82251d0da 100644 --- a/avalanche/benchmarks/scenarios/deprecated/lazy_dataset_sequence.py +++ b/avalanche/benchmarks/scenarios/deprecated/lazy_dataset_sequence.py @@ -99,9 +99,9 @@ def __init__( now, including the ones of dropped experiences. """ - self.task_labels_field_sequence: Dict[ - int, Optional[Sequence[int]] - ] = defaultdict(lambda: None) + self.task_labels_field_sequence: Dict[int, Optional[Sequence[int]]] = ( + defaultdict(lambda: None) + ) """ A dictionary mapping each experience to its `targets_task_labels` field. @@ -118,12 +118,10 @@ def __len__(self) -> int: return self._stream_length @overload - def __getitem__(self, exp_idx: int) -> TCLDataset: - ... + def __getitem__(self, exp_idx: int) -> TCLDataset: ... @overload - def __getitem__(self, exp_idx: slice) -> Sequence[TCLDataset]: - ... + def __getitem__(self, exp_idx: slice) -> Sequence[TCLDataset]: ... def __getitem__( self, exp_idx: Union[int, slice] @@ -135,9 +133,9 @@ def __getitem__( :return: The dataset associated to the experience. """ # A lot of unuseful lines needed for MyPy -_- - indexing_collate: Callable[ - [Iterable[TCLDataset]], Sequence[TCLDataset] - ] = lambda x: list(x) + indexing_collate: Callable[[Iterable[TCLDataset]], Sequence[TCLDataset]] = ( + lambda x: list(x) + ) result = manage_advanced_indexing( exp_idx, self._get_experience_and_load_if_needed, diff --git a/avalanche/benchmarks/scenarios/deprecated/new_classes/nc_scenario.py b/avalanche/benchmarks/scenarios/deprecated/new_classes/nc_scenario.py index 8c400ef9d..a5509b18d 100644 --- a/avalanche/benchmarks/scenarios/deprecated/new_classes/nc_scenario.py +++ b/avalanche/benchmarks/scenarios/deprecated/new_classes/nc_scenario.py @@ -34,7 +34,6 @@ class NCScenario( "NCStream", "NCExperience", TaskAwareSupervisedClassificationDataset ] ): - """ This class defines a "New Classes" scenario. Once created, an instance of this class can be iterated in order to obtain the experience sequence @@ -328,9 +327,9 @@ class "34" will be mapped to "1", class "11" to "2" and so on. # used, the user may have defined an amount of classes less than # the overall amount of classes in the dataset. if class_id in self.classes_order_original_ids: - self.class_mapping[ - class_id - ] = self.classes_order_original_ids.index(class_id) + self.class_mapping[class_id] = ( + self.classes_order_original_ids.index(class_id) + ) elif self.class_ids_from_zero_in_each_exp: # Method 2: remap class IDs so that they appear in range [0, N] in # each experience diff --git a/avalanche/benchmarks/scenarios/detection_scenario.py b/avalanche/benchmarks/scenarios/detection_scenario.py index 8b16be027..111c5ac11 100644 --- a/avalanche/benchmarks/scenarios/detection_scenario.py +++ b/avalanche/benchmarks/scenarios/detection_scenario.py @@ -254,12 +254,10 @@ def __len__(self): return len(self._benchmark.streams[self._stream]) @overload - def __getitem__(self, exp_id: int) -> Optional[Set[int]]: - ... + def __getitem__(self, exp_id: int) -> Optional[Set[int]]: ... @overload - def __getitem__(self, exp_id: slice) -> Tuple[Optional[Set[int]], ...]: - ... + def __getitem__(self, exp_id: slice) -> Tuple[Optional[Set[int]], ...]: ... def __getitem__(self, exp_id: Union[int, slice]) -> LazyClassesInExpsRet: indexing_collate = _LazyClassesInDetectionExps._slice_collate diff --git a/avalanche/benchmarks/scenarios/generic_scenario.py b/avalanche/benchmarks/scenarios/generic_scenario.py index d7b5ba09c..34da0d249 100644 --- a/avalanche/benchmarks/scenarios/generic_scenario.py +++ b/avalanche/benchmarks/scenarios/generic_scenario.py @@ -427,12 +427,10 @@ def __iter__(self) -> Iterator[TCLExperience]: yield exp @overload - def __getitem__(self, item: int) -> TCLExperience: - ... + def __getitem__(self, item: int) -> TCLExperience: ... @overload - def __getitem__(self: TSequenceCLStream, item: slice) -> TSequenceCLStream: - ... + def __getitem__(self: TSequenceCLStream, item: slice) -> TSequenceCLStream: ... @final def __getitem__( diff --git a/avalanche/benchmarks/utils/classification_dataset.py b/avalanche/benchmarks/utils/classification_dataset.py index acb4f880b..ad6db47f6 100644 --- a/avalanche/benchmarks/utils/classification_dataset.py +++ b/avalanche/benchmarks/utils/classification_dataset.py @@ -175,8 +175,7 @@ def _make_taskaware_classification_dataset( task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Sequence[TTargetType]] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> TaskAwareSupervisedClassificationDataset: - ... +) -> TaskAwareSupervisedClassificationDataset: ... @overload @@ -190,8 +189,7 @@ def _make_taskaware_classification_dataset( task_labels: Union[int, Sequence[int]], targets: Sequence[TTargetType], collate_fn: Optional[Callable[[List], Any]] = None -) -> TaskAwareSupervisedClassificationDataset: - ... +) -> TaskAwareSupervisedClassificationDataset: ... @overload @@ -205,8 +203,7 @@ def _make_taskaware_classification_dataset( task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Sequence[TTargetType]] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> TaskAwareClassificationDataset: - ... +) -> TaskAwareClassificationDataset: ... def _make_taskaware_classification_dataset( @@ -386,8 +383,7 @@ def _taskaware_classification_subset( task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Sequence[TTargetType]] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> TaskAwareSupervisedClassificationDataset: - ... +) -> TaskAwareSupervisedClassificationDataset: ... @overload @@ -403,8 +399,7 @@ def _taskaware_classification_subset( task_labels: Union[int, Sequence[int]], targets: Sequence[TTargetType], collate_fn: Optional[Callable[[List], Any]] = None -) -> TaskAwareSupervisedClassificationDataset: - ... +) -> TaskAwareSupervisedClassificationDataset: ... @overload @@ -420,8 +415,7 @@ def _taskaware_classification_subset( task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Sequence[TTargetType]] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> TaskAwareClassificationDataset: - ... +) -> TaskAwareClassificationDataset: ... def _taskaware_classification_subset( @@ -619,8 +613,7 @@ def _make_taskaware_tensor_classification_dataset( task_labels: Union[int, Sequence[int]], targets: Union[Sequence[TTargetType], int], collate_fn: Optional[Callable[[List], Any]] = None -) -> TaskAwareSupervisedClassificationDataset: - ... +) -> TaskAwareSupervisedClassificationDataset: ... @overload @@ -633,8 +626,9 @@ def _make_taskaware_tensor_classification_dataset( task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Union[Sequence[TTargetType], int]] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> Union[TaskAwareClassificationDataset, TaskAwareSupervisedClassificationDataset]: - ... +) -> Union[ + TaskAwareClassificationDataset, TaskAwareSupervisedClassificationDataset +]: ... def _make_taskaware_tensor_classification_dataset( @@ -759,8 +753,7 @@ def _concat_taskaware_classification_datasets( Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]] ] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> TaskAwareSupervisedClassificationDataset: - ... +) -> TaskAwareSupervisedClassificationDataset: ... @overload @@ -774,8 +767,7 @@ def _concat_taskaware_classification_datasets( task_labels: Union[int, Sequence[int], Sequence[Sequence[int]]], targets: Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]], collate_fn: Optional[Callable[[List], Any]] = None -) -> TaskAwareSupervisedClassificationDataset: - ... +) -> TaskAwareSupervisedClassificationDataset: ... @overload @@ -791,8 +783,7 @@ def _concat_taskaware_classification_datasets( Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]] ] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> TaskAwareClassificationDataset: - ... +) -> TaskAwareClassificationDataset: ... def _concat_taskaware_classification_datasets( diff --git a/avalanche/benchmarks/utils/data.py b/avalanche/benchmarks/utils/data.py index 76908345e..9985e19fa 100644 --- a/avalanche/benchmarks/utils/data.py +++ b/avalanche/benchmarks/utils/data.py @@ -344,12 +344,10 @@ def __eq__(self, other: object): ) @overload - def __getitem__(self, exp_id: int) -> T_co: - ... + def __getitem__(self, exp_id: int) -> T_co: ... @overload - def __getitem__(self: TAvalancheDataset, exp_id: slice) -> TAvalancheDataset: - ... + def __getitem__(self: TAvalancheDataset, exp_id: slice) -> TAvalancheDataset: ... def __getitem__( self: TAvalancheDataset, idx: Union[int, slice] diff --git a/avalanche/benchmarks/utils/data_attribute.py b/avalanche/benchmarks/utils/data_attribute.py index 3cb88e58c..d3ee4b0d9 100644 --- a/avalanche/benchmarks/utils/data_attribute.py +++ b/avalanche/benchmarks/utils/data_attribute.py @@ -62,12 +62,10 @@ def __iter__(self): yield self[i] @overload - def __getitem__(self, item: int) -> T_co: - ... + def __getitem__(self, item: int) -> T_co: ... @overload - def __getitem__(self, item: slice) -> Sequence[T_co]: - ... + def __getitem__(self, item: slice) -> Sequence[T_co]: ... def __getitem__(self, item: Union[int, slice]) -> Union[T_co, Sequence[T_co]]: return self.data[item] diff --git a/avalanche/benchmarks/utils/dataset_definitions.py b/avalanche/benchmarks/utils/dataset_definitions.py index ca8de5b54..de9fa05da 100644 --- a/avalanche/benchmarks/utils/dataset_definitions.py +++ b/avalanche/benchmarks/utils/dataset_definitions.py @@ -38,11 +38,9 @@ class IDataset(Protocol[T_co]): Note: no __add__ method is defined. """ - def __getitem__(self, index: int) -> T_co: - ... + def __getitem__(self, index: int) -> T_co: ... - def __len__(self) -> int: - ... + def __len__(self) -> int: ... class IDatasetWithTargets(IDataset[T_co], Protocol[T_co, TTargetType_co]): diff --git a/avalanche/benchmarks/utils/dataset_utils.py b/avalanche/benchmarks/utils/dataset_utils.py index 717d67e75..867ba9cc6 100644 --- a/avalanche/benchmarks/utils/dataset_utils.py +++ b/avalanche/benchmarks/utils/dataset_utils.py @@ -64,12 +64,10 @@ def __iter__(self) -> Iterator[TData]: yield el @overload - def __getitem__(self, item: int) -> TData: - ... + def __getitem__(self, item: int) -> TData: ... @overload - def __getitem__(self: TSliceSequence, item: slice) -> TSliceSequence: - ... + def __getitem__(self: TSliceSequence, item: slice) -> TSliceSequence: ... @final def __getitem__( diff --git a/avalanche/benchmarks/utils/detection_dataset.py b/avalanche/benchmarks/utils/detection_dataset.py index 7f3b3b632..6c5efb43f 100644 --- a/avalanche/benchmarks/utils/detection_dataset.py +++ b/avalanche/benchmarks/utils/detection_dataset.py @@ -151,8 +151,7 @@ def make_detection_dataset( task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Sequence[TTargetType]] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> SupervisedDetectionDataset: - ... +) -> SupervisedDetectionDataset: ... @overload @@ -166,8 +165,7 @@ def make_detection_dataset( task_labels: Union[int, Sequence[int]], targets: Sequence[TTargetType], collate_fn: Optional[Callable[[List], Any]] = None -) -> SupervisedDetectionDataset: - ... +) -> SupervisedDetectionDataset: ... @overload @@ -181,8 +179,7 @@ def make_detection_dataset( task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Sequence[TTargetType]] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> DetectionDataset: - ... +) -> DetectionDataset: ... def make_detection_dataset( @@ -373,8 +370,7 @@ def detection_subset( task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Sequence[TTargetType]] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> SupervisedDetectionDataset: - ... +) -> SupervisedDetectionDataset: ... @overload @@ -390,8 +386,7 @@ def detection_subset( task_labels: Union[int, Sequence[int]], targets: Sequence[TTargetType], collate_fn: Optional[Callable[[List], Any]] = None -) -> SupervisedDetectionDataset: - ... +) -> SupervisedDetectionDataset: ... @overload @@ -407,8 +402,7 @@ def detection_subset( task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Sequence[TTargetType]] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> DetectionDataset: - ... +) -> DetectionDataset: ... def detection_subset( @@ -595,8 +589,7 @@ def concat_detection_datasets( Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]] ] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> SupervisedDetectionDataset: - ... +) -> SupervisedDetectionDataset: ... @overload @@ -610,8 +603,7 @@ def concat_detection_datasets( task_labels: Union[int, Sequence[int], Sequence[Sequence[int]]], targets: Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]], collate_fn: Optional[Callable[[List], Any]] = None -) -> SupervisedDetectionDataset: - ... +) -> SupervisedDetectionDataset: ... @overload @@ -627,8 +619,7 @@ def concat_detection_datasets( Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]] ] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> DetectionDataset: - ... +) -> DetectionDataset: ... def concat_detection_datasets( diff --git a/avalanche/benchmarks/utils/ffcv_support/ffcv_epoch_iterator.py b/avalanche/benchmarks/utils/ffcv_support/ffcv_epoch_iterator.py index 7c36b633f..44fbfcedc 100644 --- a/avalanche/benchmarks/utils/ffcv_support/ffcv_epoch_iterator.py +++ b/avalanche/benchmarks/utils/ffcv_support/ffcv_epoch_iterator.py @@ -1,6 +1,7 @@ """ Custom version of the FFCV epoch iterator. """ + from threading import Thread, Event, Lock from queue import Queue from typing import List, Sequence, TYPE_CHECKING diff --git a/avalanche/benchmarks/utils/ffcv_support/ffcv_loader.py b/avalanche/benchmarks/utils/ffcv_support/ffcv_loader.py index 8f2ae9f10..1deacbb97 100644 --- a/avalanche/benchmarks/utils/ffcv_support/ffcv_loader.py +++ b/avalanche/benchmarks/utils/ffcv_support/ffcv_loader.py @@ -1,6 +1,7 @@ """ Custom version of the FFCV loader that accepts a batch sampler. """ + from typing import Any, Callable, List, Mapping, Optional, Sequence, Type, Union import warnings diff --git a/avalanche/benchmarks/utils/flat_data.py b/avalanche/benchmarks/utils/flat_data.py index d9fd73ee0..02d3681cd 100644 --- a/avalanche/benchmarks/utils/flat_data.py +++ b/avalanche/benchmarks/utils/flat_data.py @@ -66,9 +66,9 @@ def __init__( else: new_lists.append(ll) - self._lists: Optional[ - List[Sequence[int]] - ] = new_lists # freed after eagerification + self._lists: Optional[List[Sequence[int]]] = ( + new_lists # freed after eagerification + ) self._offset: int = int(offset) self._eager_list: Optional[np.ndarray] = None """This is the list where we save indices @@ -408,12 +408,10 @@ def _get_idx(self, idx) -> Tuple[int, int]: return dataset_idx, int(idx) @overload - def __getitem__(self, item: int) -> T_co: - ... + def __getitem__(self, item: int) -> T_co: ... @overload - def __getitem__(self: TFlatData, item: slice) -> TFlatData: - ... + def __getitem__(self: TFlatData, item: slice) -> TFlatData: ... def __getitem__(self: TFlatData, item: Union[int, slice]) -> Union[T_co, TFlatData]: if isinstance(item, (int, np.integer)): @@ -470,12 +468,10 @@ def __len__(self): return self._size @overload - def __getitem__(self, index: int) -> DataT: - ... + def __getitem__(self, index: int) -> DataT: ... @overload - def __getitem__(self, index: slice) -> "ConstantSequence[DataT]": - ... + def __getitem__(self, index: slice) -> "ConstantSequence[DataT]": ... def __getitem__( self, index: Union[int, slice] diff --git a/avalanche/evaluation/metric_definitions.py b/avalanche/evaluation/metric_definitions.py index 1934d5392..c91a7122c 100644 --- a/avalanche/evaluation/metric_definitions.py +++ b/avalanche/evaluation/metric_definitions.py @@ -215,8 +215,7 @@ def __init__( ] = "experience", emit_at: Literal["iteration", "epoch", "experience", "stream"] = "experience", mode: Literal["train"] = "train", - ): - ... + ): ... @overload def __init__( @@ -225,8 +224,7 @@ def __init__( reset_at: Literal["iteration", "experience", "stream", "never"] = "experience", emit_at: Literal["iteration", "experience", "stream"] = "experience", mode: Literal["eval"] = "eval", - ): - ... + ): ... def __init__( self, metric: TMetric, reset_at="experience", emit_at="experience", mode="eval" diff --git a/avalanche/evaluation/metrics/forgetting_bwt.py b/avalanche/evaluation/metrics/forgetting_bwt.py index 9e5f79094..29d9ffd28 100644 --- a/avalanche/evaluation/metrics/forgetting_bwt.py +++ b/avalanche/evaluation/metrics/forgetting_bwt.py @@ -526,18 +526,15 @@ def forgetting_metrics(*, experience=False, stream=False) -> List[PluginMetric]: @overload -def forgetting_to_bwt(f: float) -> float: - ... +def forgetting_to_bwt(f: float) -> float: ... @overload -def forgetting_to_bwt(f: Dict[int, float]) -> Dict[int, float]: - ... +def forgetting_to_bwt(f: Dict[int, float]) -> Dict[int, float]: ... @overload -def forgetting_to_bwt(f: None) -> None: - ... +def forgetting_to_bwt(f: None) -> None: ... def forgetting_to_bwt(f: Optional[Union[float, Dict[int, float]]]): diff --git a/avalanche/evaluation/metrics/labels_repartition.py b/avalanche/evaluation/metrics/labels_repartition.py index 0d55ec5f3..7a5304696 100644 --- a/avalanche/evaluation/metrics/labels_repartition.py +++ b/avalanche/evaluation/metrics/labels_repartition.py @@ -85,8 +85,7 @@ def __init__( ] = default_history_repartition_image_creator, mode: Literal["train"] = "train", emit_reset_at: Literal["stream", "experience", "epoch"] = "epoch", - ): - ... + ): ... @overload def __init__( @@ -97,8 +96,7 @@ def __init__( ] = default_history_repartition_image_creator, mode: Literal["eval"] = "eval", emit_reset_at: Literal["stream", "experience"], - ): - ... + ): ... def __init__( self, @@ -171,12 +169,14 @@ def _package_result(self, strategy: "SupervisedTemplate") -> "MetricResult": f"/{self._mode}_phase" f"/{stream_type(strategy.experience)}_stream" f"/Task_{task:03}", - value=AlternativeValues( - self.image_creator(label2counts, self.steps), - label2counts, - ) - if self.image_creator is not None - else label2counts, + value=( + AlternativeValues( + self.image_creator(label2counts, self.steps), + label2counts, + ) + if self.image_creator is not None + else label2counts + ), x_plot=strategy.clock.train_iterations, ) for task, label2counts in self.task2label2counts.items() diff --git a/avalanche/evaluation/metrics/mean_scores.py b/avalanche/evaluation/metrics/mean_scores.py index a7e69276e..b0f84170f 100644 --- a/avalanche/evaluation/metrics/mean_scores.py +++ b/avalanche/evaluation/metrics/mean_scores.py @@ -9,6 +9,7 @@ It selects the scores of the true class and then average them for past and new classes. """ + from abc import ABC, abstractmethod from collections import defaultdict from typing import Callable, Dict, Set, TYPE_CHECKING, List, Optional, TypeVar, Literal diff --git a/avalanche/models/slim_resnet18.py b/avalanche/models/slim_resnet18.py index f5cd0fed6..a27135d44 100644 --- a/avalanche/models/slim_resnet18.py +++ b/avalanche/models/slim_resnet18.py @@ -1,4 +1,5 @@ """This is the slimmed ResNet as used by Lopez et al. in the GEM paper.""" + import torch.nn as nn from torch.nn.functional import relu, avg_pool2d from avalanche.models import MultiHeadClassifier, MultiTaskModule, DynamicModule diff --git a/avalanche/models/timm_vit.py b/avalanche/models/timm_vit.py index c2e0bf45a..21dcc677a 100644 --- a/avalanche/models/timm_vit.py +++ b/avalanche/models/timm_vit.py @@ -23,6 +23,7 @@ # -- Jaeho Lee, dlwogh9344@khu.ac.kr # ------------------------------------------ """ + import math import re @@ -442,9 +443,11 @@ def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False): v = resize_pos_embed( v, model.pos_embed, - 0 - if getattr(model, "no_embed_class") - else getattr(model, "num_prefix_tokens", 1), + ( + 0 + if getattr(model, "no_embed_class") + else getattr(model, "num_prefix_tokens", 1) + ), model.patch_embed.grid_size, ) elif adapt_layer_scale and "gamma_" in k: diff --git a/avalanche/training/__init__.py b/avalanche/training/__init__.py index 0a995200a..7780e066f 100644 --- a/avalanche/training/__init__.py +++ b/avalanche/training/__init__.py @@ -5,6 +5,7 @@ class (:py:class:`BaseTemplate`) and implementations of the most common :py:mod:`training.supervised` or as plugins (:py:mod:`training.plugins`) that can be easily combined with your own strategy. """ + from .regularization import * from .supervised import * from .storage_policy import * diff --git a/avalanche/training/plugins/agem.py b/avalanche/training/plugins/agem.py index db1cfd7a5..9cdb219d9 100644 --- a/avalanche/training/plugins/agem.py +++ b/avalanche/training/plugins/agem.py @@ -61,9 +61,11 @@ def before_training_iteration(self, strategy, **kwargs): loss.backward() # gradient can be None for some head on multi-headed models reference_gradients_list = [ - p.grad.view(-1) - if p.grad is not None - else torch.zeros(p.numel(), device=strategy.device) + ( + p.grad.view(-1) + if p.grad is not None + else torch.zeros(p.numel(), device=strategy.device) + ) for n, p in strategy.model.named_parameters() ] self.reference_gradients = torch.cat(reference_gradients_list) @@ -76,9 +78,11 @@ def after_backward(self, strategy, **kwargs): """ if len(self.buffers) > 0: current_gradients_list = [ - p.grad.view(-1) - if p.grad is not None - else torch.zeros(p.numel(), device=strategy.device) + ( + p.grad.view(-1) + if p.grad is not None + else torch.zeros(p.numel(), device=strategy.device) + ) for n, p in strategy.model.named_parameters() ] current_gradients = torch.cat(current_gradients_list) diff --git a/avalanche/training/plugins/bic.py b/avalanche/training/plugins/bic.py index 61535fd37..caef605b9 100644 --- a/avalanche/training/plugins/bic.py +++ b/avalanche/training/plugins/bic.py @@ -454,9 +454,9 @@ def _classes_groups(self, strategy: SupervisedTemplate): # - "current" classes: seen in current_experience # "initial" classes - initial_classes: Set[ - int - ] = set() # pre_initial_cl in the original implementation + initial_classes: Set[int] = ( + set() + ) # pre_initial_cl in the original implementation previous_classes: Set[int] = set() # pre_new_cl in the original implementation current_classes: Set[int] = set() # new_cl in the original implementation # Note: pre_initial_cl + pre_new_cl is "initial_cl" in the original implementation diff --git a/avalanche/training/plugins/gem.py b/avalanche/training/plugins/gem.py index f0b01128f..26eccdad6 100644 --- a/avalanche/training/plugins/gem.py +++ b/avalanche/training/plugins/gem.py @@ -59,9 +59,11 @@ def before_training_iteration(self, strategy, **kwargs): G.append( torch.cat( [ - p.grad.flatten() - if p.grad is not None - else torch.zeros(p.numel(), device=strategy.device) + ( + p.grad.flatten() + if p.grad is not None + else torch.zeros(p.numel(), device=strategy.device) + ) for p in strategy.model.parameters() ], dim=0, @@ -79,9 +81,11 @@ def after_backward(self, strategy, **kwargs): if strategy.clock.train_exp_counter > 0: g = torch.cat( [ - p.grad.flatten() - if p.grad is not None - else torch.zeros(p.numel(), device=strategy.device) + ( + p.grad.flatten() + if p.grad is not None + else torch.zeros(p.numel(), device=strategy.device) + ) for p in strategy.model.parameters() ], dim=0, diff --git a/avalanche/training/regularization.py b/avalanche/training/regularization.py index 2c15d3392..277dc1ee3 100644 --- a/avalanche/training/regularization.py +++ b/avalanche/training/regularization.py @@ -1,4 +1,5 @@ """Regularization methods.""" + import copy from collections import defaultdict from typing import List diff --git a/avalanche/training/supervised/__init__.py b/avalanche/training/supervised/__init__.py index a956677b3..b35bdbf78 100644 --- a/avalanche/training/supervised/__init__.py +++ b/avalanche/training/supervised/__init__.py @@ -2,6 +2,7 @@ This module contains """ + from .joint_training import * from .ar1 import AR1 from .cumulative import Cumulative diff --git a/avalanche/training/supervised/expert_gate.py b/avalanche/training/supervised/expert_gate.py index 67d84e3b5..8dad92f46 100644 --- a/avalanche/training/supervised/expert_gate.py +++ b/avalanche/training/supervised/expert_gate.py @@ -236,9 +236,9 @@ def _select_expert(self, strategy: "SupervisedTemplate", task_label): # Iterate through all autoencoders to get error values for autoencoder_id in strategy.model.autoencoder_dict: - error_dict[ - str(autoencoder_id) - ] = self._get_average_reconstruction_error(strategy, autoencoder_id) + error_dict[str(autoencoder_id)] = ( + self._get_average_reconstruction_error(strategy, autoencoder_id) + ) # Send error dictionary to get most relevant autoencoder relatedness_dict = self._task_relatedness(strategy, error_dict, task_label) diff --git a/avalanche/training/supervised/joint_training.py b/avalanche/training/supervised/joint_training.py index 2bb151d80..3bb93df13 100644 --- a/avalanche/training/supervised/joint_training.py +++ b/avalanche/training/supervised/joint_training.py @@ -135,9 +135,9 @@ def train( ) # Normalize training and eval data. - experiences_list: Iterable[ - TDatasetExperience - ] = _experiences_parameter_as_iterable(experiences) + experiences_list: Iterable[TDatasetExperience] = ( + _experiences_parameter_as_iterable(experiences) + ) if eval_streams is None: eval_streams = [experiences_list] diff --git a/avalanche/training/supervised/lamaml.py b/avalanche/training/supervised/lamaml.py index 81a9866cd..94f0b5381 100644 --- a/avalanche/training/supervised/lamaml.py +++ b/avalanche/training/supervised/lamaml.py @@ -183,9 +183,11 @@ def inner_update_step(self, fast_model, x, y, t): # Clip grad norms grads = [ - torch.clamp(g, min=-self.grad_clip_norm, max=self.grad_clip_norm) - if g is not None - else g + ( + torch.clamp(g, min=-self.grad_clip_norm, max=self.grad_clip_norm) + if g is not None + else g + ) for g in grads ] diff --git a/avalanche/training/supervised/lamaml_v2.py b/avalanche/training/supervised/lamaml_v2.py index 9301a14c5..0838ecc61 100644 --- a/avalanche/training/supervised/lamaml_v2.py +++ b/avalanche/training/supervised/lamaml_v2.py @@ -185,9 +185,11 @@ def inner_update_step(self, fast_params, x, y, t): # Clip grad norms grads = [ - torch.clamp(g, min=-self.grad_clip_norm, max=self.grad_clip_norm) - if g is not None - else g + ( + torch.clamp(g, min=-self.grad_clip_norm, max=self.grad_clip_norm) + if g is not None + else g + ) for g in grads ] diff --git a/avalanche/training/templates/__init__.py b/avalanche/training/templates/__init__.py index 1d3308113..84a482291 100644 --- a/avalanche/training/templates/__init__.py +++ b/avalanche/training/templates/__init__.py @@ -9,6 +9,7 @@ Templates are the backbone that supports the plugin systems. """ + from .base import BaseTemplate from .base_sgd import BaseSGDTemplate from .common_templates import ( diff --git a/avalanche/training/templates/base.py b/avalanche/training/templates/base.py index 4cd5abe22..26379272e 100644 --- a/avalanche/training/templates/base.py +++ b/avalanche/training/templates/base.py @@ -130,9 +130,9 @@ def train( self.model.to(self.device) # Normalize training and eval data. - experiences_list: Iterable[ - TExperienceType - ] = _experiences_parameter_as_iterable(experiences) + experiences_list: Iterable[TExperienceType] = ( + _experiences_parameter_as_iterable(experiences) + ) if eval_streams is None: eval_streams = [experiences_list] @@ -184,9 +184,9 @@ def eval( self.is_training = False self.model.eval() - experiences_list: Iterable[ - TExperienceType - ] = _experiences_parameter_as_iterable(experiences) + experiences_list: Iterable[TExperienceType] = ( + _experiences_parameter_as_iterable(experiences) + ) self.current_eval_stream = experiences_list self._before_eval(**kwargs) diff --git a/avalanche/training/templates/observation_type/__init__.py b/avalanche/training/templates/observation_type/__init__.py index f98135349..9f0cf3c00 100644 --- a/avalanche/training/templates/observation_type/__init__.py +++ b/avalanche/training/templates/observation_type/__init__.py @@ -2,4 +2,5 @@ batch(multiple epochs) vs. online(one epoch) """ + from .batch_observation import BatchObservation diff --git a/avalanche/training/templates/observation_type/batch_observation.py b/avalanche/training/templates/observation_type/batch_observation.py index 6bf2cb696..4a8dde613 100644 --- a/avalanche/training/templates/observation_type/batch_observation.py +++ b/avalanche/training/templates/observation_type/batch_observation.py @@ -1,9 +1,3 @@ -import warnings -from collections import defaultdict -from typing import List, TypeVar - -from torch import Tensor - from avalanche.benchmarks import OnlineCLExperience from avalanche.models.utils import avalanche_model_adaptation from avalanche.training.templates.strategy_mixin_protocol import SGDStrategyProtocol diff --git a/avalanche/training/templates/problem_type/__init__.py b/avalanche/training/templates/problem_type/__init__.py index 0932beb4c..81cad2a6d 100644 --- a/avalanche/training/templates/problem_type/__init__.py +++ b/avalanche/training/templates/problem_type/__init__.py @@ -2,4 +2,5 @@ how inputs should be mapped to outputs. """ + from .supervised_problem import SupervisedProblem diff --git a/avalanche/training/templates/strategy_mixin_protocol.py b/avalanche/training/templates/strategy_mixin_protocol.py index ea3c54f8e..98f664f7e 100644 --- a/avalanche/training/templates/strategy_mixin_protocol.py +++ b/avalanche/training/templates/strategy_mixin_protocol.py @@ -54,53 +54,37 @@ class SGDStrategyProtocol( _criterion: Module - def forward(self) -> TMBoutput: - ... + def forward(self) -> TMBoutput: ... - def criterion(self) -> Tensor: - ... + def criterion(self) -> Tensor: ... - def backward(self) -> None: - ... + def backward(self) -> None: ... - def _make_empty_loss(self) -> Tensor: - ... + def _make_empty_loss(self) -> Tensor: ... - def make_optimizer(self, **kwargs): - ... + def make_optimizer(self, **kwargs): ... - def optimizer_step(self) -> None: - ... + def optimizer_step(self) -> None: ... - def model_adaptation(self, model: Optional[Module] = None) -> Module: - ... + def model_adaptation(self, model: Optional[Module] = None) -> Module: ... - def _unpack_minibatch(self): - ... + def _unpack_minibatch(self): ... - def _before_training_iteration(self, **kwargs): - ... + def _before_training_iteration(self, **kwargs): ... - def _before_forward(self, **kwargs): - ... + def _before_forward(self, **kwargs): ... - def _after_forward(self, **kwargs): - ... + def _after_forward(self, **kwargs): ... - def _before_backward(self, **kwargs): - ... + def _before_backward(self, **kwargs): ... - def _after_backward(self, **kwargs): - ... + def _after_backward(self, **kwargs): ... - def _before_update(self, **kwargs): - ... + def _before_update(self, **kwargs): ... - def _after_update(self, **kwargs): - ... + def _after_update(self, **kwargs): ... - def _after_training_iteration(self, **kwargs): - ... + def _after_training_iteration(self, **kwargs): ... class SupervisedStrategyProtocol( @@ -118,23 +102,17 @@ class MetaLearningStrategyProtocol( SGDStrategyProtocol[TSGDExperienceType, TMBinput, TMBoutput], Protocol[TSGDExperienceType, TMBinput, TMBoutput], ): - def _before_inner_updates(self, **kwargs): - ... + def _before_inner_updates(self, **kwargs): ... - def _inner_updates(self, **kwargs): - ... + def _inner_updates(self, **kwargs): ... - def _after_inner_updates(self, **kwargs): - ... + def _after_inner_updates(self, **kwargs): ... - def _before_outer_update(self, **kwargs): - ... + def _before_outer_update(self, **kwargs): ... - def _outer_update(self, **kwargs): - ... + def _outer_update(self, **kwargs): ... - def _after_outer_update(self, **kwargs): - ... + def _after_outer_update(self, **kwargs): ... __all__ = [ diff --git a/examples/detection.py b/examples/detection.py index 3e7f7d94e..ec2e92119 100644 --- a/examples/detection.py +++ b/examples/detection.py @@ -160,15 +160,11 @@ def obtain_base_model(segmentation: bool): pretrain_argument["pretrained"] = True else: if segmentation: - pretrain_argument[ - "weights" - ] = ( + pretrain_argument["weights"] = ( torchvision.models.detection.mask_rcnn.MaskRCNN_ResNet50_FPN_Weights.DEFAULT ) else: - pretrain_argument[ - "weights" - ] = ( + pretrain_argument["weights"] = ( torchvision.models.detection.faster_rcnn.FasterRCNN_ResNet50_FPN_Weights.DEFAULT ) diff --git a/examples/detection_lvis.py b/examples/detection_lvis.py index ed86e5b2c..330b7a9ec 100644 --- a/examples/detection_lvis.py +++ b/examples/detection_lvis.py @@ -142,15 +142,11 @@ def obtain_base_model(segmentation: bool): pretrain_argument["pretrained"] = True else: if segmentation: - pretrain_argument[ - "weights" - ] = ( + pretrain_argument["weights"] = ( torchvision.models.detection.mask_rcnn.MaskRCNN_ResNet50_FPN_Weights.DEFAULT ) else: - pretrain_argument[ - "weights" - ] = ( + pretrain_argument["weights"] = ( torchvision.models.detection.faster_rcnn.FasterRCNN_ResNet50_FPN_Weights.DEFAULT ) diff --git a/examples/hf_datawrapper.py b/examples/hf_datawrapper.py index 892b3af9c..9b38d6690 100644 --- a/examples/hf_datawrapper.py +++ b/examples/hf_datawrapper.py @@ -7,6 +7,7 @@ You can install them by running: pip install datasets transformers """ + import datasets as ds import numpy as np import torch diff --git a/examples/nlp_qa.py b/examples/nlp_qa.py index 7a6192d43..d6714de12 100644 --- a/examples/nlp_qa.py +++ b/examples/nlp_qa.py @@ -2,6 +2,7 @@ Simple example that show how to use Avalanche for Question Answering on Squad by using T5 """ + from avalanche.benchmarks.utils import DataAttribute, ConstantSequence from avalanche.training.plugins import ReplayPlugin from transformers import DataCollatorForSeq2Seq @@ -162,7 +163,7 @@ def preprocess_function( benchmark = CLScenario( [ CLStream("train", train_exps), - CLStream("valid", val_exp) + CLStream("valid", val_exp), # add more stream here (validation, test, out-of-domain, ...) ] ) diff --git a/examples/tvdetection/train.py b/examples/tvdetection/train.py index e73205295..0fa536674 100644 --- a/examples/tvdetection/train.py +++ b/examples/tvdetection/train.py @@ -17,6 +17,7 @@ Because the number of images is smaller in the person keypoint subset of COCO, the number of epochs should be adapted so that we have the same number of iterations. """ + import datetime import os import time diff --git a/profiling/data_merging.py b/profiling/data_merging.py index c0574f05f..ae230aa7c 100644 --- a/profiling/data_merging.py +++ b/profiling/data_merging.py @@ -2,6 +2,7 @@ Profiling script to measure the performance of the merging process of AvalancheDatasets """ + from unittest.mock import Mock import time diff --git a/profiling/serialization.py b/profiling/serialization.py index 92ea7d873..a404d6d47 100644 --- a/profiling/serialization.py +++ b/profiling/serialization.py @@ -4,6 +4,7 @@ SAVING TIME: 0.019003629684448242 LOADING TIME: 0.016000032424926758 """ + import os import time from torch.optim import SGD diff --git a/tests/test_loggers.py b/tests/test_loggers.py index 36f3403e3..4f50f67ef 100644 --- a/tests/test_loggers.py +++ b/tests/test_loggers.py @@ -3,6 +3,7 @@ Right now they don't do much but they at least check that the loggers run without errors. """ + import unittest from torch.optim import SGD From 8af1282298df7efd85663040af58beb9a6e2bb85 Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Mon, 29 Jan 2024 18:24:30 +0100 Subject: [PATCH 04/21] Porting all strategies to keyword only parameters --- avalanche/training/supervised/ar1.py | 5 +- avalanche/training/supervised/cumulative.py | 9 +- avalanche/training/supervised/deep_slda.py | 25 +- avalanche/training/supervised/der.py | 9 +- avalanche/training/supervised/er_ace.py | 9 +- avalanche/training/supervised/er_aml.py | 15 +- avalanche/training/supervised/expert_gate.py | 20 +- .../training/supervised/feature_replay.py | 9 +- avalanche/training/supervised/icarl.py | 28 +- .../training/supervised/joint_training.py | 11 +- avalanche/training/supervised/l2p.py | 29 +- avalanche/training/supervised/lamaml.py | 7 +- avalanche/training/supervised/lamaml_v2.py | 7 +- avalanche/training/supervised/mer.py | 33 +- .../training/supervised/strategy_wrappers.py | 188 +++-- .../supervised/strategy_wrappers_online.py | 9 +- .../training/templates/common_templates.py | 118 ++- .../templates/update_type/meta_update.py | 3 + examples/naive.py | 8 +- pyproject.toml | 5 +- tests/benchmarks/utils/test_dataloaders.py | 6 +- tests/models/test_models.py | 48 +- tests/models/test_packnet.py | 24 +- tests/training/test_online_strategies.py | 28 +- tests/training/test_plugins.py | 6 +- tests/training/test_strategies.py | 779 +++++++++--------- tests/training/test_strategies_accuracy.py | 10 +- tests/training/test_supervised_regression.py | 13 +- 28 files changed, 847 insertions(+), 614 deletions(-) diff --git a/avalanche/training/supervised/ar1.py b/avalanche/training/supervised/ar1.py index 8be1d75b9..999924e69 100644 --- a/avalanche/training/supervised/ar1.py +++ b/avalanche/training/supervised/ar1.py @@ -17,6 +17,7 @@ CWRStarPlugin, ) from avalanche.training.templates import SupervisedTemplate +from avalanche.training.templates.base_sgd import CriterionType from avalanche.training.utils import ( replace_bn_with_brn, get_last_fc_layer, @@ -42,7 +43,8 @@ class AR1(SupervisedTemplate): def __init__( self, - criterion=None, + *args, + criterion: CriterionType = None, lr: float = 0.001, inc_lr: float = 5e-5, momentum=0.9, @@ -167,6 +169,7 @@ def __init__( self.replay_mb_size = 0 super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, diff --git a/avalanche/training/supervised/cumulative.py b/avalanche/training/supervised/cumulative.py index 5596ac592..c324bbef3 100644 --- a/avalanche/training/supervised/cumulative.py +++ b/avalanche/training/supervised/cumulative.py @@ -8,6 +8,7 @@ from avalanche.training.plugins.evaluation import default_evaluator from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin from avalanche.training.templates import SupervisedTemplate +from avalanche.training.templates.base_sgd import CriterionType class Cumulative(SupervisedTemplate): @@ -19,9 +20,10 @@ class Cumulative(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion, + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = "not_set", train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, @@ -53,6 +55,7 @@ def __init__( """ super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, diff --git a/avalanche/training/supervised/deep_slda.py b/avalanche/training/supervised/deep_slda.py index 00468612f..2ff1ae4cd 100644 --- a/avalanche/training/supervised/deep_slda.py +++ b/avalanche/training/supervised/deep_slda.py @@ -12,6 +12,7 @@ ) from avalanche.models.dynamic_modules import MultiTaskModule from avalanche.models import FeatureExtractorBackbone +from avalanche.training.templates.base_sgd import CriterionType class StreamingLDA(SupervisedTemplate): @@ -28,11 +29,12 @@ class StreamingLDA(SupervisedTemplate): def __init__( self, - slda_model, - criterion, - input_size, - num_classes, - output_layer_name=None, + *args, + slda_model="not_set", + criterion: CriterionType = "not_set", + input_size="not_set", + num_classes: int = "not_set", + output_layer_name: Optional[str] = None, shrinkage_param=1e-4, streaming_update_sigma=True, train_epochs: int = 1, @@ -48,7 +50,7 @@ def __init__( ): """Init function for the SLDA model. - :param slda_model: a PyTorch model + :param model: a PyTorch model :param criterion: loss function :param output_layer_name: if not None, wrap model to retrieve only the `output_layer_name` output. If None, the strategy @@ -72,13 +74,24 @@ def __init__( if plugins is None: plugins = [] + adapt_legacy_args = slda_model == "not_set" + slda_model = slda_model if slda_model != "not_set" else args[0] slda_model = slda_model.eval() if output_layer_name is not None: slda_model = FeatureExtractorBackbone( slda_model.to(device), output_layer_name ).eval() + # Legacy positional arguments support (deprecation cycle) + if adapt_legacy_args: + args = list(args) + args[0] = slda_model + + input_size = input_size if input_size != "not_set" else args[2] + num_classes = num_classes if num_classes != "not_set" else args[3] + super(StreamingLDA, self).__init__( + legacy_positional_args=args, model=slda_model, optimizer=None, # type: ignore criterion=criterion, diff --git a/avalanche/training/supervised/der.py b/avalanche/training/supervised/der.py index 8b6e87b86..bbc4b2186 100644 --- a/avalanche/training/supervised/der.py +++ b/avalanche/training/supervised/der.py @@ -19,6 +19,7 @@ from avalanche.benchmarks.utils.data import AvalancheDataset from avalanche.benchmarks.utils.data_attribute import TensorDataAttribute from avalanche.benchmarks.utils.flat_data import FlatData +from avalanche.training.templates.base_sgd import CriterionType from avalanche.training.utils import cycle from avalanche.core import SupervisedPlugin from avalanche.training.plugins.evaluation import ( @@ -156,9 +157,10 @@ class DER(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion=CrossEntropyLoss(), + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = CrossEntropyLoss(), mem_size: int = 200, batch_size_mem: Optional[int] = None, alpha: float = 0.1, @@ -209,6 +211,7 @@ def __init__( `eval_every` experience or iterations (Default='experience'). """ super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, diff --git a/avalanche/training/supervised/er_ace.py b/avalanche/training/supervised/er_ace.py index daf7ee480..d68044330 100644 --- a/avalanche/training/supervised/er_ace.py +++ b/avalanche/training/supervised/er_ace.py @@ -13,6 +13,7 @@ ) from avalanche.training.storage_policy import ClassBalancedBuffer from avalanche.training.templates import SupervisedTemplate +from avalanche.training.templates.base_sgd import CriterionType from avalanche.training.utils import cycle @@ -32,9 +33,10 @@ class ER_ACE(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion=CrossEntropyLoss(), + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = CrossEntropyLoss(), mem_size: int = 200, batch_size_mem: int = 10, train_mb_size: int = 1, @@ -72,6 +74,7 @@ def __init__( `eval_every` epochs or iterations (Default='epoch'). """ super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, diff --git a/avalanche/training/supervised/er_aml.py b/avalanche/training/supervised/er_aml.py index d247f8b76..efff025c0 100644 --- a/avalanche/training/supervised/er_aml.py +++ b/avalanche/training/supervised/er_aml.py @@ -13,6 +13,7 @@ ) from avalanche.training.storage_policy import ClassBalancedBuffer from avalanche.training.templates import SupervisedTemplate +from avalanche.training.templates.base_sgd import CriterionType from avalanche.training.utils import cycle @@ -27,10 +28,11 @@ class ER_AML(SupervisedTemplate): def __init__( self, - model: Module, - feature_extractor: Module, - optimizer: Optimizer, - criterion=CrossEntropyLoss(), + *args, + model: Module = "not_set", + feature_extractor: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = CrossEntropyLoss(), temp: float = 0.1, base_temp: float = 0.07, same_task_neg: bool = True, @@ -75,6 +77,7 @@ def __init__( `eval_every` epochs or iterations (Default='epoch'). """ super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -95,7 +98,9 @@ def __init__( ) self.replay_loader = None self.aml_criterion = AMLCriterion( - feature_extractor=feature_extractor, + feature_extractor=( + feature_extractor if feature_extractor != "not_set" else args[1] + ), temp=temp, base_temp=base_temp, same_task_neg=same_task_neg, diff --git a/avalanche/training/supervised/expert_gate.py b/avalanche/training/supervised/expert_gate.py index 8dad92f46..21c8b0fca 100644 --- a/avalanche/training/supervised/expert_gate.py +++ b/avalanche/training/supervised/expert_gate.py @@ -22,6 +22,7 @@ from avalanche.training.templates import SupervisedTemplate from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin, LwFPlugin from avalanche.training.plugins.evaluation import default_evaluator +from avalanche.training.templates.base_sgd import CriterionType class ExpertGateStrategy(SupervisedTemplate): @@ -41,9 +42,10 @@ class ExpertGateStrategy(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion=CrossEntropyLoss(), + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, @@ -85,10 +87,6 @@ def __init__( expert during the forward method :class:`~avalanche.training.BaseTemplate` constructor arguments. """ - # Check that the model has the correct architecture. - assert isinstance( - model, ExpertGate - ), "ExpertGateStrategy requires an ExpertGate model." expertgate = _ExpertGatePlugin() @@ -102,7 +100,6 @@ def __init__( self.ae_lr = ae_lr self.ae_latent_dim = ae_latent_dim self.rel_thresh = rel_thresh - model.temp = temp warnings.warn( "This strategy is currently in the alpha stage and we are still " @@ -112,6 +109,7 @@ def __init__( ) super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -125,6 +123,12 @@ def __init__( **base_kwargs ) + # Check that the model has the correct architecture. + assert isinstance( + self.model, ExpertGate + ), "ExpertGateStrategy requires an ExpertGate model." + self.model.temp = temp + class _ExpertGatePlugin(SupervisedPlugin): """The ExpertGate algorithm is a dynamic architecture algorithm. diff --git a/avalanche/training/supervised/feature_replay.py b/avalanche/training/supervised/feature_replay.py index 40e29f2b5..be7293daf 100644 --- a/avalanche/training/supervised/feature_replay.py +++ b/avalanche/training/supervised/feature_replay.py @@ -11,6 +11,7 @@ from avalanche.training.plugins.evaluation import EvaluationPlugin, default_evaluator from avalanche.training.storage_policy import ClassBalancedBuffer from avalanche.training.templates import SupervisedTemplate +from avalanche.training.templates.base_sgd import CriterionType from avalanche.training.utils import cycle from avalanche.training.losses import MaskedCrossEntropy @@ -46,9 +47,10 @@ class FeatureReplay(SupervisedTemplate): def __init__( self, - model: nn.Module, - optimizer: Optimizer, - criterion=MaskedCrossEntropy(), + *args, + model: nn.Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = MaskedCrossEntropy(), last_layer_name: str = "classifier", mem_size: int = 200, batch_size_mem: int = 10, @@ -65,6 +67,7 @@ def __init__( **kwargs ): super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, diff --git a/avalanche/training/supervised/icarl.py b/avalanche/training/supervised/icarl.py index 9e2ca25d1..fb81d8c17 100644 --- a/avalanche/training/supervised/icarl.py +++ b/avalanche/training/supervised/icarl.py @@ -28,12 +28,13 @@ class ICaRL(SupervisedTemplate): def __init__( self, - feature_extractor: Module, - classifier: Module, - optimizer: Optimizer, - memory_size, - buffer_transform, - fixed_memory, + *args, + feature_extractor: Module = "not_set", + classifier: Module = "not_set", + optimizer: Optimizer = "not_set", + memory_size="not_set", + buffer_transform="not_set", + fixed_memory="not_set", train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, @@ -72,13 +73,17 @@ def __init__( learning experience. """ model = TrainEvalModel( - feature_extractor, - train_classifier=classifier, + feature_extractor if feature_extractor != "not_set" else args[0], + train_classifier=classifier if classifier != "not_set" else args[1], eval_classifier=NCMClassifier(normalize=True), ) criterion = ICaRLLossPlugin() # iCaRL requires this specific loss (#966) - icarl = _ICaRLPlugin(memory_size, buffer_transform, fixed_memory) + icarl = _ICaRLPlugin( + memory_size if memory_size != "not_set" else args[3], + buffer_transform if buffer_transform != "not_set" else args[4], + fixed_memory if fixed_memory != "not_set" else args[5], + ) if plugins is None: plugins = [icarl] @@ -89,8 +94,9 @@ def __init__( plugins += [criterion] super().__init__( - model, - optimizer, + legacy_positional_args=args, + model=model, + optimizer=optimizer, criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, diff --git a/avalanche/training/supervised/joint_training.py b/avalanche/training/supervised/joint_training.py index 3bb93df13..a52374a30 100644 --- a/avalanche/training/supervised/joint_training.py +++ b/avalanche/training/supervised/joint_training.py @@ -28,6 +28,7 @@ _experiences_parameter_as_iterable, _group_experiences_by_stream, ) +from avalanche.training.templates.base_sgd import CriterionType class AlreadyTrainedError(Exception): @@ -57,9 +58,10 @@ class JointTraining(SupervisedTemplate[TDatasetExperience, TMBInput, TMBOutput]) def __init__( self, - model: Module, - optimizer: Optimizer, - criterion, + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = "not_set", train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: int = 1, @@ -69,6 +71,7 @@ def __init__( EvaluationPlugin, Callable[[], EvaluationPlugin] ] = default_evaluator, eval_every=-1, + **kwargs ): """Init. @@ -88,6 +91,7 @@ def __init__( `eval` is called every `eval_every` epochs and at the end of the learning experience.""" super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -98,6 +102,7 @@ def __init__( plugins=plugins, evaluator=evaluator, eval_every=eval_every, + **kwargs ) # JointTraining can be trained only once. self._is_fitted = False diff --git a/avalanche/training/supervised/l2p.py b/avalanche/training/supervised/l2p.py index 715498c20..410df3b57 100644 --- a/avalanche/training/supervised/l2p.py +++ b/avalanche/training/supervised/l2p.py @@ -31,7 +31,8 @@ class LearningToPrompt(SupervisedTemplate): def __init__( self, - model_name: str, + *args, + model_name: str = "not_set", criterion: nn.Module = nn.CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, @@ -98,7 +99,7 @@ def __init__( self.lr = lr self.sim_coefficient = sim_coefficient model = create_model( - model_name=model_name, + model_name=model_name if model_name != "not_set" else args[0], prompt_pool=prompt_pool, pool_size=pool_size, prompt_length=prompt_length, @@ -130,17 +131,19 @@ def __init__( ) super().__init__( - model, - optimizer, - criterion, - train_mb_size, - train_epochs, - eval_mb_size, - device, - plugins, - evaluator, - eval_every, - peval_mode, + legacy_positional_args=args, + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=train_mb_size, + train_epochs=train_epochs, + eval_mb_size=eval_mb_size, + device=device, + plugins=plugins, + evaluator=evaluator, + eval_every=eval_every, + peval_mode=peval_mode, + **kwargs, ) self._criterion = criterion diff --git a/avalanche/training/supervised/lamaml.py b/avalanche/training/supervised/lamaml.py index 94f0b5381..35eb0cc46 100644 --- a/avalanche/training/supervised/lamaml.py +++ b/avalanche/training/supervised/lamaml.py @@ -27,9 +27,10 @@ class LaMAML(SupervisedMetaLearningTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion=CrossEntropyLoss(), + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = CrossEntropyLoss(), n_inner_updates: int = 5, second_order: bool = True, grad_clip_norm: float = 1.0, diff --git a/avalanche/training/supervised/lamaml_v2.py b/avalanche/training/supervised/lamaml_v2.py index 0838ecc61..fc0492aee 100644 --- a/avalanche/training/supervised/lamaml_v2.py +++ b/avalanche/training/supervised/lamaml_v2.py @@ -24,9 +24,10 @@ class LaMAML(SupervisedMetaLearningTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion=CrossEntropyLoss(), + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = CrossEntropyLoss(), n_inner_updates: int = 5, second_order: bool = True, grad_clip_norm: float = 1.0, diff --git a/avalanche/training/supervised/mer.py b/avalanche/training/supervised/mer.py index b1a100cdf..b546eb5d2 100644 --- a/avalanche/training/supervised/mer.py +++ b/avalanche/training/supervised/mer.py @@ -11,6 +11,7 @@ from avalanche.training.plugins.evaluation import default_evaluator from avalanche.training.storage_policy import ReservoirSamplingBuffer from avalanche.training.templates import SupervisedMetaLearningTemplate +from avalanche.training.templates.base_sgd import CriterionType class MERBuffer: @@ -51,9 +52,10 @@ def get_batch(self, x, y, t): class MER(SupervisedMetaLearningTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion=CrossEntropyLoss(), + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = CrossEntropyLoss(), mem_size=200, batch_size_mem=10, n_inner_steps=5, @@ -69,6 +71,7 @@ def __init__( ] = default_evaluator, eval_every=-1, peval_mode="epoch", + **kwargs ): """Implementation of Look-ahead MAML (LaMAML) algorithm in Avalanche using Higher library for applying fast updates. @@ -85,17 +88,19 @@ def __init__( """ super().__init__( - model, - optimizer, - criterion, - train_mb_size, - train_epochs, - eval_mb_size, - device, - plugins, - evaluator, - eval_every, - peval_mode, + legacy_positional_args=args, + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=train_mb_size, + train_epochs=train_epochs, + eval_mb_size=eval_mb_size, + device=device, + plugins=plugins, + evaluator=evaluator, + eval_every=eval_every, + peval_mode=peval_mode, + **kwargs ) self.buffer = MERBuffer( diff --git a/avalanche/training/supervised/strategy_wrappers.py b/avalanche/training/supervised/strategy_wrappers.py index e9c5b945b..f8d6ffce7 100644 --- a/avalanche/training/supervised/strategy_wrappers.py +++ b/avalanche/training/supervised/strategy_wrappers.py @@ -49,6 +49,7 @@ from avalanche.models.generator import MlpVAE, VAE_loss from avalanche.models.expert_gate import AE_loss from avalanche.logging import InteractiveLogger +from avalanche.training.templates.base_sgd import CriterionType class Naive(SupervisedTemplate): @@ -66,9 +67,9 @@ class Naive(SupervisedTemplate): def __init__( self, *args, - model: Module = None, - optimizer: Optimizer = None, - criterion=CrossEntropyLoss(), + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, @@ -103,7 +104,7 @@ def __init__( """ super().__init__( - *args, + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -127,8 +128,8 @@ class PNNStrategy(SupervisedTemplate): def __init__( self, *args, - model: Module = None, - optimizer: Optimizer = None, + model: Module = "not_set", + optimizer: Optimizer = "not_set", criterion=CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, @@ -161,12 +162,9 @@ def __init__( :param base_kwargs: any additional :class:`~avalanche.training.BaseTemplate` constructor arguments. """ - # Check that the model has the correct architecture. - assert isinstance(model, PNN) or ( - len(args) > 0 and isinstance(args[0], PNN) - ), "PNNStrategy requires a PNN model." + super().__init__( - *args, + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -180,6 +178,9 @@ def __init__( **base_kwargs ) + # Check that the model has the correct architecture. + assert isinstance(self.model, PNN), "PNNStrategy requires a PNN model." + class PackNet(SupervisedTemplate): """Task-incremental fixed-network parameter isolation with PackNet. @@ -202,10 +203,11 @@ class PackNet(SupervisedTemplate): def __init__( self, - model: Union[PackNetModule, PackNetModel], - optimizer: Optimizer, - post_prune_epochs: int, - prune_proportion: float, + *args, + model: Union[PackNetModule, PackNetModel] = "not_set", + optimizer: Optimizer = "not_set", + post_prune_epochs: int = "not_set", + prune_proportion: float = "not_set", criterion=CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, @@ -247,10 +249,8 @@ def __init__( :class:`~avalanche.training.BaseTemplate` constructor arguments. """ - if not isinstance(model, PackNetModule): - raise ValueError("PackNet requires a model that implements PackNetModule.") - super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -264,15 +264,19 @@ def __init__( **base_kwargs ) + if not isinstance(self.model, PackNetModule): + raise ValueError("PackNet requires a model that implements PackNetModule.") + class CWRStar(SupervisedTemplate): """CWR* Strategy.""" def __init__( self, - model: Module, - optimizer: Optimizer, - criterion, + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = "not_set", cwr_layer_name: str, train_mb_size: int = 1, train_epochs: int = 1, @@ -307,12 +311,18 @@ def __init__( :param \*\*base_kwargs: any additional :class:`~avalanche.training.BaseTemplate` constructor arguments. """ - cwsp = CWRStarPlugin(model, cwr_layer_name, freeze_remaining_model=True) + + cwsp = CWRStarPlugin( + model if model != "not_set" else args[0], + cwr_layer_name, + freeze_remaining_model=True, + ) if plugins is None: plugins = [cwsp] else: plugins.append(cwsp) super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -336,9 +346,10 @@ class Replay(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion, + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = "not_set", mem_size: int = 200, train_mb_size: int = 1, train_epochs: int = 1, @@ -379,6 +390,7 @@ def __init__( else: plugins.append(rp) super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -414,9 +426,10 @@ class GenerativeReplay(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion=CrossEntropyLoss(), + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, @@ -510,6 +523,7 @@ def __init__( plugins.append(rp) super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -543,6 +557,7 @@ class AETraining(SupervisedTemplate): def __init__( self, + *args, model: Module, optimizer: Optimizer, device, @@ -578,6 +593,7 @@ def __init__( """ super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -615,8 +631,9 @@ class VAETraining(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", criterion=VAE_loss, train_mb_size: int = 1, train_epochs: int = 1, @@ -652,6 +669,7 @@ def __init__( """ super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -680,9 +698,10 @@ class GSS_greedy(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion, + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = "not_set", mem_size: int = 200, mem_strength=1, input_size=[], @@ -750,9 +769,10 @@ class GDumb(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion, + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = "not_set", mem_size: int = 200, train_mb_size: int = 1, train_epochs: int = 1, @@ -794,6 +814,7 @@ def __init__( plugins.append(gdumb) super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -816,9 +837,10 @@ class LwF(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion, + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = "not_set", alpha: Union[float, Sequence[float]], temperature: float, train_mb_size: int = 1, @@ -863,6 +885,7 @@ def __init__( plugins.append(lwf) super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -886,9 +909,10 @@ class AGEM(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion, + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = "not_set", patterns_per_exp: int, sample_size: int = 64, train_mb_size: int = 1, @@ -933,6 +957,7 @@ def __init__( plugins.append(agem) super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -956,9 +981,10 @@ class GEM(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion, + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = "not_set", patterns_per_exp: int, memory_strength: float = 0.5, train_mb_size: int = 1, @@ -1003,6 +1029,7 @@ def __init__( plugins.append(gem) super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -1026,9 +1053,10 @@ class EWC(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion, + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = "not_set", ewc_lambda: float, mode: str = "separate", decay_factor: Optional[float] = None, @@ -1084,6 +1112,7 @@ def __init__( plugins.append(ewc) super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -1116,9 +1145,10 @@ class SynapticIntelligence(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion, + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = "not_set", si_lambda: Union[float, Sequence[float]], eps: float = 0.0000001, train_mb_size: int = 1, @@ -1169,6 +1199,7 @@ def __init__( plugins.append(SynapticIntelligencePlugin(si_lambda=si_lambda, eps=eps)) super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -1192,9 +1223,10 @@ class CoPE(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion, + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = "not_set", mem_size: int = 200, n_classes: int = 10, p_size: int = 100, @@ -1247,6 +1279,7 @@ def __init__( else: plugins.append(copep) super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -1271,9 +1304,10 @@ class LFL(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion, + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = "not_set", lambda_e: Union[float, Sequence[float]], train_mb_size: int = 1, train_epochs: int = 1, @@ -1315,6 +1349,7 @@ def __init__( plugins.append(lfl) super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -1338,9 +1373,10 @@ class MAS(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion, + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = "not_set", lambda_reg: float = 1.0, alpha: float = 0.5, verbose: bool = False, @@ -1392,6 +1428,7 @@ def __init__( plugins.append(mas) super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -1415,9 +1452,10 @@ class BiC(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion, + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = "not_set", mem_size: int = 200, val_percentage: float = 0.1, T: int = 2, @@ -1484,6 +1522,7 @@ def __init__( plugins.append(bic) super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -1505,11 +1544,12 @@ class MIR(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion, - mem_size: int, - subsample: int, + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = "not_set", + mem_size: int = "not_set", + subsample: int = "not_set", batch_size_mem: int = 1, train_mb_size: int = 1, train_epochs: int = 1, @@ -1549,7 +1589,9 @@ def __init__( # Instantiate plugin mir = MIRPlugin( - mem_size=mem_size, subsample=subsample, batch_size_mem=batch_size_mem + mem_size=mem_size if mem_size != "not_set" else args[3], + subsample=subsample if subsample != "not_set" else args[5], + batch_size_mem=batch_size_mem, ) # Add plugin to the strategy @@ -1559,6 +1601,7 @@ def __init__( plugins.append(mir) super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -1584,9 +1627,10 @@ class FromScratchTraining(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion, + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = "not_set", reset_optimizer: bool = True, train_mb_size: int = 1, train_epochs: int = 1, diff --git a/avalanche/training/supervised/strategy_wrappers_online.py b/avalanche/training/supervised/strategy_wrappers_online.py index 640f5abb4..b4fb71d39 100644 --- a/avalanche/training/supervised/strategy_wrappers_online.py +++ b/avalanche/training/supervised/strategy_wrappers_online.py @@ -21,6 +21,7 @@ SupervisedTemplate, ) from avalanche._annotations import deprecated +from avalanche.training.templates.base_sgd import CriterionType @deprecated( @@ -43,9 +44,10 @@ class OnlineNaive(SupervisedTemplate): def __init__( self, - model: Module, - optimizer: Optimizer, - criterion=CrossEntropyLoss(), + *args, + model: Module = "not_set", + optimizer: Optimizer = "not_set", + criterion: CriterionType = CrossEntropyLoss(), train_passes: int = 1, train_mb_size: int = 1, eval_mb_size: Optional[int] = None, @@ -78,6 +80,7 @@ def __init__( learning experience. """ super().__init__( + legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, diff --git a/avalanche/training/templates/common_templates.py b/avalanche/training/templates/common_templates.py index 9f48e1d7b..d15f00bcc 100644 --- a/avalanche/training/templates/common_templates.py +++ b/avalanche/training/templates/common_templates.py @@ -1,7 +1,8 @@ import sys -from typing import Callable, Sequence, Optional, TypeVar, Union +from typing import Any, Callable, Dict, List, Sequence, Optional, TypeVar, Union import warnings import torch +import inspect from torch.nn import Module, CrossEntropyLoss from torch.optim import Optimizer @@ -28,31 +29,96 @@ TMBOutput = TypeVar("TMBOutput") -def _merge_legacy_positional_arguments(args, kwargs): - # Manage the legacy positional constructor parameters +class PositionalArgumentDeprecatedWarning(UserWarning): + pass + + +def _merge_legacy_positional_arguments( + self, + args: Optional[Sequence[Any]], + kwargs: Dict[str, Any], + strict_init_check=True, + allow_pos_args=True, +): + """ + Manage the legacy positional constructor parameters. + + Used to warn the user about the deprecation of positional parameters + in strategy constructors. + + To allow for a smooth transition, we allow the user to pass positional + arguments to the constructor. However, we want to warn the user that + this is deprecated and will be removed in the future. + """ + + init_method = getattr(self, "__init__") + init_kwargs = inspect.signature(init_method).parameters + + if self.__class__ in [SupervisedTemplate, SupervisedMetaLearningTemplate]: + return + + has_transition_varargs = any( + x.kind == inspect.Parameter.VAR_POSITIONAL for x in init_kwargs.values() + ) + if (not has_transition_varargs and args is not None) or ( + has_transition_varargs and args is None + ): + error_str = ( + "While trainsitioning to the new keyword-only strategy constructors, " + "the ability to pass positional arguments (as in older versions of Avalanche) should be granted. " + "You should catch all positional arguments (excluding self) using *args and do not declare other positional arguments! " + "Then, pass them to SupervisedTemplate/SupervisedMetaLearningTemplate super constructor as the legacy_args argument." + "Those legacy positional arguments will then be converted to keyword arguments " + "according to the declaration order of those keyword arguments." + ) + + if strict_init_check: + raise PositionalArgumentDeprecatedWarning(error_str) + else: + warnings.warn(error_str, category=PositionalArgumentDeprecatedWarning) + + positional_args = { + (k, x) + for k, x in init_kwargs.items() + if x.kind + in [inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD] + } + + if len(positional_args) > 0: + all_positional_args_names_as_str = ", ".join([x[0] for x in positional_args]) + error_str = ( + "Avalanche is transitioning to strategy constructors that allow for keyword arguments only. " + "Your strategy __init__ method still has some positional-only or positional-or-keyword arguments. " + "Consider removing them. Offending arguments: " + + all_positional_args_names_as_str + ) + + if strict_init_check: + raise PositionalArgumentDeprecatedWarning(error_str) + else: + warnings.warn(error_str, category=PositionalArgumentDeprecatedWarning) + + # Discard all non-keyword params (also exclude VAR_KEYWORD) + kwargs_order = [ + k for k, v in init_kwargs.items() if v.kind == inspect.Parameter.KEYWORD_ONLY + ] + + print(init_kwargs) + if len(args) > 0: - warnings.warn( + + error_str = ( "Passing positional arguments to Strategy constructors is " "deprecated. Please use keyword arguments instead." ) - # unroll args and apply it to kwargs - legacy_kwargs_order = [ - "model", - "optimizer", - "criterion", - "train_mb_size", - "train_epochs", - "eval_mb_size", - "device", - "plugins", - "evaluator", - "eval_every", - "peval_mode", - ] + if allow_pos_args: + warnings.warn(error_str, category=PositionalArgumentDeprecatedWarning) + else: + raise PositionalArgumentDeprecatedWarning(error_str) for i, arg in enumerate(args): - kwargs[legacy_kwargs_order[i]] = arg + kwargs[kwargs_order[i]] = arg for key, value in kwargs.items(): if value == "not_set": @@ -117,7 +183,8 @@ class SupervisedTemplate( # TODO: remove default values of model and optimizer when legacy positional arguments are definitively removed def __init__( self, - *args, + *, + legacy_positional_args: Optional[Sequence[Any]], model: Module = "not_set", optimizer: Optimizer = "not_set", criterion: CriterionType = CrossEntropyLoss(), @@ -169,7 +236,9 @@ def __init__( kwargs["eval_every"] = eval_every kwargs["peval_mode"] = peval_mode - kwargs = _merge_legacy_positional_arguments(args, kwargs) + kwargs = _merge_legacy_positional_arguments( + self, legacy_positional_args, kwargs + ) if sys.version_info >= (3, 11): super().__init__(**kwargs) @@ -233,7 +302,8 @@ class SupervisedMetaLearningTemplate( # TODO: remove default values of model and optimizer when legacy positional arguments are definitively removed def __init__( self, - *args, + *, + legacy_positional_args: Optional[Sequence[Any]], model: Module = "not_set", optimizer: Optimizer = "not_set", criterion=CrossEntropyLoss(), @@ -285,7 +355,9 @@ def __init__( kwargs["eval_every"] = eval_every kwargs["peval_mode"] = peval_mode - kwargs = _merge_legacy_positional_arguments(args, kwargs) + kwargs = _merge_legacy_positional_arguments( + self, legacy_positional_args, kwargs + ) if sys.version_info >= (3, 11): super().__init__(**kwargs) diff --git a/avalanche/training/templates/update_type/meta_update.py b/avalanche/training/templates/update_type/meta_update.py index 79e8a5016..2217d2d07 100644 --- a/avalanche/training/templates/update_type/meta_update.py +++ b/avalanche/training/templates/update_type/meta_update.py @@ -5,6 +5,9 @@ class MetaUpdate(MetaLearningStrategyProtocol): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + def training_epoch(self, **kwargs): """Training epoch. diff --git a/examples/naive.py b/examples/naive.py index e63241c90..8df7a9a2e 100644 --- a/examples/naive.py +++ b/examples/naive.py @@ -39,10 +39,10 @@ def main(): # create strategy strategy = Naive( - model=model, - optimizer=optimizer, - criterion=criterion, - train_epochs=1, + model, + optimizer, + criterion, + 1, device=device, train_mb_size=32, evaluator=eval_plugin, diff --git a/pyproject.toml b/pyproject.toml index 192d35128..1f0d593dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,4 +59,7 @@ ignore_errors = true [[tool.mypy.overrides]] module = "profiling.*" -ignore_errors = true \ No newline at end of file +ignore_errors = true + +[tool.pytest.ini_options] +addopts="-n 4" \ No newline at end of file diff --git a/tests/benchmarks/utils/test_dataloaders.py b/tests/benchmarks/utils/test_dataloaders.py index 31a1c47cd..8a6892828 100644 --- a/tests/benchmarks/utils/test_dataloaders.py +++ b/tests/benchmarks/utils/test_dataloaders.py @@ -157,9 +157,9 @@ def test_dataloader_with_multiple_workers(self): # Continual learning strategy cl_strategy = Naive( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=32, train_epochs=1, eval_mb_size=32, diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 7fdf67996..1dd49cc02 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -157,9 +157,9 @@ def test_optimizers(self): model, "SGDmom", "Adam", "SGD", "AdamW" ): strategy = Naive( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=64, device=self.device, eval_mb_size=50, @@ -172,9 +172,9 @@ def test_checkpointing(self): model, criterion, benchmark = self.init_scenario(multi_task=True) optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9) strategy = Naive( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=64, device=self.device, eval_mb_size=50, @@ -190,9 +190,9 @@ def test_checkpointing(self): model, criterion, benchmark = self.init_scenario(multi_task=True) optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9) strategy = Naive( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=64, device=self.device, eval_mb_size=50, @@ -226,9 +226,9 @@ def test_mh_classifier(self): model, criterion, benchmark = self.init_scenario(multi_task=True) optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9) strategy = Naive( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=64, device=self.device, eval_mb_size=50, @@ -272,9 +272,9 @@ def test_incremental_classifier(self): benchmark = self.benchmark strategy = Naive( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=100, train_epochs=1, eval_mb_size=100, @@ -334,9 +334,9 @@ def test_incremental_classifier_weight_update(self): benchmark = self.benchmark strategy = Naive( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=100, train_epochs=1, eval_mb_size=100, @@ -378,9 +378,9 @@ def test_multihead_head_creation(self): benchmark = get_fast_benchmark(use_task_labels=True, shuffle=False) strategy = Naive( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=100, train_epochs=1, eval_mb_size=100, @@ -432,9 +432,9 @@ def test_multihead_head_selection(self): benchmark = get_fast_benchmark(use_task_labels=True, shuffle=False) strategy = Naive( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=100, train_epochs=1, eval_mb_size=100, diff --git a/tests/models/test_packnet.py b/tests/models/test_packnet.py index 1b0caff40..f87992cbd 100644 --- a/tests/models/test_packnet.py +++ b/tests/models/test_packnet.py @@ -1,6 +1,9 @@ import unittest from avalanche.models.packnet import PackNetModel, packnet_simple_mlp from avalanche.training.supervised.strategy_wrappers import PackNet +from avalanche.training.templates.common_templates import ( + PositionalArgumentDeprecatedWarning, +) from torch.optim import SGD, Adam import torch import os @@ -26,16 +29,17 @@ def _test_PackNetPlugin(self, expectations, optimizer_constructor, lr): ) model = construct_model() optimizer = optimizer_constructor(model.parameters(), lr=lr) - strategy = PackNet( - model, - prune_proportion=0.5, - post_prune_epochs=1, - optimizer=optimizer, - train_epochs=2, - train_mb_size=10, - eval_mb_size=10, - device="cuda" if use_gpu else "cpu", - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = PackNet( + model, + prune_proportion=0.5, + post_prune_epochs=1, + optimizer=optimizer, + train_epochs=2, + train_mb_size=10, + eval_mb_size=10, + device="cuda" if use_gpu else "cpu", + ) x_test = torch.rand(10, 6) t_test = torch.ones(10, dtype=torch.long) task_ouputs = [] diff --git a/tests/training/test_online_strategies.py b/tests/training/test_online_strategies.py index 520a837d2..aec164887 100644 --- a/tests/training/test_online_strategies.py +++ b/tests/training/test_online_strategies.py @@ -9,6 +9,9 @@ from avalanche.models import SimpleMLP from avalanche.benchmarks.scenarios.online import OnlineCLScenario from avalanche.training import OnlineNaive +from avalanche.training.templates.common_templates import ( + PositionalArgumentDeprecatedWarning, +) from tests.unit_tests_utils import get_fast_benchmark from avalanche.training.plugins.evaluation import default_evaluator @@ -46,24 +49,25 @@ def test_naive(self): # With task boundaries model, optimizer, criterion, _ = self.init_sit() - strategy = OnlineNaive( - model, - optimizer, - criterion, - train_mb_size=1, - device=self.device, - eval_mb_size=50, - evaluator=default_evaluator, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = OnlineNaive( + model, + optimizer, + criterion=criterion, + train_mb_size=1, + device=self.device, + eval_mb_size=50, + evaluator=default_evaluator, + ) ocl_benchmark = OnlineCLScenario(benchmark_streams, access_task_boundaries=True) self.run_strategy_boundaries(ocl_benchmark, strategy) # Without task boundaries model, optimizer, criterion, my_nc_benchmark = self.init_sit() strategy = OnlineNaive( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=1, device=self.device, eval_mb_size=50, diff --git a/tests/training/test_plugins.py b/tests/training/test_plugins.py index 04e591f99..6c223c73e 100644 --- a/tests/training/test_plugins.py +++ b/tests/training/test_plugins.py @@ -118,9 +118,9 @@ def test_callback_reachability(self): plug = MockPlugin() strategy = Naive( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=100, train_epochs=1, eval_mb_size=100, diff --git a/tests/training/test_strategies.py b/tests/training/test_strategies.py index 61c3592ff..095c7224a 100644 --- a/tests/training/test_strategies.py +++ b/tests/training/test_strategies.py @@ -60,6 +60,9 @@ from avalanche.training.supervised.strategy_wrappers import PNNStrategy from avalanche.training.templates import SupervisedTemplate from avalanche.training.templates.base import _group_experiences_by_stream +from avalanche.training.templates.common_templates import ( + PositionalArgumentDeprecatedWarning, +) from avalanche.training.utils import get_last_fc_layer from tests.training.test_strategy_utils import run_strategy from tests.unit_tests_utils import get_fast_benchmark, get_device @@ -110,9 +113,9 @@ def test_periodic_eval(self): # for each eval loop. acc = StreamAccuracy() strategy = Naive( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_epochs=2, eval_every=-1, evaluator=EvaluationPlugin(acc), @@ -127,9 +130,9 @@ def test_periodic_eval(self): acc = StreamAccuracy() evalp = EvaluationPlugin(acc) strategy = Naive( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_epochs=2, eval_every=0, evaluator=evalp, @@ -144,9 +147,9 @@ def test_periodic_eval(self): ################### acc = StreamAccuracy() strategy = Naive( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_epochs=2, eval_every=1, evaluator=EvaluationPlugin(acc), @@ -160,9 +163,9 @@ def test_periodic_eval(self): ################### acc = StreamAccuracy() strategy = Naive( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_epochs=2, eval_every=100, evaluator=EvaluationPlugin(acc), @@ -185,9 +188,9 @@ def test_plugins_compatibility_checks(self): ) strategy = Naive( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_epochs=2, eval_every=-1, evaluator=evalp, @@ -226,9 +229,9 @@ def after_training_iteration( optimizer = SGD(model.parameters(), lr=1) strategy = Cumulative( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=1, device=get_device(), eval_mb_size=512, @@ -271,23 +274,26 @@ def init_scenario(self, multi_task=False): def test_naive(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = Naive( - model, - optimizer, - criterion, - train_mb_size=64, - device=self.device, - eval_mb_size=50, - train_epochs=2, - ) + + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = Naive( + model, + optimizer, + criterion, + train_mb_size=64, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) + run_strategy(benchmark, strategy) # MT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=True) strategy = Naive( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=64, device=self.device, eval_mb_size=50, @@ -313,25 +319,28 @@ def after_train_dataset_adaptation( # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = JointTraining( - model, - optimizer, - criterion, - train_mb_size=64, - device=self.device, - eval_mb_size=50, - train_epochs=2, - plugins=[JointSTestPlugin(benchmark)], - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = JointTraining( + model, + optimizer, + criterion, + train_mb_size=64, + device=self.device, + eval_mb_size=50, + train_epochs=2, + plugins=[JointSTestPlugin(benchmark)], + ) + strategy.evaluator.loggers = [TextLogger(sys.stdout)] strategy.train(benchmark.train_stream) # MT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=True) + strategy = JointTraining( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=64, device=self.device, eval_mb_size=50, @@ -351,14 +360,17 @@ def test_cwrstar(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) last_fc_name, _ = get_last_fc_layer(model) - strategy = CWRStar( - model, - optimizer, - criterion, - last_fc_name, - train_mb_size=64, - device=self.device, - ) + + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = CWRStar( + model, + optimizer, + criterion=criterion, + cwr_layer_name=last_fc_name, + train_mb_size=64, + device=self.device, + ) + run_strategy(benchmark, strategy) dict_past_j = {} @@ -378,10 +390,10 @@ def test_cwrstar(self): # MT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=True) strategy = CWRStar( - model, - optimizer, - criterion, - last_fc_name, + model=model, + optimizer=optimizer, + criterion=criterion, + cwr_layer_name=last_fc_name, train_mb_size=64, device=self.device, ) @@ -401,25 +413,27 @@ def test_cwrstar(self): def test_replay(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = Replay( - model, - optimizer, - criterion, - mem_size=10, - train_mb_size=64, - device=self.device, - eval_mb_size=50, - train_epochs=2, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = Replay( + model, + optimizer, + criterion=criterion, + mem_size=10, + train_mb_size=64, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) + run_strategy(benchmark, strategy) # MT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=True) strategy = Replay( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, mem_size=10, train_mb_size=64, device=self.device, @@ -432,25 +446,27 @@ def test_replay(self): def test_mer(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = MER( - model, - optimizer, - criterion, - mem_size=10, - train_mb_size=64, - device=self.device, - eval_mb_size=50, - train_epochs=2, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = MER( + model, + optimizer, + criterion=criterion, + mem_size=10, + train_mb_size=64, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) + run_strategy(benchmark, strategy) # MT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=True) strategy = MER( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, mem_size=10, train_mb_size=64, device=self.device, @@ -463,24 +479,26 @@ def test_mer(self): def test_gdumb(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = GDumb( - model, - optimizer, - criterion, - mem_size=200, - train_mb_size=64, - device=self.device, - eval_mb_size=50, - train_epochs=2, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = GDumb( + model, + optimizer, + criterion, + mem_size=200, + train_mb_size=64, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) + run_strategy(benchmark, strategy) # MT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=True) strategy = GDumb( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, mem_size=200, train_mb_size=64, device=self.device, @@ -492,23 +510,24 @@ def test_gdumb(self): def test_cumulative(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = Cumulative( - model, - optimizer, - criterion, - train_mb_size=64, - device=self.device, - eval_mb_size=50, - train_epochs=2, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = Cumulative( + model, + optimizer, + criterion, + train_mb_size=64, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) run_strategy(benchmark, strategy) # MT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=True) strategy = Cumulative( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=64, device=self.device, eval_mb_size=50, @@ -518,17 +537,19 @@ def test_cumulative(self): def test_slda(self): model, _, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = StreamingLDA( - model, - criterion, - input_size=10, - output_layer_name="features", - num_classes=10, - eval_mb_size=7, - train_epochs=1, - device=self.device, - train_mb_size=7, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = StreamingLDA( + model, + criterion, + 10, + 10, + output_layer_name="features", + eval_mb_size=7, + train_epochs=1, + device=self.device, + train_mb_size=7, + ) + run_strategy(benchmark, strategy) def test_warning_slda_lwf(self): @@ -546,25 +567,26 @@ def test_warning_slda_lwf(self): def test_lwf(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = LwF( - model, - optimizer, - criterion, - alpha=[0, 1 / 2, 2 * (2 / 3), 3 * (3 / 4), 4 * (4 / 5)], - temperature=2, - device=self.device, - train_mb_size=10, - eval_mb_size=50, - train_epochs=2, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = LwF( + model, + optimizer, + criterion, + alpha=[0, 1 / 2, 2 * (2 / 3), 3 * (3 / 4), 4 * (4 / 5)], + temperature=2, + device=self.device, + train_mb_size=10, + eval_mb_size=50, + train_epochs=2, + ) run_strategy(benchmark, strategy) # MT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=True) strategy = LwF( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, alpha=[0, 1 / 2, 2 * (2 / 3), 3 * (3 / 4), 4 * (4 / 5)], temperature=2, device=self.device, @@ -577,24 +599,25 @@ def test_lwf(self): def test_agem(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = AGEM( - model, - optimizer, - criterion, - patterns_per_exp=25, - sample_size=25, - train_mb_size=10, - eval_mb_size=50, - train_epochs=2, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = AGEM( + model, + optimizer, + criterion, + patterns_per_exp=25, + sample_size=25, + train_mb_size=10, + eval_mb_size=50, + train_epochs=2, + ) run_strategy(benchmark, strategy) # MT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=True) strategy = AGEM( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, patterns_per_exp=25, sample_size=25, train_mb_size=10, @@ -606,24 +629,25 @@ def test_agem(self): def test_gem(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = GEM( - model, - optimizer, - criterion, - patterns_per_exp=256, - train_mb_size=10, - eval_mb_size=50, - train_epochs=2, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = GEM( + model, + optimizer, + criterion, + patterns_per_exp=256, + train_mb_size=10, + eval_mb_size=50, + train_epochs=2, + ) run_strategy(benchmark, strategy) # MT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=True) strategy = GEM( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, patterns_per_exp=256, train_mb_size=10, eval_mb_size=50, @@ -635,25 +659,26 @@ def test_gem(self): def test_ewc(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = EWC( - model, - optimizer, - criterion, - ewc_lambda=0.4, - mode="separate", - train_mb_size=10, - eval_mb_size=50, - train_epochs=2, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = EWC( + model, + optimizer, + criterion, + ewc_lambda=0.4, + mode="separate", + train_mb_size=10, + eval_mb_size=50, + train_epochs=2, + ) run_strategy(benchmark, strategy) # MT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=True) strategy = EWC( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, ewc_lambda=0.4, mode="separate", train_mb_size=10, @@ -665,25 +690,26 @@ def test_ewc(self): def test_ewc_online(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = EWC( - model, - optimizer, - criterion, - ewc_lambda=0.4, - mode="online", - decay_factor=0.1, - train_mb_size=10, - eval_mb_size=50, - train_epochs=2, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = EWC( + model, + optimizer, + criterion, + ewc_lambda=0.4, + mode="online", + decay_factor=0.1, + train_mb_size=10, + eval_mb_size=50, + train_epochs=2, + ) run_strategy(benchmark, strategy) # # MT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=True) strategy = EWC( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, ewc_lambda=0.4, mode="online", decay_factor=0.1, @@ -696,29 +722,30 @@ def test_ewc_online(self): def test_rwalk(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = Naive( - model, - optimizer, - criterion, - train_mb_size=10, - eval_mb_size=50, - train_epochs=2, - plugins=[ - RWalkPlugin( - ewc_lambda=0.1, - ewc_alpha=0.9, - delta_t=10, - ), - ], - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = Naive( + model, + optimizer, + criterion, + train_mb_size=10, + eval_mb_size=50, + train_epochs=2, + plugins=[ + RWalkPlugin( + ewc_lambda=0.1, + ewc_alpha=0.9, + delta_t=10, + ), + ], + ) run_strategy(benchmark, strategy) # # MT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=True) strategy = Naive( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=10, eval_mb_size=50, train_epochs=2, @@ -735,23 +762,24 @@ def test_rwalk(self): def test_synaptic_intelligence(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = SynapticIntelligence( - model, - optimizer, - criterion, - si_lambda=0.0001, - train_epochs=1, - train_mb_size=10, - eval_mb_size=10, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = SynapticIntelligence( + model, + optimizer, + criterion, + si_lambda=0.0001, + train_epochs=1, + train_mb_size=10, + eval_mb_size=10, + ) run_strategy(benchmark, strategy) # MT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=True) strategy = SynapticIntelligence( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, si_lambda=0.0001, train_epochs=1, train_mb_size=10, @@ -766,18 +794,19 @@ def test_cope(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = CoPE( - model, - optimizer, - criterion, - mem_size=10, - n_classes=n_classes, - p_size=emb_size, - train_mb_size=10, - device=self.device, - eval_mb_size=50, - train_epochs=2, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = CoPE( + model, + optimizer, + criterion, + mem_size=10, + n_classes=n_classes, + p_size=emb_size, + train_mb_size=10, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) run_strategy(benchmark, strategy) # MT scenario @@ -802,14 +831,15 @@ def test_pnn(self): # eval on future tasks is not allowed. model = PNN(num_layers=3, in_features=6, hidden_features_per_column=10) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) - strategy = PNNStrategy( - model, - optimizer, - train_mb_size=10, - device=self.device, - eval_mb_size=50, - train_epochs=2, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = PNNStrategy( + model, + optimizer, + train_mb_size=10, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) # train and test loop benchmark = self.load_benchmark(use_task_labels=True) @@ -820,17 +850,18 @@ def test_pnn(self): def test_expertgate(self): model = ExpertGate(shape=(3, 227, 227), device=self.device) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) - strategy = ExpertGateStrategy( - model, - optimizer, - device=self.device, - train_mb_size=10, - train_epochs=2, - eval_mb_size=50, - ae_train_mb_size=10, - ae_train_epochs=2, - ae_lr=5e-4, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = ExpertGateStrategy( + model, + optimizer, + device=self.device, + train_mb_size=10, + train_epochs=2, + eval_mb_size=50, + ae_train_mb_size=10, + ae_train_epochs=2, + ae_lr=5e-4, + ) # Mandatory transform for AlexNet # 3 Channels and input size should be a minimum of 227 @@ -858,35 +889,37 @@ def test_icarl(self): cls = torch.nn.Sigmoid() optimizer = SGD([*fe.parameters(), *cls.parameters()], lr=0.001) - strategy = ICaRL( - fe, - cls, - optimizer, - 20, - buffer_transform=None, - fixed_memory=True, - train_mb_size=32, - train_epochs=2, - eval_mb_size=50, - device=self.device, - eval_every=-1, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = ICaRL( + fe, + cls, + optimizer, + 20, + buffer_transform=None, + fixed_memory=True, + train_mb_size=32, + train_epochs=2, + eval_mb_size=50, + device=self.device, + eval_every=-1, + ) run_strategy(benchmark, strategy) def test_lfl(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = LFL( - model, - optimizer, - criterion, - lambda_e=0.0001, - train_mb_size=10, - device=self.device, - eval_mb_size=50, - train_epochs=2, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = LFL( + model, + optimizer, + criterion, + lambda_e=0.0001, + train_mb_size=10, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) run_strategy(benchmark, strategy) # MT scenario @@ -907,17 +940,18 @@ def test_lfl(self): def test_mas(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = MAS( - model, - optimizer, - criterion, - lambda_reg=1.0, - alpha=0.5, - train_mb_size=10, - device=self.device, - eval_mb_size=50, - train_epochs=2, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = MAS( + model, + optimizer, + criterion, + lambda_reg=1.0, + alpha=0.5, + train_mb_size=10, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) run_strategy(benchmark, strategy) # MT scenario @@ -939,46 +973,48 @@ def test_mas(self): def test_bic(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = BiC( - model, - optimizer, - criterion, - mem_size=50, - val_percentage=0.1, - T=2, - stage_2_epochs=10, - lamb=-1, - lr=0.01, - train_mb_size=10, - device=self.device, - eval_mb_size=50, - train_epochs=2, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = BiC( + model, + optimizer, + criterion, + mem_size=50, + val_percentage=0.1, + T=2, + stage_2_epochs=10, + lamb=-1, + lr=0.01, + train_mb_size=10, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) run_strategy(benchmark, strategy) def test_mir(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = MIR( - model, - optimizer, - criterion, - mem_size=1000, - batch_size_mem=10, - subsample=50, - train_mb_size=10, - device=self.device, - eval_mb_size=50, - train_epochs=2, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = MIR( + model, + optimizer, + criterion, + mem_size=1000, + batch_size_mem=10, + subsample=50, + train_mb_size=10, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) run_strategy(benchmark, strategy) # MT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=True) strategy = MIR( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, mem_size=1000, batch_size_mem=10, subsample=50, @@ -992,70 +1028,74 @@ def test_mir(self): def test_erace(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = ER_ACE( - model, - optimizer, - criterion, - mem_size=1000, - batch_size_mem=10, - train_mb_size=10, - device=self.device, - eval_mb_size=50, - train_epochs=2, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = ER_ACE( + model, + optimizer, + criterion, + mem_size=1000, + batch_size_mem=10, + train_mb_size=10, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) run_strategy(benchmark, strategy) def test_eraml(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = ER_AML( - model, - model.features, - optimizer, - criterion, - temp=0.1, - base_temp=0.07, - mem_size=1000, - batch_size_mem=10, - train_mb_size=10, - device=self.device, - eval_mb_size=50, - train_epochs=2, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = ER_AML( + model, + model.features, + optimizer, + criterion, + temp=0.1, + base_temp=0.07, + mem_size=1000, + batch_size_mem=10, + train_mb_size=10, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) run_strategy(benchmark, strategy) def test_l2p(self): _, _, _, benchmark = self.init_scenario(multi_task=False) - strategy = LearningToPrompt( - model_name="simpleMLP", - criterion=CrossEntropyLoss(), - train_mb_size=10, - device=self.device, - train_epochs=1, - num_classes=10, - eval_mb_size=50, - use_cls_features=False, - use_mask=False, - use_vit=False, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = LearningToPrompt( + "simpleMLP", + criterion=CrossEntropyLoss(), + train_mb_size=10, + device=self.device, + train_epochs=1, + num_classes=10, + eval_mb_size=50, + use_cls_features=False, + use_mask=False, + use_vit=False, + ) run_strategy(benchmark, strategy) def test_der(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - strategy = DER( - model, - optimizer, - criterion, - mem_size=1000, - batch_size_mem=10, - train_mb_size=10, - device=self.device, - eval_mb_size=50, - train_epochs=2, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = DER( + model, + optimizer, + criterion, + mem_size=1000, + batch_size_mem=10, + train_mb_size=10, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) run_strategy(benchmark, strategy) def test_feature_distillation(self): @@ -1073,9 +1113,9 @@ def test_feature_distillation(self): plugins = [feature_distillation] strategy = Naive( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, device=self.device, train_mb_size=10, eval_mb_size=50, @@ -1094,16 +1134,17 @@ def test_feature_replay(self): # We do not add MT criterion cause FeatureReplay uses # MaskedCrossEntropy as main criterion - strategy = FeatureReplay( - model, - optimizer, - device=self.device, - last_layer_name=last_fc_name, - train_mb_size=10, - eval_mb_size=50, - train_epochs=2, - plugins=plugins, - ) + with self.assertWarns(PositionalArgumentDeprecatedWarning): + strategy = FeatureReplay( + model, + optimizer, + device=self.device, + last_layer_name=last_fc_name, + train_mb_size=10, + eval_mb_size=50, + train_epochs=2, + plugins=plugins, + ) run_strategy(benchmark, strategy) def load_benchmark( diff --git a/tests/training/test_strategies_accuracy.py b/tests/training/test_strategies_accuracy.py index 63776c7d8..481f94367 100644 --- a/tests/training/test_strategies_accuracy.py +++ b/tests/training/test_strategies_accuracy.py @@ -82,9 +82,9 @@ def test_multihead_cumulative(self): exp_acc = ExperienceAccuracy() evalp = EvaluationPlugin(main_metric, exp_acc, loggers=None) strategy = Cumulative( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=64, device=get_device(), eval_mb_size=512, @@ -111,8 +111,8 @@ def test_pnn(self): model = PNN(num_layers=1, in_features=6, hidden_features_per_column=50) optimizer = SGD(model.parameters(), lr=0.1) strategy = PNNStrategy( - model, - optimizer, + model=model, + optimizer=optimizer, train_mb_size=32, device=get_device(), eval_mb_size=512, diff --git a/tests/training/test_supervised_regression.py b/tests/training/test_supervised_regression.py index 76cff8f9b..f1af6950c 100644 --- a/tests/training/test_supervised_regression.py +++ b/tests/training/test_supervised_regression.py @@ -34,6 +34,7 @@ from avalanche.training.plugins.clock import Clock from avalanche.training.plugins import EvaluationPlugin +from avalanche.training.supervised.strategy_wrappers import Naive from avalanche.training.templates import SupervisedTemplate from avalanche.training.utils import trigger_plugins from tests.unit_tests_utils import get_fast_benchmark @@ -396,9 +397,9 @@ def __init__( eval_every=-1, ): super().__init__( - model, - optimizer, - criterion, + model=model, + optimizer=optimizer, + criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, @@ -541,9 +542,9 @@ def test_reproduce_old_base_strategy(self): # if you want to check whether the seeds are set correctly # switch SupervisedTemplate with OldBaseStrategy and check that # values are exactly equal. - new_strategy = SupervisedTemplate( - new_model, - SGD(new_model.parameters(), lr=0.01), + new_strategy = Naive( + model=new_model, + optimizer=SGD(new_model.parameters(), lr=0.01), train_epochs=2, plugins=[p_new], evaluator=None, From 043effa674585a9fa0fab706b8952e0ceddb160f Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Thu, 1 Feb 2024 21:38:32 +0100 Subject: [PATCH 05/21] WIP: better management of legacy positional arguments --- avalanche/training/supervised/ar1.py | 3 +- avalanche/training/supervised/cumulative.py | 9 +- avalanche/training/supervised/deep_slda.py | 22 +- avalanche/training/supervised/der.py | 7 +- avalanche/training/supervised/er_ace.py | 7 +- avalanche/training/supervised/er_aml.py | 13 +- avalanche/training/supervised/expert_gate.py | 7 +- .../training/supervised/feature_replay.py | 7 +- avalanche/training/supervised/icarl.py | 25 +- .../training/supervised/joint_training.py | 9 +- avalanche/training/supervised/l2p.py | 7 +- avalanche/training/supervised/lamaml.py | 6 +- avalanche/training/supervised/lamaml_v2.py | 6 +- avalanche/training/supervised/mer.py | 7 +- .../training/supervised/strategy_wrappers.py | 185 +++++++-------- .../supervised/strategy_wrappers_online.py | 7 +- avalanche/training/templates/base.py | 223 +++++++++++++++++- avalanche/training/templates/base_sgd.py | 2 +- .../training/templates/common_templates.py | 140 ++--------- .../observation_type/batch_observation.py | 13 +- .../problem_type/supervised_problem.py | 11 +- .../templates/strategy_mixin_protocol.py | 22 +- .../templates/update_type/meta_update.py | 9 +- .../templates/update_type/sgd_update.py | 13 +- examples/naive.py | 2 +- tests/models/test_packnet.py | 2 +- tests/training/test_online_strategies.py | 2 +- tests/training/test_strategies.py | 2 +- 28 files changed, 427 insertions(+), 341 deletions(-) diff --git a/avalanche/training/supervised/ar1.py b/avalanche/training/supervised/ar1.py index 999924e69..a06370706 100644 --- a/avalanche/training/supervised/ar1.py +++ b/avalanche/training/supervised/ar1.py @@ -43,7 +43,7 @@ class AR1(SupervisedTemplate): def __init__( self, - *args, + *, criterion: CriterionType = None, lr: float = 0.001, inc_lr: float = 5e-5, @@ -169,7 +169,6 @@ def __init__( self.replay_mb_size = 0 super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, diff --git a/avalanche/training/supervised/cumulative.py b/avalanche/training/supervised/cumulative.py index c324bbef3..314966be2 100644 --- a/avalanche/training/supervised/cumulative.py +++ b/avalanche/training/supervised/cumulative.py @@ -20,10 +20,10 @@ class Cumulative(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", - criterion: CriterionType = "not_set", + *, + model: Module, + optimizer: Optimizer, + criterion: CriterionType, train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, @@ -55,7 +55,6 @@ def __init__( """ super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, diff --git a/avalanche/training/supervised/deep_slda.py b/avalanche/training/supervised/deep_slda.py index 2ff1ae4cd..7bc336d74 100644 --- a/avalanche/training/supervised/deep_slda.py +++ b/avalanche/training/supervised/deep_slda.py @@ -3,6 +3,7 @@ import os import torch +from torch.nn import Module from avalanche.training.plugins import SupervisedPlugin from avalanche.training.templates import SupervisedTemplate @@ -29,11 +30,11 @@ class StreamingLDA(SupervisedTemplate): def __init__( self, - *args, - slda_model="not_set", - criterion: CriterionType = "not_set", - input_size="not_set", - num_classes: int = "not_set", + *, + slda_model: Module, + criterion: CriterionType, + input_size: int, + num_classes: int, output_layer_name: Optional[str] = None, shrinkage_param=1e-4, streaming_update_sigma=True, @@ -74,24 +75,13 @@ def __init__( if plugins is None: plugins = [] - adapt_legacy_args = slda_model == "not_set" - slda_model = slda_model if slda_model != "not_set" else args[0] slda_model = slda_model.eval() if output_layer_name is not None: slda_model = FeatureExtractorBackbone( slda_model.to(device), output_layer_name ).eval() - # Legacy positional arguments support (deprecation cycle) - if adapt_legacy_args: - args = list(args) - args[0] = slda_model - - input_size = input_size if input_size != "not_set" else args[2] - num_classes = num_classes if num_classes != "not_set" else args[3] - super(StreamingLDA, self).__init__( - legacy_positional_args=args, model=slda_model, optimizer=None, # type: ignore criterion=criterion, diff --git a/avalanche/training/supervised/der.py b/avalanche/training/supervised/der.py index bbc4b2186..7df915748 100644 --- a/avalanche/training/supervised/der.py +++ b/avalanche/training/supervised/der.py @@ -157,9 +157,9 @@ class DER(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", + *, + model: Module, + optimizer: Optimizer, criterion: CriterionType = CrossEntropyLoss(), mem_size: int = 200, batch_size_mem: Optional[int] = None, @@ -211,7 +211,6 @@ def __init__( `eval_every` experience or iterations (Default='experience'). """ super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, diff --git a/avalanche/training/supervised/er_ace.py b/avalanche/training/supervised/er_ace.py index d68044330..d9a260e15 100644 --- a/avalanche/training/supervised/er_ace.py +++ b/avalanche/training/supervised/er_ace.py @@ -33,9 +33,9 @@ class ER_ACE(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", + *, + model: Module, + optimizer: Optimizer, criterion: CriterionType = CrossEntropyLoss(), mem_size: int = 200, batch_size_mem: int = 10, @@ -74,7 +74,6 @@ def __init__( `eval_every` epochs or iterations (Default='epoch'). """ super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, diff --git a/avalanche/training/supervised/er_aml.py b/avalanche/training/supervised/er_aml.py index efff025c0..a79b574e0 100644 --- a/avalanche/training/supervised/er_aml.py +++ b/avalanche/training/supervised/er_aml.py @@ -28,10 +28,10 @@ class ER_AML(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - feature_extractor: Module = "not_set", - optimizer: Optimizer = "not_set", + *, + model: Module, + feature_extractor: Module, + optimizer: Optimizer, criterion: CriterionType = CrossEntropyLoss(), temp: float = 0.1, base_temp: float = 0.07, @@ -77,7 +77,6 @@ def __init__( `eval_every` epochs or iterations (Default='epoch'). """ super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -98,9 +97,7 @@ def __init__( ) self.replay_loader = None self.aml_criterion = AMLCriterion( - feature_extractor=( - feature_extractor if feature_extractor != "not_set" else args[1] - ), + feature_extractor=feature_extractor, temp=temp, base_temp=base_temp, same_task_neg=same_task_neg, diff --git a/avalanche/training/supervised/expert_gate.py b/avalanche/training/supervised/expert_gate.py index 21c8b0fca..fd04b7ad7 100644 --- a/avalanche/training/supervised/expert_gate.py +++ b/avalanche/training/supervised/expert_gate.py @@ -42,9 +42,9 @@ class ExpertGateStrategy(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", + *, + model: Module, + optimizer: Optimizer, criterion: CriterionType = CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, @@ -109,7 +109,6 @@ def __init__( ) super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, diff --git a/avalanche/training/supervised/feature_replay.py b/avalanche/training/supervised/feature_replay.py index be7293daf..04e355f4e 100644 --- a/avalanche/training/supervised/feature_replay.py +++ b/avalanche/training/supervised/feature_replay.py @@ -47,9 +47,9 @@ class FeatureReplay(SupervisedTemplate): def __init__( self, - *args, - model: nn.Module = "not_set", - optimizer: Optimizer = "not_set", + *, + model: nn.Module, + optimizer: Optimizer, criterion: CriterionType = MaskedCrossEntropy(), last_layer_name: str = "classifier", mem_size: int = 200, @@ -67,7 +67,6 @@ def __init__( **kwargs ): super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, diff --git a/avalanche/training/supervised/icarl.py b/avalanche/training/supervised/icarl.py index fb81d8c17..fa292ba73 100644 --- a/avalanche/training/supervised/icarl.py +++ b/avalanche/training/supervised/icarl.py @@ -28,13 +28,13 @@ class ICaRL(SupervisedTemplate): def __init__( self, - *args, - feature_extractor: Module = "not_set", - classifier: Module = "not_set", - optimizer: Optimizer = "not_set", - memory_size="not_set", - buffer_transform="not_set", - fixed_memory="not_set", + *, + feature_extractor: Module, + classifier: Module, + optimizer: Optimizer, + memory_size: int, + buffer_transform, + fixed_memory: bool, train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = None, @@ -73,16 +73,16 @@ def __init__( learning experience. """ model = TrainEvalModel( - feature_extractor if feature_extractor != "not_set" else args[0], - train_classifier=classifier if classifier != "not_set" else args[1], + feature_extractor, + train_classifier=classifier, eval_classifier=NCMClassifier(normalize=True), ) criterion = ICaRLLossPlugin() # iCaRL requires this specific loss (#966) icarl = _ICaRLPlugin( - memory_size if memory_size != "not_set" else args[3], - buffer_transform if buffer_transform != "not_set" else args[4], - fixed_memory if fixed_memory != "not_set" else args[5], + memory_size, + buffer_transform, + fixed_memory, ) if plugins is None: @@ -94,7 +94,6 @@ def __init__( plugins += [criterion] super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, diff --git a/avalanche/training/supervised/joint_training.py b/avalanche/training/supervised/joint_training.py index a52374a30..9b6ca3dc4 100644 --- a/avalanche/training/supervised/joint_training.py +++ b/avalanche/training/supervised/joint_training.py @@ -58,10 +58,10 @@ class JointTraining(SupervisedTemplate[TDatasetExperience, TMBInput, TMBOutput]) def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", - criterion: CriterionType = "not_set", + *, + model: Module, + optimizer: Optimizer, + criterion: CriterionType, train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: int = 1, @@ -91,7 +91,6 @@ def __init__( `eval` is called every `eval_every` epochs and at the end of the learning experience.""" super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, diff --git a/avalanche/training/supervised/l2p.py b/avalanche/training/supervised/l2p.py index 410df3b57..2d519991a 100644 --- a/avalanche/training/supervised/l2p.py +++ b/avalanche/training/supervised/l2p.py @@ -31,8 +31,8 @@ class LearningToPrompt(SupervisedTemplate): def __init__( self, - *args, - model_name: str = "not_set", + *, + model_name: str, criterion: nn.Module = nn.CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, @@ -99,7 +99,7 @@ def __init__( self.lr = lr self.sim_coefficient = sim_coefficient model = create_model( - model_name=model_name if model_name != "not_set" else args[0], + model_name=model_name, prompt_pool=prompt_pool, pool_size=pool_size, prompt_length=prompt_length, @@ -131,7 +131,6 @@ def __init__( ) super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, diff --git a/avalanche/training/supervised/lamaml.py b/avalanche/training/supervised/lamaml.py index 35eb0cc46..d727a778f 100644 --- a/avalanche/training/supervised/lamaml.py +++ b/avalanche/training/supervised/lamaml.py @@ -27,9 +27,9 @@ class LaMAML(SupervisedMetaLearningTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", + *, + model: Module, + optimizer: Optimizer, criterion: CriterionType = CrossEntropyLoss(), n_inner_updates: int = 5, second_order: bool = True, diff --git a/avalanche/training/supervised/lamaml_v2.py b/avalanche/training/supervised/lamaml_v2.py index fc0492aee..42fd8bfbf 100644 --- a/avalanche/training/supervised/lamaml_v2.py +++ b/avalanche/training/supervised/lamaml_v2.py @@ -24,9 +24,9 @@ class LaMAML(SupervisedMetaLearningTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", + *, + model: Module, + optimizer: Optimizer, criterion: CriterionType = CrossEntropyLoss(), n_inner_updates: int = 5, second_order: bool = True, diff --git a/avalanche/training/supervised/mer.py b/avalanche/training/supervised/mer.py index b546eb5d2..a834a5930 100644 --- a/avalanche/training/supervised/mer.py +++ b/avalanche/training/supervised/mer.py @@ -52,9 +52,9 @@ def get_batch(self, x, y, t): class MER(SupervisedMetaLearningTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", + *, + model: Module, + optimizer: Optimizer, criterion: CriterionType = CrossEntropyLoss(), mem_size=200, batch_size_mem=10, @@ -88,7 +88,6 @@ def __init__( """ super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, diff --git a/avalanche/training/supervised/strategy_wrappers.py b/avalanche/training/supervised/strategy_wrappers.py index f8d6ffce7..c8ddcab44 100644 --- a/avalanche/training/supervised/strategy_wrappers.py +++ b/avalanche/training/supervised/strategy_wrappers.py @@ -66,9 +66,9 @@ class Naive(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", + *, + model: Module, + optimizer: Optimizer, criterion: CriterionType = CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, @@ -104,7 +104,6 @@ def __init__( """ super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -127,9 +126,9 @@ class PNNStrategy(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", + *, + model: Module, + optimizer: Optimizer, criterion=CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, @@ -164,7 +163,6 @@ def __init__( """ super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -203,11 +201,11 @@ class PackNet(SupervisedTemplate): def __init__( self, - *args, - model: Union[PackNetModule, PackNetModel] = "not_set", - optimizer: Optimizer = "not_set", - post_prune_epochs: int = "not_set", - prune_proportion: float = "not_set", + *, + model: Union[PackNetModule, PackNetModel], + optimizer: Optimizer, + post_prune_epochs: int, + prune_proportion: float, criterion=CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, @@ -250,7 +248,6 @@ def __init__( """ super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -273,10 +270,10 @@ class CWRStar(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", - criterion: CriterionType = "not_set", + *, + model: Module, + optimizer: Optimizer, + criterion: CriterionType, cwr_layer_name: str, train_mb_size: int = 1, train_epochs: int = 1, @@ -313,7 +310,7 @@ def __init__( """ cwsp = CWRStarPlugin( - model if model != "not_set" else args[0], + model, cwr_layer_name, freeze_remaining_model=True, ) @@ -322,7 +319,6 @@ def __init__( else: plugins.append(cwsp) super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -346,10 +342,10 @@ class Replay(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", - criterion: CriterionType = "not_set", + *, + model: Module, + optimizer: Optimizer, + criterion: CriterionType, mem_size: int = 200, train_mb_size: int = 1, train_epochs: int = 1, @@ -390,7 +386,6 @@ def __init__( else: plugins.append(rp) super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -426,9 +421,9 @@ class GenerativeReplay(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", + *, + model: Module, + optimizer: Optimizer, criterion: CriterionType = CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, @@ -523,7 +518,6 @@ def __init__( plugins.append(rp) super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -557,7 +551,7 @@ class AETraining(SupervisedTemplate): def __init__( self, - *args, + *, model: Module, optimizer: Optimizer, device, @@ -593,7 +587,6 @@ def __init__( """ super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -631,9 +624,9 @@ class VAETraining(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", + *, + model: Module, + optimizer: Optimizer, criterion=VAE_loss, train_mb_size: int = 1, train_epochs: int = 1, @@ -669,7 +662,6 @@ def __init__( """ super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -698,10 +690,10 @@ class GSS_greedy(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", - criterion: CriterionType = "not_set", + *, + model: Module, + optimizer: Optimizer, + criterion: CriterionType, mem_size: int = 200, mem_strength=1, input_size=[], @@ -769,10 +761,10 @@ class GDumb(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", - criterion: CriterionType = "not_set", + *, + model: Module, + optimizer: Optimizer, + criterion: CriterionType, mem_size: int = 200, train_mb_size: int = 1, train_epochs: int = 1, @@ -814,7 +806,6 @@ def __init__( plugins.append(gdumb) super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -837,10 +828,10 @@ class LwF(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", - criterion: CriterionType = "not_set", + *, + model: Module, + optimizer: Optimizer, + criterion: CriterionType, alpha: Union[float, Sequence[float]], temperature: float, train_mb_size: int = 1, @@ -885,7 +876,6 @@ def __init__( plugins.append(lwf) super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -909,10 +899,10 @@ class AGEM(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", - criterion: CriterionType = "not_set", + *, + model: Module, + optimizer: Optimizer, + criterion: CriterionType, patterns_per_exp: int, sample_size: int = 64, train_mb_size: int = 1, @@ -957,7 +947,6 @@ def __init__( plugins.append(agem) super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -981,10 +970,10 @@ class GEM(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", - criterion: CriterionType = "not_set", + *, + model: Module, + optimizer: Optimizer, + criterion: CriterionType, patterns_per_exp: int, memory_strength: float = 0.5, train_mb_size: int = 1, @@ -1029,7 +1018,6 @@ def __init__( plugins.append(gem) super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -1053,10 +1041,10 @@ class EWC(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", - criterion: CriterionType = "not_set", + *, + model: Module, + optimizer: Optimizer, + criterion: CriterionType, ewc_lambda: float, mode: str = "separate", decay_factor: Optional[float] = None, @@ -1112,7 +1100,6 @@ def __init__( plugins.append(ewc) super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -1145,10 +1132,10 @@ class SynapticIntelligence(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", - criterion: CriterionType = "not_set", + *, + model: Module, + optimizer: Optimizer, + criterion: CriterionType, si_lambda: Union[float, Sequence[float]], eps: float = 0.0000001, train_mb_size: int = 1, @@ -1199,7 +1186,6 @@ def __init__( plugins.append(SynapticIntelligencePlugin(si_lambda=si_lambda, eps=eps)) super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -1223,10 +1209,10 @@ class CoPE(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", - criterion: CriterionType = "not_set", + *, + model: Module, + optimizer: Optimizer, + criterion: CriterionType, mem_size: int = 200, n_classes: int = 10, p_size: int = 100, @@ -1279,7 +1265,6 @@ def __init__( else: plugins.append(copep) super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -1304,10 +1289,10 @@ class LFL(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", - criterion: CriterionType = "not_set", + *, + model: Module, + optimizer: Optimizer, + criterion: CriterionType, lambda_e: Union[float, Sequence[float]], train_mb_size: int = 1, train_epochs: int = 1, @@ -1349,7 +1334,6 @@ def __init__( plugins.append(lfl) super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -1373,10 +1357,10 @@ class MAS(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", - criterion: CriterionType = "not_set", + *, + model: Module, + optimizer: Optimizer, + criterion: CriterionType, lambda_reg: float = 1.0, alpha: float = 0.5, verbose: bool = False, @@ -1428,7 +1412,6 @@ def __init__( plugins.append(mas) super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -1452,10 +1435,10 @@ class BiC(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", - criterion: CriterionType = "not_set", + *, + model: Module, + optimizer: Optimizer, + criterion: CriterionType, mem_size: int = 200, val_percentage: float = 0.1, T: int = 2, @@ -1522,7 +1505,6 @@ def __init__( plugins.append(bic) super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -1544,12 +1526,12 @@ class MIR(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", - criterion: CriterionType = "not_set", - mem_size: int = "not_set", - subsample: int = "not_set", + *, + model: Module, + optimizer: Optimizer, + criterion: CriterionType, + mem_size: int, + subsample: int, batch_size_mem: int = 1, train_mb_size: int = 1, train_epochs: int = 1, @@ -1589,8 +1571,8 @@ def __init__( # Instantiate plugin mir = MIRPlugin( - mem_size=mem_size if mem_size != "not_set" else args[3], - subsample=subsample if subsample != "not_set" else args[5], + mem_size=mem_size, + subsample=subsample, batch_size_mem=batch_size_mem, ) @@ -1601,7 +1583,6 @@ def __init__( plugins.append(mir) super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, @@ -1627,10 +1608,10 @@ class FromScratchTraining(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", - criterion: CriterionType = "not_set", + *, + model: Module, + optimizer: Optimizer, + criterion: CriterionType, reset_optimizer: bool = True, train_mb_size: int = 1, train_epochs: int = 1, diff --git a/avalanche/training/supervised/strategy_wrappers_online.py b/avalanche/training/supervised/strategy_wrappers_online.py index b4fb71d39..0210eb005 100644 --- a/avalanche/training/supervised/strategy_wrappers_online.py +++ b/avalanche/training/supervised/strategy_wrappers_online.py @@ -44,9 +44,9 @@ class OnlineNaive(SupervisedTemplate): def __init__( self, - *args, - model: Module = "not_set", - optimizer: Optimizer = "not_set", + *, + model: Module, + optimizer: Optimizer, criterion: CriterionType = CrossEntropyLoss(), train_passes: int = 1, train_mb_size: int = 1, @@ -80,7 +80,6 @@ def __init__( learning experience. """ super().__init__( - legacy_positional_args=args, model=model, optimizer=optimizer, criterion=criterion, diff --git a/avalanche/training/templates/base.py b/avalanche/training/templates/base.py index 26379272e..c1d85957a 100644 --- a/avalanche/training/templates/base.py +++ b/avalanche/training/templates/base.py @@ -1,7 +1,17 @@ import sys import warnings from collections import defaultdict -from typing import Iterable, Sequence, Optional, TypeVar, Union, List +from typing import ( + Any, + Dict, + Iterable, + OrderedDict, + Sequence, + Optional, + TypeVar, + Union, + List, +) import torch from torch.nn import Module @@ -11,12 +21,211 @@ from avalanche.distributed.distributed_helper import DistributedHelper from avalanche.training.templates.strategy_mixin_protocol import BaseStrategyProtocol from avalanche.training.utils import trigger_plugins - +import functools +import inspect TExperienceType = TypeVar("TExperienceType", bound=CLExperience) TPluginType = TypeVar("TPluginType", bound=BasePlugin, contravariant=True) +def _warn_init_has_positional_args(init_method, class_name): + init_args = inspect.signature(init_method).parameters + positional_args = [ + k + for k, v in init_args.items() + if v.kind + in { + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.VAR_POSITIONAL, + } + ] + if len(positional_args) > 1: # self is always present + warnings.warn( + f"Avalanche is transitioning to strategy constructors that accept named (keyword) arguments only. " + f"This is done to ensure that there is no confusion regarding the meaning of each argument (strategies can have many arguments). " + f"Your strategy {class_name}.__init__ method still has some positional-only or " + f"positional-or-keyword arguments. Consider removing them. Offending arguments: {positional_args}. " + f"This can be achieved by adding a * in the argument list of your __init__ method just after 'self'. " + f"More info: https://peps.python.org/pep-3102/#specification" + ) + + +def _merge_legacy_positional_arguments( + init_method, + class_name: str, + args: Sequence[Any], + kwargs: Dict[str, Any], + strict_init_check=True, + allow_pos_args=True, +): + """ + Manage the legacy positional constructor parameters. + + Used to warn the user when passing positional parameters to strategy constructors + (which is deprecated). + + To allow for a smooth transition, we allow the user to pass positional + arguments to the constructor. However, we warn the user that + this soft transition mechanism will be removed in the future. + """ + + if len(args) == 1: + # No positional argument has been passed (good!) + return args, kwargs + elif len(args) == 0: + # This should never happen and will fail later. + # assert len(args) == 0, "At least the 'self' argument should be passed" + return args, kwargs + + all_init_args = dict(inspect.signature(init_method).parameters) + + # Remove 'self' from the list of arguments + all_init_args.pop("self") + + # Divide parameters in groups + pos_only_args = [ + (k, v) + for k, v in all_init_args.items() + if v.kind == inspect.Parameter.POSITIONAL_ONLY + ] + pos_or_keyword = [ + (k, v) + for k, v in all_init_args.items() + if v.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + ] + var_pos = [ + (k, v) + for k, v in all_init_args.items() + if v.kind == inspect.Parameter.VAR_POSITIONAL + ] + keyword_only_args = [ + (k, v) + for k, v in all_init_args.items() + if v.kind == inspect.Parameter.KEYWORD_ONLY + ] + var_keyword_args = [ + (k, v) + for k, v in all_init_args.items() + if v.kind == inspect.Parameter.VAR_KEYWORD + ] + + error_str = ( + f"Avalanche is transitioning to strategy constructors that accept named (keyword) arguments only. " + f"This is done to ensure that there is no confusion regarding the meaning of each argument (strategies can have many arguments). " + f"Your are passing {len(args) - 1} positional arguments to the {class_name}.__init__ method. " + f"Consider passing them as names arguments." + ) + + if allow_pos_args: + error_str += ( + " The ability to pass positional arguments will be removed in the future." + ) + warnings.warn(error_str, category=PositionalArgumentDeprecatedWarning) + else: + raise PositionalArgumentDeprecatedWarning(error_str) + + args_to_manage = list(args) + kwargs_to_manage = dict(kwargs) + + result_args = [args_to_manage.pop(0)] # Add self + result_kwargs = OrderedDict() + + unset_arguments = set(all_init_args.keys()) + + for argument_name, arg_def in pos_only_args: + if len(args_to_manage) > 0: + result_args.append(args_to_manage.pop(0)) + elif arg_def.default is inspect.Parameter.empty: + raise ValueError( + f"Positional-only argument {argument_name} is not set (and no default set)." + ) + unset_arguments.remove(argument_name) + + for argument_name, arg_def in pos_or_keyword: + if len(args_to_manage) > 0: + result_args.append(args_to_manage.pop(0)) + elif argument_name in kwargs_to_manage: + result_kwargs[argument_name] = kwargs_to_manage.pop(argument_name) + elif arg_def.default is inspect.Parameter.empty: + raise ValueError( + f"Parameter {argument_name} is not set (and no default provided)." + ) + + if argument_name not in unset_arguments: + # This is the same error and message raised by Python when passing + # multiple values for an argument. + raise TypeError(f"Got multiple values for argument '{argument_name}'") + unset_arguments.remove(argument_name) + + if len(var_pos) > 0: + # assert len(var_pos) == 1, "Only one var-positional argument is supported" + argument_name = var_pos[0][0] + if len(args_to_manage) > 0: + result_args.extend(args_to_manage) + args_to_manage = list() + + if argument_name not in unset_arguments: + # This is the same error and message raised by Python when passing + # multiple values for an argument. + raise TypeError(f"Got multiple values for argument '{argument_name}'") + + unset_arguments.remove(argument_name) + + for argument_name, arg_def in keyword_only_args: + if len(args_to_manage) > 0 and argument_name in kwargs_to_manage: + raise TypeError( + f"Got multiple values for argument '{argument_name}' (passed as both positional and named parameter)" + ) + + if len(args_to_manage) > 0: + # This is where the soft transition mechanism is implemented. + # The legacy positional arguments are transformed to keyword arguments. + result_kwargs[argument_name] = args_to_manage.pop(0) + elif argument_name in kwargs_to_manage: + result_kwargs[argument_name] = kwargs_to_manage.pop(argument_name) + elif arg_def.default is inspect.Parameter.empty: + raise ValueError( + f"Keyword-only parameter {argument_name} is not set (and no default set)." + ) + + if argument_name not in unset_arguments: + # This is the same error and message raised by Python when passing + # multiple values for an argument. + raise TypeError(f"Got multiple values for argument '{argument_name}'") + + unset_arguments.remove(argument_name) + + if len(var_keyword_args) > 0: + # assert len(var_keyword_args) == 1, "Only one var-keyword argument is supported" + argument_name = var_keyword_args[0][0] + result_kwargs.update(kwargs_to_manage) + kwargs_to_manage = dict() + + if argument_name not in unset_arguments: + # This is the same error and message raised by Python when passing + # multiple values for an argument. + raise TypeError(f"Got multiple values for argument '{argument_name}'") + unset_arguments.remove(argument_name) + + assert len(unset_arguments) == 0 + + return result_args, result_kwargs + + +def _support_legacy_strategy_positional_args(func, cls_name): + @functools.wraps(func) + def wrap_init(*args, **kwargs): + # print("wrap_init", cls_name, args[0]) + _warn_init_has_positional_args(func, cls_name) + args, kwargs = _merge_legacy_positional_arguments( + func, cls_name, args, kwargs, strict_init_check=False, allow_pos_args=True + ) + return func(*args, **kwargs) + + return wrap_init + + class BaseTemplate(BaseStrategyProtocol[TExperienceType]): """Base class for continual learning skeletons. @@ -34,6 +243,12 @@ class BaseTemplate(BaseStrategyProtocol[TExperienceType]): """ + def __init_subclass__(cls, **kwargs): + # This is needed to manage the transition to keyword-only arguments. + cls.__init__ = _support_legacy_strategy_positional_args( + cls.__init__, cls.__name__ + ) + # we need this only for type checking PLUGIN_CLASS = BasePlugin @@ -353,3 +568,7 @@ def _experiences_parameter_as_iterable( __all__ = ["BaseTemplate"] + + +class PositionalArgumentDeprecatedWarning(UserWarning): + pass diff --git a/avalanche/training/templates/base_sgd.py b/avalanche/training/templates/base_sgd.py index 600af3361..e0133179c 100644 --- a/avalanche/training/templates/base_sgd.py +++ b/avalanche/training/templates/base_sgd.py @@ -9,7 +9,7 @@ from torch import Tensor from avalanche.benchmarks import CLExperience, CLStream -from avalanche.benchmarks.scenarios.deprecated import DatasetExperience +from avalanche.benchmarks import DatasetExperience from avalanche.benchmarks.utils.data import AvalancheDataset from avalanche.core import BasePlugin, BaseSGDPlugin from avalanche.training.plugins import EvaluationPlugin diff --git a/avalanche/training/templates/common_templates.py b/avalanche/training/templates/common_templates.py index d15f00bcc..7dbde7e9a 100644 --- a/avalanche/training/templates/common_templates.py +++ b/avalanche/training/templates/common_templates.py @@ -1,21 +1,23 @@ import sys -from typing import Any, Callable, Dict, List, Sequence, Optional, TypeVar, Union +from typing import Any, Callable, Dict, Sequence, Optional, TypeVar, Union import warnings import torch import inspect from torch.nn import Module, CrossEntropyLoss from torch.optim import Optimizer -from torch import Tensor -from ...benchmarks.scenarios.deprecated import DatasetExperience +from avalanche.benchmarks import DatasetExperience -from avalanche.core import BasePlugin, BaseSGDPlugin, SupervisedPlugin +from avalanche.core import BasePlugin, SupervisedPlugin from avalanche.training.plugins.evaluation import ( EvaluationPlugin, default_evaluator, ) +from avalanche.training.templates.base import PositionalArgumentDeprecatedWarning from avalanche.training.templates.strategy_mixin_protocol import ( SupervisedStrategyProtocol, + TMBOutput, + TMBInput, ) from .observation_type import * @@ -25,112 +27,12 @@ TDatasetExperience = TypeVar("TDatasetExperience", bound=DatasetExperience) -TMBInput = TypeVar("TMBInput") -TMBOutput = TypeVar("TMBOutput") - - -class PositionalArgumentDeprecatedWarning(UserWarning): - pass - - -def _merge_legacy_positional_arguments( - self, - args: Optional[Sequence[Any]], - kwargs: Dict[str, Any], - strict_init_check=True, - allow_pos_args=True, -): - """ - Manage the legacy positional constructor parameters. - - Used to warn the user about the deprecation of positional parameters - in strategy constructors. - - To allow for a smooth transition, we allow the user to pass positional - arguments to the constructor. However, we want to warn the user that - this is deprecated and will be removed in the future. - """ - - init_method = getattr(self, "__init__") - init_kwargs = inspect.signature(init_method).parameters - - if self.__class__ in [SupervisedTemplate, SupervisedMetaLearningTemplate]: - return - - has_transition_varargs = any( - x.kind == inspect.Parameter.VAR_POSITIONAL for x in init_kwargs.values() - ) - if (not has_transition_varargs and args is not None) or ( - has_transition_varargs and args is None - ): - error_str = ( - "While trainsitioning to the new keyword-only strategy constructors, " - "the ability to pass positional arguments (as in older versions of Avalanche) should be granted. " - "You should catch all positional arguments (excluding self) using *args and do not declare other positional arguments! " - "Then, pass them to SupervisedTemplate/SupervisedMetaLearningTemplate super constructor as the legacy_args argument." - "Those legacy positional arguments will then be converted to keyword arguments " - "according to the declaration order of those keyword arguments." - ) - - if strict_init_check: - raise PositionalArgumentDeprecatedWarning(error_str) - else: - warnings.warn(error_str, category=PositionalArgumentDeprecatedWarning) - - positional_args = { - (k, x) - for k, x in init_kwargs.items() - if x.kind - in [inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD] - } - - if len(positional_args) > 0: - all_positional_args_names_as_str = ", ".join([x[0] for x in positional_args]) - error_str = ( - "Avalanche is transitioning to strategy constructors that allow for keyword arguments only. " - "Your strategy __init__ method still has some positional-only or positional-or-keyword arguments. " - "Consider removing them. Offending arguments: " - + all_positional_args_names_as_str - ) - - if strict_init_check: - raise PositionalArgumentDeprecatedWarning(error_str) - else: - warnings.warn(error_str, category=PositionalArgumentDeprecatedWarning) - - # Discard all non-keyword params (also exclude VAR_KEYWORD) - kwargs_order = [ - k for k, v in init_kwargs.items() if v.kind == inspect.Parameter.KEYWORD_ONLY - ] - - print(init_kwargs) - - if len(args) > 0: - - error_str = ( - "Passing positional arguments to Strategy constructors is " - "deprecated. Please use keyword arguments instead." - ) - - if allow_pos_args: - warnings.warn(error_str, category=PositionalArgumentDeprecatedWarning) - else: - raise PositionalArgumentDeprecatedWarning(error_str) - - for i, arg in enumerate(args): - kwargs[kwargs_order[i]] = arg - - for key, value in kwargs.items(): - if value == "not_set": - raise ValueError(f"Parameter {key} is not set") - - return kwargs class SupervisedTemplate( - BatchObservation, - SupervisedProblem, - SGDUpdate, + BatchObservation[TDatasetExperience, TMBInput, TMBInput], + SupervisedProblem[TDatasetExperience, TMBInput, TMBInput], + SGDUpdate[TDatasetExperience, TMBInput, TMBOutput], SupervisedStrategyProtocol[TDatasetExperience, TMBInput, TMBOutput], BaseSGDTemplate[TDatasetExperience, TMBInput, TMBOutput], ): @@ -184,9 +86,8 @@ class SupervisedTemplate( def __init__( self, *, - legacy_positional_args: Optional[Sequence[Any]], - model: Module = "not_set", - optimizer: Optimizer = "not_set", + model: Module, + optimizer: Optimizer, criterion: CriterionType = CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, @@ -236,10 +137,6 @@ def __init__( kwargs["eval_every"] = eval_every kwargs["peval_mode"] = peval_mode - kwargs = _merge_legacy_positional_arguments( - self, legacy_positional_args, kwargs - ) - if sys.version_info >= (3, 11): super().__init__(**kwargs) else: @@ -248,9 +145,9 @@ def __init__( class SupervisedMetaLearningTemplate( - BatchObservation, - SupervisedProblem, - MetaUpdate, + BatchObservation[TDatasetExperience, TMBInput, TMBInput], + SupervisedProblem[TDatasetExperience, TMBInput, TMBInput], + MetaUpdate[TDatasetExperience, TMBInput, TMBInput], BaseSGDTemplate[TDatasetExperience, TMBInput, TMBOutput], ): """Base class for continual learning strategies. @@ -303,9 +200,8 @@ class SupervisedMetaLearningTemplate( def __init__( self, *, - legacy_positional_args: Optional[Sequence[Any]], - model: Module = "not_set", - optimizer: Optimizer = "not_set", + model: Module, + optimizer: Optimizer, criterion=CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, @@ -355,10 +251,6 @@ def __init__( kwargs["eval_every"] = eval_every kwargs["peval_mode"] = peval_mode - kwargs = _merge_legacy_positional_arguments( - self, legacy_positional_args, kwargs - ) - if sys.version_info >= (3, 11): super().__init__(**kwargs) else: diff --git a/avalanche/training/templates/observation_type/batch_observation.py b/avalanche/training/templates/observation_type/batch_observation.py index 4a8dde613..374d71c45 100644 --- a/avalanche/training/templates/observation_type/batch_observation.py +++ b/avalanche/training/templates/observation_type/batch_observation.py @@ -1,13 +1,18 @@ from avalanche.benchmarks import OnlineCLExperience from avalanche.models.utils import avalanche_model_adaptation -from avalanche.training.templates.strategy_mixin_protocol import SGDStrategyProtocol +from avalanche.training.templates.strategy_mixin_protocol import ( + SGDStrategyProtocol, + TSGDExperienceType, + TMBInput, + TMBOutput, +) from avalanche.models.dynamic_optimizers import reset_optimizer, update_optimizer from avalanche.training.utils import _at_task_boundary -class BatchObservation(SGDStrategyProtocol): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) +class BatchObservation(SGDStrategyProtocol[TSGDExperienceType, TMBInput, TMBOutput]): + def __init__(self, **kwargs): + super().__init__(**kwargs) self.optimized_param_id = None def model_adaptation(self, model=None): diff --git a/avalanche/training/templates/problem_type/supervised_problem.py b/avalanche/training/templates/problem_type/supervised_problem.py index 2479dfe05..1de5639ab 100644 --- a/avalanche/training/templates/problem_type/supervised_problem.py +++ b/avalanche/training/templates/problem_type/supervised_problem.py @@ -1,15 +1,20 @@ from avalanche.models import avalanche_forward from avalanche.training.templates.strategy_mixin_protocol import ( SupervisedStrategyProtocol, + TSGDExperienceType, + TMBInput, + TMBOutput, ) # Types are perfectly ok for MyPy # Also confirmed here: https://stackoverflow.com/a/70907644 # PyLance just does not understand it -class SupervisedProblem(SupervisedStrategyProtocol): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) +class SupervisedProblem( + SupervisedStrategyProtocol[TSGDExperienceType, TMBInput, TMBOutput] +): + def __init__(self, **kwargs): + super().__init__(**kwargs) @property def mb_x(self): diff --git a/avalanche/training/templates/strategy_mixin_protocol.py b/avalanche/training/templates/strategy_mixin_protocol.py index 98f664f7e..c2635f4ad 100644 --- a/avalanche/training/templates/strategy_mixin_protocol.py +++ b/avalanche/training/templates/strategy_mixin_protocol.py @@ -14,8 +14,8 @@ TExperienceType = TypeVar("TExperienceType", bound=CLExperience) TSGDExperienceType = TypeVar("TSGDExperienceType", bound=DatasetExperience) -TMBinput = TypeVar("TMBinput") -TMBoutput = TypeVar("TMBoutput") +TMBInput = TypeVar("TMBInput") +TMBOutput = TypeVar("TMBOutput") class BaseStrategyProtocol(Protocol[TExperienceType]): @@ -34,17 +34,17 @@ class BaseStrategyProtocol(Protocol[TExperienceType]): class SGDStrategyProtocol( BaseStrategyProtocol[TSGDExperienceType], - Protocol[TSGDExperienceType, TMBinput, TMBoutput], + Protocol[TSGDExperienceType, TMBInput, TMBOutput], ): """ A protocol for strategies to be used for typing mixin classes. """ - mbatch: Optional[TMBinput] + mbatch: Optional[TMBInput] - mb_output: Optional[TMBoutput] + mb_output: Optional[TMBOutput] - dataloader: Iterable[TMBinput] + dataloader: Iterable[TMBInput] _stop_training: bool @@ -54,7 +54,7 @@ class SGDStrategyProtocol( _criterion: Module - def forward(self) -> TMBoutput: ... + def forward(self) -> TMBOutput: ... def criterion(self) -> Tensor: ... @@ -88,8 +88,8 @@ def _after_training_iteration(self, **kwargs): ... class SupervisedStrategyProtocol( - SGDStrategyProtocol[TSGDExperienceType, TMBinput, TMBoutput], - Protocol[TSGDExperienceType, TMBinput, TMBoutput], + SGDStrategyProtocol[TSGDExperienceType, TMBInput, TMBOutput], + Protocol[TSGDExperienceType, TMBInput, TMBOutput], ): mb_x: Tensor @@ -99,8 +99,8 @@ class SupervisedStrategyProtocol( class MetaLearningStrategyProtocol( - SGDStrategyProtocol[TSGDExperienceType, TMBinput, TMBoutput], - Protocol[TSGDExperienceType, TMBinput, TMBoutput], + SGDStrategyProtocol[TSGDExperienceType, TMBInput, TMBOutput], + Protocol[TSGDExperienceType, TMBInput, TMBOutput], ): def _before_inner_updates(self, **kwargs): ... diff --git a/avalanche/training/templates/update_type/meta_update.py b/avalanche/training/templates/update_type/meta_update.py index 2217d2d07..b74c0066e 100644 --- a/avalanche/training/templates/update_type/meta_update.py +++ b/avalanche/training/templates/update_type/meta_update.py @@ -1,12 +1,15 @@ from avalanche.training.templates.strategy_mixin_protocol import ( MetaLearningStrategyProtocol, + TSGDExperienceType, + TMBInput, + TMBOutput, ) from avalanche.training.utils import trigger_plugins -class MetaUpdate(MetaLearningStrategyProtocol): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) +class MetaUpdate(MetaLearningStrategyProtocol[TSGDExperienceType, TMBInput, TMBOutput]): + def __init__(self, **kwargs): + super().__init__(**kwargs) def training_epoch(self, **kwargs): """Training epoch. diff --git a/avalanche/training/templates/update_type/sgd_update.py b/avalanche/training/templates/update_type/sgd_update.py index f44343367..934c9cfe6 100644 --- a/avalanche/training/templates/update_type/sgd_update.py +++ b/avalanche/training/templates/update_type/sgd_update.py @@ -1,9 +1,14 @@ -from avalanche.training.templates.strategy_mixin_protocol import SGDStrategyProtocol +from avalanche.training.templates.strategy_mixin_protocol import ( + SGDStrategyProtocol, + TSGDExperienceType, + TMBInput, + TMBOutput, +) -class SGDUpdate(SGDStrategyProtocol): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) +class SGDUpdate(SGDStrategyProtocol[TSGDExperienceType, TMBInput, TMBOutput]): + def __init__(self, **kwargs): + super().__init__(**kwargs) def training_epoch(self, **kwargs): """Training epoch. diff --git a/examples/naive.py b/examples/naive.py index 8df7a9a2e..f03c1fc6f 100644 --- a/examples/naive.py +++ b/examples/naive.py @@ -42,7 +42,7 @@ def main(): model, optimizer, criterion, - 1, + train_epochs=1, device=device, train_mb_size=32, evaluator=eval_plugin, diff --git a/tests/models/test_packnet.py b/tests/models/test_packnet.py index f87992cbd..5b423eca9 100644 --- a/tests/models/test_packnet.py +++ b/tests/models/test_packnet.py @@ -1,7 +1,7 @@ import unittest from avalanche.models.packnet import PackNetModel, packnet_simple_mlp from avalanche.training.supervised.strategy_wrappers import PackNet -from avalanche.training.templates.common_templates import ( +from avalanche.training.templates.base import ( PositionalArgumentDeprecatedWarning, ) from torch.optim import SGD, Adam diff --git a/tests/training/test_online_strategies.py b/tests/training/test_online_strategies.py index aec164887..852fd0db6 100644 --- a/tests/training/test_online_strategies.py +++ b/tests/training/test_online_strategies.py @@ -9,7 +9,7 @@ from avalanche.models import SimpleMLP from avalanche.benchmarks.scenarios.online import OnlineCLScenario from avalanche.training import OnlineNaive -from avalanche.training.templates.common_templates import ( +from avalanche.training.templates.base import ( PositionalArgumentDeprecatedWarning, ) from tests.unit_tests_utils import get_fast_benchmark diff --git a/tests/training/test_strategies.py b/tests/training/test_strategies.py index 095c7224a..30bbd6921 100644 --- a/tests/training/test_strategies.py +++ b/tests/training/test_strategies.py @@ -60,7 +60,7 @@ from avalanche.training.supervised.strategy_wrappers import PNNStrategy from avalanche.training.templates import SupervisedTemplate from avalanche.training.templates.base import _group_experiences_by_stream -from avalanche.training.templates.common_templates import ( +from avalanche.training.templates.base import ( PositionalArgumentDeprecatedWarning, ) from avalanche.training.utils import get_last_fc_layer From 384d82a8c33702c5130d800cbaab189ee75dcaf1 Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Fri, 2 Feb 2024 12:13:14 +0100 Subject: [PATCH 06/21] Fix minor issue. Move declaration of CriterionType. --- .github/workflows/unit-test.yml | 2 - avalanche/training/supervised/ar1.py | 7 +--- avalanche/training/supervised/cumulative.py | 2 +- avalanche/training/supervised/deep_slda.py | 2 +- avalanche/training/supervised/der.py | 2 +- avalanche/training/supervised/er_ace.py | 2 +- avalanche/training/supervised/er_aml.py | 2 +- avalanche/training/supervised/expert_gate.py | 2 +- .../training/supervised/feature_replay.py | 2 +- .../training/supervised/joint_training.py | 2 +- avalanche/training/supervised/lamaml.py | 2 + avalanche/training/supervised/lamaml_v2.py | 2 + avalanche/training/supervised/mer.py | 2 +- .../training/supervised/strategy_wrappers.py | 4 +- .../supervised/strategy_wrappers_online.py | 2 +- avalanche/training/templates/base.py | 1 + avalanche/training/templates/base_sgd.py | 39 +++++++++---------- .../training/templates/common_templates.py | 10 ++--- .../templates/strategy_mixin_protocol.py | 10 +++-- tests/run_dist_tests.py | 4 ++ 20 files changed, 53 insertions(+), 48 deletions(-) diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index 769968d09..2972480ee 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -59,8 +59,6 @@ jobs: echo "Running checkpointing tests..." && bash ./tests/checkpointing/test_checkpointing.sh && echo "Running distributed training tests..." && - cd tests && PYTHONPATH=.. python run_dist_tests.py && - cd .. && echo "While running unit tests, the following datasets were downloaded:" && ls ~/.avalanche/data diff --git a/avalanche/training/supervised/ar1.py b/avalanche/training/supervised/ar1.py index a06370706..41040b9b0 100644 --- a/avalanche/training/supervised/ar1.py +++ b/avalanche/training/supervised/ar1.py @@ -17,7 +17,6 @@ CWRStarPlugin, ) from avalanche.training.templates import SupervisedTemplate -from avalanche.training.templates.base_sgd import CriterionType from avalanche.training.utils import ( replace_bn_with_brn, get_last_fc_layer, @@ -27,6 +26,7 @@ LayerAndParameter, ) from avalanche.training.plugins.evaluation import default_evaluator +from avalanche.training.templates.strategy_mixin_protocol import CriterionType class AR1(SupervisedTemplate): @@ -44,7 +44,7 @@ class AR1(SupervisedTemplate): def __init__( self, *, - criterion: CriterionType = None, + criterion: CriterionType = CrossEntropyLoss(), lr: float = 0.001, inc_lr: float = 5e-5, momentum=0.9, @@ -149,9 +149,6 @@ def __init__( optimizer = SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=l2) - if criterion is None: - criterion = CrossEntropyLoss() - self.ewc_lambda = ewc_lambda self.freeze_below_layer = freeze_below_layer self.rm_sz = rm_sz diff --git a/avalanche/training/supervised/cumulative.py b/avalanche/training/supervised/cumulative.py index 314966be2..0e073456c 100644 --- a/avalanche/training/supervised/cumulative.py +++ b/avalanche/training/supervised/cumulative.py @@ -8,7 +8,7 @@ from avalanche.training.plugins.evaluation import default_evaluator from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin from avalanche.training.templates import SupervisedTemplate -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType class Cumulative(SupervisedTemplate): diff --git a/avalanche/training/supervised/deep_slda.py b/avalanche/training/supervised/deep_slda.py index 7bc336d74..c234c44c9 100644 --- a/avalanche/training/supervised/deep_slda.py +++ b/avalanche/training/supervised/deep_slda.py @@ -13,7 +13,7 @@ ) from avalanche.models.dynamic_modules import MultiTaskModule from avalanche.models import FeatureExtractorBackbone -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType class StreamingLDA(SupervisedTemplate): diff --git a/avalanche/training/supervised/der.py b/avalanche/training/supervised/der.py index 7df915748..d9263f2a2 100644 --- a/avalanche/training/supervised/der.py +++ b/avalanche/training/supervised/der.py @@ -19,7 +19,7 @@ from avalanche.benchmarks.utils.data import AvalancheDataset from avalanche.benchmarks.utils.data_attribute import TensorDataAttribute from avalanche.benchmarks.utils.flat_data import FlatData -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType from avalanche.training.utils import cycle from avalanche.core import SupervisedPlugin from avalanche.training.plugins.evaluation import ( diff --git a/avalanche/training/supervised/er_ace.py b/avalanche/training/supervised/er_ace.py index d9a260e15..01146c432 100644 --- a/avalanche/training/supervised/er_ace.py +++ b/avalanche/training/supervised/er_ace.py @@ -13,7 +13,7 @@ ) from avalanche.training.storage_policy import ClassBalancedBuffer from avalanche.training.templates import SupervisedTemplate -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType from avalanche.training.utils import cycle diff --git a/avalanche/training/supervised/er_aml.py b/avalanche/training/supervised/er_aml.py index a79b574e0..9d44f1f1f 100644 --- a/avalanche/training/supervised/er_aml.py +++ b/avalanche/training/supervised/er_aml.py @@ -13,7 +13,7 @@ ) from avalanche.training.storage_policy import ClassBalancedBuffer from avalanche.training.templates import SupervisedTemplate -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType from avalanche.training.utils import cycle diff --git a/avalanche/training/supervised/expert_gate.py b/avalanche/training/supervised/expert_gate.py index fd04b7ad7..abb249216 100644 --- a/avalanche/training/supervised/expert_gate.py +++ b/avalanche/training/supervised/expert_gate.py @@ -22,7 +22,7 @@ from avalanche.training.templates import SupervisedTemplate from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin, LwFPlugin from avalanche.training.plugins.evaluation import default_evaluator -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType class ExpertGateStrategy(SupervisedTemplate): diff --git a/avalanche/training/supervised/feature_replay.py b/avalanche/training/supervised/feature_replay.py index 04e355f4e..e723b1515 100644 --- a/avalanche/training/supervised/feature_replay.py +++ b/avalanche/training/supervised/feature_replay.py @@ -11,7 +11,7 @@ from avalanche.training.plugins.evaluation import EvaluationPlugin, default_evaluator from avalanche.training.storage_policy import ClassBalancedBuffer from avalanche.training.templates import SupervisedTemplate -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType from avalanche.training.utils import cycle from avalanche.training.losses import MaskedCrossEntropy diff --git a/avalanche/training/supervised/joint_training.py b/avalanche/training/supervised/joint_training.py index 9b6ca3dc4..08294cfc0 100644 --- a/avalanche/training/supervised/joint_training.py +++ b/avalanche/training/supervised/joint_training.py @@ -28,7 +28,7 @@ _experiences_parameter_as_iterable, _group_experiences_by_stream, ) -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType class AlreadyTrainedError(Exception): diff --git a/avalanche/training/supervised/lamaml.py b/avalanche/training/supervised/lamaml.py index d727a778f..d880df347 100644 --- a/avalanche/training/supervised/lamaml.py +++ b/avalanche/training/supervised/lamaml.py @@ -8,6 +8,8 @@ from torch.optim import Optimizer import math +from avalanche.training.templates.strategy_mixin_protocol import CriterionType + try: import higher except ImportError: diff --git a/avalanche/training/supervised/lamaml_v2.py b/avalanche/training/supervised/lamaml_v2.py index 42fd8bfbf..36905cb80 100644 --- a/avalanche/training/supervised/lamaml_v2.py +++ b/avalanche/training/supervised/lamaml_v2.py @@ -3,6 +3,8 @@ import warnings import torch +from avalanche.training.templates.strategy_mixin_protocol import CriterionType + if parse(torch.__version__) < parse("2.0.0"): warnings.warn(f"LaMAML requires torch >= 2.0.0.") diff --git a/avalanche/training/supervised/mer.py b/avalanche/training/supervised/mer.py index a834a5930..d4d709d17 100644 --- a/avalanche/training/supervised/mer.py +++ b/avalanche/training/supervised/mer.py @@ -11,7 +11,7 @@ from avalanche.training.plugins.evaluation import default_evaluator from avalanche.training.storage_policy import ReservoirSamplingBuffer from avalanche.training.templates import SupervisedMetaLearningTemplate -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType class MERBuffer: diff --git a/avalanche/training/supervised/strategy_wrappers.py b/avalanche/training/supervised/strategy_wrappers.py index c8ddcab44..31fbe8f4d 100644 --- a/avalanche/training/supervised/strategy_wrappers.py +++ b/avalanche/training/supervised/strategy_wrappers.py @@ -45,11 +45,11 @@ ) from avalanche.training.templates.base import BaseTemplate from avalanche.training.templates import SupervisedTemplate -from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics +from avalanche.evaluation.metrics import loss_metrics from avalanche.models.generator import MlpVAE, VAE_loss from avalanche.models.expert_gate import AE_loss from avalanche.logging import InteractiveLogger -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType class Naive(SupervisedTemplate): diff --git a/avalanche/training/supervised/strategy_wrappers_online.py b/avalanche/training/supervised/strategy_wrappers_online.py index 0210eb005..ffe6f1575 100644 --- a/avalanche/training/supervised/strategy_wrappers_online.py +++ b/avalanche/training/supervised/strategy_wrappers_online.py @@ -21,7 +21,7 @@ SupervisedTemplate, ) from avalanche._annotations import deprecated -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType @deprecated( diff --git a/avalanche/training/templates/base.py b/avalanche/training/templates/base.py index c1d85957a..f491a951a 100644 --- a/avalanche/training/templates/base.py +++ b/avalanche/training/templates/base.py @@ -248,6 +248,7 @@ def __init_subclass__(cls, **kwargs): cls.__init__ = _support_legacy_strategy_positional_args( cls.__init__, cls.__name__ ) + super().__init_subclass__(**kwargs) # we need this only for type checking PLUGIN_CLASS = BasePlugin diff --git a/avalanche/training/templates/base_sgd.py b/avalanche/training/templates/base_sgd.py index e0133179c..e54f402a6 100644 --- a/avalanche/training/templates/base_sgd.py +++ b/avalanche/training/templates/base_sgd.py @@ -1,5 +1,5 @@ import sys -from typing import Any, Callable, Iterable, Sequence, Optional, TypeVar, Union +from typing import Any, Callable, Generic, Iterable, Sequence, Optional, TypeVar, Union from typing_extensions import TypeAlias from packaging.version import parse @@ -22,7 +22,10 @@ collate_from_data_or_kwargs, ) -from avalanche.training.templates.strategy_mixin_protocol import SGDStrategyProtocol +from avalanche.training.templates.strategy_mixin_protocol import ( + CriterionType, + SGDStrategyProtocol, +) from avalanche.training.utils import trigger_plugins @@ -30,10 +33,9 @@ TMBInput = TypeVar("TMBInput") TMBOutput = TypeVar("TMBOutput") -CriterionType: TypeAlias = Union[Module, Callable[[Tensor, Tensor], Tensor]] - class BaseSGDTemplate( + Generic[TDatasetExperience, TMBInput, TMBOutput], SGDStrategyProtocol[TDatasetExperience, TMBInput, TMBOutput], BaseTemplate[TDatasetExperience], ): @@ -93,21 +95,6 @@ def __init__( `eval_every` epochs or iterations (Default='epoch'). """ - super().__init__( - model=model, - optimizer=optimizer, - criterion=criterion, - train_mb_size=train_mb_size, - train_epochs=train_epochs, - eval_mb_size=eval_mb_size, - device=device, - plugins=plugins, - evaluator=evaluator, - eval_every=eval_every, - peval_mode=peval_mode, - **kwargs - ) - # Call super with all args if sys.version_info >= (3, 11): super().__init__( @@ -127,7 +114,19 @@ def __init__( else: super().__init__() # type: ignore BaseTemplate.__init__( - self=self, model=model, device=device, plugins=plugins + self, + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=train_mb_size, + train_epochs=train_epochs, + eval_mb_size=eval_mb_size, + device=device, + plugins=plugins, + evaluator=evaluator, + eval_every=eval_every, + peval_mode=peval_mode, + **kwargs ) self.optimizer: Optimizer = optimizer diff --git a/avalanche/training/templates/common_templates.py b/avalanche/training/templates/common_templates.py index 7dbde7e9a..8405c5a6f 100644 --- a/avalanche/training/templates/common_templates.py +++ b/avalanche/training/templates/common_templates.py @@ -1,8 +1,6 @@ import sys -from typing import Any, Callable, Dict, Sequence, Optional, TypeVar, Union -import warnings +from typing import Callable, Sequence, Optional, TypeVar, Union import torch -import inspect from torch.nn import Module, CrossEntropyLoss from torch.optim import Optimizer @@ -13,8 +11,8 @@ EvaluationPlugin, default_evaluator, ) -from avalanche.training.templates.base import PositionalArgumentDeprecatedWarning from avalanche.training.templates.strategy_mixin_protocol import ( + CriterionType, SupervisedStrategyProtocol, TMBOutput, TMBInput, @@ -23,7 +21,7 @@ from .observation_type import * from .problem_type import * from .update_type import * -from .base_sgd import BaseSGDTemplate, CriterionType +from .base_sgd import BaseSGDTemplate TDatasetExperience = TypeVar("TDatasetExperience", bound=DatasetExperience) @@ -202,7 +200,7 @@ def __init__( *, model: Module, optimizer: Optimizer, - criterion=CrossEntropyLoss(), + criterion: CriterionType = CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = 1, diff --git a/avalanche/training/templates/strategy_mixin_protocol.py b/avalanche/training/templates/strategy_mixin_protocol.py index c2635f4ad..4596c9d39 100644 --- a/avalanche/training/templates/strategy_mixin_protocol.py +++ b/avalanche/training/templates/strategy_mixin_protocol.py @@ -1,4 +1,5 @@ -from typing import Iterable, List, Optional, TypeVar, Protocol +from typing import Generic, Iterable, List, Optional, TypeVar, Protocol, Callable, Union +from typing_extensions import TypeAlias from torch import Tensor import torch @@ -17,8 +18,10 @@ TMBInput = TypeVar("TMBInput") TMBOutput = TypeVar("TMBOutput") +CriterionType: TypeAlias = Union[Module, Callable[[Tensor, Tensor], Tensor]] -class BaseStrategyProtocol(Protocol[TExperienceType]): + +class BaseStrategyProtocol(Generic[TExperienceType], Protocol[TExperienceType]): model: Module device: torch.device @@ -33,6 +36,7 @@ class BaseStrategyProtocol(Protocol[TExperienceType]): class SGDStrategyProtocol( + Generic[TSGDExperienceType, TMBInput, TMBOutput], BaseStrategyProtocol[TSGDExperienceType], Protocol[TSGDExperienceType, TMBInput, TMBOutput], ): @@ -52,7 +56,7 @@ class SGDStrategyProtocol( loss: Tensor - _criterion: Module + _criterion: CriterionType def forward(self) -> TMBOutput: ... diff --git a/tests/run_dist_tests.py b/tests/run_dist_tests.py index fa8d9f94d..232e98645 100644 --- a/tests/run_dist_tests.py +++ b/tests/run_dist_tests.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import signal import sys import unittest @@ -34,6 +35,9 @@ def get_distributed_test_cases(suite: Union[TestCase, TestSuite]) -> Set[str]: @click.command() @click.argument("test_cases", nargs=-1) def run_distributed_suites(test_cases): + if Path.cwd().name != "tests": + os.chdir(Path.cwd() / "tests") + cases_names = get_distributed_test_cases( unittest.defaultTestLoader.discover(".") ) # Don't change the path! From 9a04d726b25a042391a906b1781ee02efcb99d62 Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Fri, 2 Feb 2024 14:15:18 +0100 Subject: [PATCH 07/21] Amend (fix) latest unit-test action change --- .github/workflows/unit-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index 2972480ee..35aa2644c 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -59,6 +59,6 @@ jobs: echo "Running checkpointing tests..." && bash ./tests/checkpointing/test_checkpointing.sh && echo "Running distributed training tests..." && - PYTHONPATH=.. python run_dist_tests.py && + PYTHONPATH=.. python tests/run_dist_tests.py && echo "While running unit tests, the following datasets were downloaded:" && ls ~/.avalanche/data From 7d7e850c258fc394358f55303db905e504f27d91 Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Fri, 2 Feb 2024 15:33:38 +0100 Subject: [PATCH 08/21] Fix unit test action --- .github/workflows/unit-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index 35aa2644c..8d538e35e 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -59,6 +59,6 @@ jobs: echo "Running checkpointing tests..." && bash ./tests/checkpointing/test_checkpointing.sh && echo "Running distributed training tests..." && - PYTHONPATH=.. python tests/run_dist_tests.py && + PYTHONPATH=. python tests/run_dist_tests.py && echo "While running unit tests, the following datasets were downloaded:" && ls ~/.avalanche/data From 81b978a2ca4959b9e5fd9ac4d0cc25ad9cdb3b50 Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Fri, 2 Feb 2024 15:50:24 +0100 Subject: [PATCH 09/21] Fix PYTHONPATH issue in unit-test action --- .github/workflows/unit-test.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index 8d538e35e..af1941da1 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -51,14 +51,15 @@ jobs: env: FAST_TEST: "True" USE_GPU: "False" + PYTHONPATH: ${{ GITHUB_WORKSPACE }} run: | python -m unittest discover tests && echo "Checking that optional dependencies are not needed" && pip uninstall -y higher ctrl-benchmark torchaudio gym pycocotools lvis && - PYTHONPATH=. python examples/eval_plugin.py && + python examples/eval_plugin.py && echo "Running checkpointing tests..." && bash ./tests/checkpointing/test_checkpointing.sh && echo "Running distributed training tests..." && - PYTHONPATH=. python tests/run_dist_tests.py && + python tests/run_dist_tests.py && echo "While running unit tests, the following datasets were downloaded:" && ls ~/.avalanche/data From 76e13f70998ff3abb80bc2cb743fb8a82211abf2 Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Fri, 2 Feb 2024 15:51:07 +0100 Subject: [PATCH 10/21] Fix TextLogger checkpoint issue when using pytest --- avalanche/logging/text_logging.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/avalanche/logging/text_logging.py b/avalanche/logging/text_logging.py index 7222410bc..2d3b131f1 100644 --- a/avalanche/logging/text_logging.py +++ b/avalanche/logging/text_logging.py @@ -253,10 +253,10 @@ def _fobj_serialize(file_object) -> Optional[str]: out_file_path = TextLogger._file_get_real_path(file_object) stream_name = TextLogger._file_get_stream(file_object) - if out_file_path is not None: - return "path:" + str(out_file_path) - elif stream_name is not None: + if stream_name is not None: return "stream:" + stream_name + elif out_file_path is not None: + return "path:" + str(out_file_path) else: return None @@ -265,7 +265,7 @@ def _fobj_deserialize(file_def: str) -> Optional[TextIO]: if not isinstance(file_def, str): # Custom object (managed by pickle or dill library) return file_def - + if file_def.startswith("path:"): file_def = _remove_prefix(file_def, "path:") return open(file_def, "a") @@ -285,8 +285,15 @@ def _file_get_real_path(file_object) -> Optional[str]: # Manage files created by tempfile file_object = file_object.file fobject_path = file_object.name + if fobject_path in ["", ""]: + # Standard output / error + return None + + if isinstance(fobject_path, int): + # File descriptor return None + return fobject_path except AttributeError: return None From 8557560b53f84e960fba62954b76715ddb8f3a31 Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Fri, 2 Feb 2024 15:53:31 +0100 Subject: [PATCH 11/21] Fix lint issue --- avalanche/logging/text_logging.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/avalanche/logging/text_logging.py b/avalanche/logging/text_logging.py index 2d3b131f1..17451e344 100644 --- a/avalanche/logging/text_logging.py +++ b/avalanche/logging/text_logging.py @@ -265,7 +265,7 @@ def _fobj_deserialize(file_def: str) -> Optional[TextIO]: if not isinstance(file_def, str): # Custom object (managed by pickle or dill library) return file_def - + if file_def.startswith("path:"): file_def = _remove_prefix(file_def, "path:") return open(file_def, "a") @@ -289,11 +289,11 @@ def _file_get_real_path(file_object) -> Optional[str]: if fobject_path in ["", ""]: # Standard output / error return None - + if isinstance(fobject_path, int): # File descriptor return None - + return fobject_path except AttributeError: return None From 99d1563c0be912cff51c945bf1a868838ccb20a2 Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Fri, 2 Feb 2024 15:56:12 +0100 Subject: [PATCH 12/21] Fix issue in unit-test action --- .github/workflows/unit-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index af1941da1..745806e17 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -51,7 +51,7 @@ jobs: env: FAST_TEST: "True" USE_GPU: "False" - PYTHONPATH: ${{ GITHUB_WORKSPACE }} + PYTHONPATH: ${{ env.GITHUB_WORKSPACE }} run: | python -m unittest discover tests && echo "Checking that optional dependencies are not needed" && From c52842b83e3c5faeca1ddaeaea8b1407e09bfd47 Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Fri, 2 Feb 2024 16:25:45 +0100 Subject: [PATCH 13/21] Attempt fix for unit-test action --- .github/workflows/unit-test.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index 745806e17..6e2b3da10 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -51,8 +51,10 @@ jobs: env: FAST_TEST: "True" USE_GPU: "False" - PYTHONPATH: ${{ env.GITHUB_WORKSPACE }} + PYTHONPATH: ${{ github.workspace }} run: | + bash ./tests/checkpointing/test_checkpointing.sh && + exit 1 && python -m unittest discover tests && echo "Checking that optional dependencies are not needed" && pip uninstall -y higher ctrl-benchmark torchaudio gym pycocotools lvis && From c94f3c1482a0a82aa5dd5ab887d2df8c467132fe Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Fri, 2 Feb 2024 16:28:50 +0100 Subject: [PATCH 14/21] Definitive unit test fix --- .github/workflows/unit-test.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index 6e2b3da10..f831a3f18 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -53,8 +53,6 @@ jobs: USE_GPU: "False" PYTHONPATH: ${{ github.workspace }} run: | - bash ./tests/checkpointing/test_checkpointing.sh && - exit 1 && python -m unittest discover tests && echo "Checking that optional dependencies are not needed" && pip uninstall -y higher ctrl-benchmark torchaudio gym pycocotools lvis && From f87fc22b434a363a9de79515e24a3fb8181a4f9b Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Fri, 2 Feb 2024 16:29:13 +0100 Subject: [PATCH 15/21] Expand tests considered by coveralls --- .github/workflows/test-coverage-coveralls.yml | 11 +++- tests/checkpointing/test_checkpointing.sh | 64 ++++++++++++++++--- 2 files changed, 65 insertions(+), 10 deletions(-) diff --git a/.github/workflows/test-coverage-coveralls.yml b/.github/workflows/test-coverage-coveralls.yml index 76cda636b..be6e3f3e7 100644 --- a/.github/workflows/test-coverage-coveralls.yml +++ b/.github/workflows/test-coverage-coveralls.yml @@ -51,9 +51,18 @@ jobs: env: FAST_TEST: "True" USE_GPU: "False" + PYTHONPATH: ${{ github.workspace }} run: | + pip install pytest-xdist pytest-cov && mkdir coverage && - coverage run -m unittest discover tests && + pytest --cov=. -n 4 tests && + mv .coverage coverage_unit_tests && + coverage run examples/eval_plugin.py && + mv .coverage coverage_eval_plugin && + RUN_COVERAGE=True bash ./tests/checkpointing/test_checkpointing.sh && + coverage run tests/run_dist_tests.py && + mv .coverage coverage_dist_tests && + coverage combine coverage_* && coverage lcov -o coverage/lcov.info --omit '**/config-3.py,**/config.py' - name: Upload coverage data to coveralls.io uses: coverallsapp/github-action@master diff --git a/tests/checkpointing/test_checkpointing.sh b/tests/checkpointing/test_checkpointing.sh index eb447897d..a68b9c1b0 100755 --- a/tests/checkpointing/test_checkpointing.sh +++ b/tests/checkpointing/test_checkpointing.sh @@ -2,14 +2,19 @@ # Script used to automatically test various combinations of plugins when used with # the checkpointing functionality. set -euo pipefail -cd tests/checkpointing rm -rf logs rm -rf checkpoints rm -rf metrics_no_checkpoint rm -rf metrics_checkpoint +rm -rf checkpoint.pkl +rm -rf .coverage +rm -rf .coverage_no_checkpoint +rm -rf .coverage_checkpoint1 +rm -rf .coverage_checkpoint2 +rm -rf coverage_checkpointing_* export PYTHONUNBUFFERED=1 -export PYTHONPATH=../.. +export PYTHONPATH=. export CUDA_VISIBLE_DEVICES=0 export CUBLAS_WORKSPACE_CONFIG=:4096:8 @@ -30,6 +35,7 @@ function str_bool { RUN_FAST_TESTS=$(str_bool "${FAST_TEST:-False}") RUN_GPU_TESTS=$(str_bool "${USE_GPU:-False}") +RUN_COVERAGE=$(str_bool "${RUN_COVERAGE:-False}") GPU_PARAM="--cuda -1" @@ -38,26 +44,66 @@ then GPU_PARAM="--cuda 0" fi +BASE_RUNNER=( + "python" + "-u" +) + +if [ "$RUN_COVERAGE" = "true" ] +then + echo "Running with coverage..." + BASE_RUNNER=( + "coverage" + "run" + ) +fi + +BASE_RUNNER=( + "${BASE_RUNNER[@]}" + "tests/checkpointing/task_incremental_with_checkpointing.py" + $GPU_PARAM + "--benchmark" + $BENCHMARK +) + run_and_check() { # Without checkpoints - python -u task_incremental_with_checkpointing.py $GPU_PARAM --checkpoint_at -1 \ - --plugins "$@" --benchmark $BENCHMARK --log_metrics_to './metrics_no_checkpoint' + "${BASE_RUNNER[@]}" --checkpoint_at -1 \ + --plugins "$@" --log_metrics_to './metrics_no_checkpoint' rm -r ./checkpoint.pkl + if [ "$RUN_COVERAGE" = "true" ] + then + mv .coverage .coverage_no_checkpoint + fi # Stop after experience 1 - python -u task_incremental_with_checkpointing.py $GPU_PARAM --checkpoint_at 1 \ - --plugins "$@" --benchmark $BENCHMARK --log_metrics_to './metrics_checkpoint' + "${BASE_RUNNER[@]}" --checkpoint_at 1 \ + --plugins "$@" --log_metrics_to './metrics_checkpoint' + if [ "$RUN_COVERAGE" = "true" ] + then + mv .coverage .coverage_checkpoint1 + fi echo "Running from checkpoint 1..." - python -u task_incremental_with_checkpointing.py $GPU_PARAM --checkpoint_at 1 \ - --plugins "$@" --benchmark $BENCHMARK --log_metrics_to './metrics_checkpoint' + "${BASE_RUNNER[@]}" --checkpoint_at 1 \ + --plugins "$@" --log_metrics_to './metrics_checkpoint' + if [ "$RUN_COVERAGE" = "true" ] + then + mv .coverage .coverage_checkpoint2 + coverage combine .coverage_no_checkpoint .coverage_checkpoint1 .coverage_checkpoint2 + mv .coverage "coverage_checkpointing_$@" + fi rm -r ./checkpoint.pkl - python -u check_metrics_aligned.py \ + python -u tests/checkpointing/check_metrics_aligned.py \ "./metrics_no_checkpoint" "./metrics_checkpoint" rm -r metrics_no_checkpoint rm -r metrics_checkpoint rm -r logs + + rm -rf .coverage_no_checkpoint + rm -rf .coverage_checkpoint1 + rm -rf .coverage_checkpoint2 } run_and_check "replay" From e4b11efc7db45f621d0fc7c78e2f1078bd9efce0 Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Fri, 2 Feb 2024 16:45:47 +0100 Subject: [PATCH 16/21] Fix docstring for strategy wrappers --- .../training/supervised/strategy_wrappers.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/avalanche/training/supervised/strategy_wrappers.py b/avalanche/training/supervised/strategy_wrappers.py index 31fbe8f4d..541d500b2 100644 --- a/avalanche/training/supervised/strategy_wrappers.py +++ b/avalanche/training/supervised/strategy_wrappers.py @@ -305,7 +305,7 @@ def __init__( only at the end of the learning experience. Values >0 mean that `eval` is called every `eval_every` epochs and at the end of the learning experience. - :param \*\*base_kwargs: any additional + :param **base_kwargs: any additional :class:`~avalanche.training.BaseTemplate` constructor arguments. """ @@ -376,7 +376,7 @@ def __init__( only at the end of the learning experience. Values >0 mean that `eval` is called every `eval_every` epochs and at the end of the learning experience. - :param \*\*base_kwargs: any additional + :param **base_kwargs: any additional :class:`~avalanche.training.BaseTemplate` constructor arguments. """ @@ -460,7 +460,7 @@ def __init__( learning experience. :param generator_strategy: A trainable strategy with a generative model, which employs GenerativeReplayPlugin. Defaults to None. - :param \*\*base_kwargs: any additional + :param **base_kwargs: any additional :class:`~avalanche.training.BaseTemplate` constructor arguments. """ @@ -582,7 +582,7 @@ def __init__( only at the end of the learning experience. Values >0 mean that `eval` is called every `eval_every` epochs and at the end of the learning experience. - :param \*\*base_kwargs: any additional + :param **base_kwargs: any additional :class:`~avalanche.training.BaseTemplate` constructor arguments. """ @@ -657,7 +657,7 @@ def __init__( only at the end of the learning experience. Values >0 mean that `eval` is called every `eval_every` epochs and at the end of the learning experience. - :param \*\*base_kwargs: any additional + :param **base_kwargs: any additional :class:`~avalanche.training.BaseTemplate` constructor arguments. """ @@ -727,7 +727,7 @@ def __init__( only at the end of the learning experience. Values >0 mean that `eval` is called every `eval_every` epochs and at the end of the learning experience. - :param \*\*base_kwargs: any additional + :param **base_kwargs: any additional :class:`~avalanche.training.BaseTemplate` constructor arguments. """ rp = GSS_greedyPlugin( @@ -1484,7 +1484,7 @@ def __init__( only at the end of the learning experience. Values >0 mean that `eval` is called every `eval_every` epochs and at the end of the learning experience. - :param \*\*base_kwargs: any additional + :param **base_kwargs: any additional :class:`~avalanche.training.BaseTemplate` constructor arguments. """ @@ -1643,7 +1643,7 @@ def __init__( only at the end of the learning experience. Values >0 mean that `eval` is called every `eval_every` epochs and at the end of the learning experience. - :param \*\*base_kwargs: any additional + :param **base_kwargs: any additional :class:`~avalanche.training.BaseTemplate` constructor arguments. """ From a11cd31199081ce3620f50aa2b104656be3553f4 Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Fri, 2 Feb 2024 17:06:33 +0100 Subject: [PATCH 17/21] Fix coveralls action --- .github/workflows/test-coverage-coveralls.yml | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test-coverage-coveralls.yml b/.github/workflows/test-coverage-coveralls.yml index be6e3f3e7..5de79dbc7 100644 --- a/.github/workflows/test-coverage-coveralls.yml +++ b/.github/workflows/test-coverage-coveralls.yml @@ -46,6 +46,16 @@ jobs: - name: install coverage.py and coverralls run: | pip install coverage coveralls + - name: install pytest-xdist and pytest-cov + run: | + pip install pytest-xdist pytest-cov + - name: download datasets + env: + FAST_TEST: "True" + USE_GPU: "False" + PYTHONPATH: ${{ github.workspace }} + run: | + pytest ./tests/benchmarks/scenarios/deprecated/test_high_level_generators.py::HighLevelGeneratorTests::test_filelist_benchmark ./tests/benchmarks/scenarios/deprecated/test_nc_mt_scenario.py::MultiTaskTests::test_mt_single_dataset - name: python unit test id: unittest env: @@ -53,7 +63,6 @@ jobs: USE_GPU: "False" PYTHONPATH: ${{ github.workspace }} run: | - pip install pytest-xdist pytest-cov && mkdir coverage && pytest --cov=. -n 4 tests && mv .coverage coverage_unit_tests && From 763fd1ceb047b5e9d2a716b3867cc47feaac7b59 Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Mon, 5 Feb 2024 17:44:45 +0100 Subject: [PATCH 18/21] Move positional parameter management functions --- avalanche/training/templates/base.py | 402 +++++++++++++-------------- 1 file changed, 201 insertions(+), 201 deletions(-) diff --git a/avalanche/training/templates/base.py b/avalanche/training/templates/base.py index f491a951a..80b47f325 100644 --- a/avalanche/training/templates/base.py +++ b/avalanche/training/templates/base.py @@ -28,204 +28,6 @@ TPluginType = TypeVar("TPluginType", bound=BasePlugin, contravariant=True) -def _warn_init_has_positional_args(init_method, class_name): - init_args = inspect.signature(init_method).parameters - positional_args = [ - k - for k, v in init_args.items() - if v.kind - in { - inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.VAR_POSITIONAL, - } - ] - if len(positional_args) > 1: # self is always present - warnings.warn( - f"Avalanche is transitioning to strategy constructors that accept named (keyword) arguments only. " - f"This is done to ensure that there is no confusion regarding the meaning of each argument (strategies can have many arguments). " - f"Your strategy {class_name}.__init__ method still has some positional-only or " - f"positional-or-keyword arguments. Consider removing them. Offending arguments: {positional_args}. " - f"This can be achieved by adding a * in the argument list of your __init__ method just after 'self'. " - f"More info: https://peps.python.org/pep-3102/#specification" - ) - - -def _merge_legacy_positional_arguments( - init_method, - class_name: str, - args: Sequence[Any], - kwargs: Dict[str, Any], - strict_init_check=True, - allow_pos_args=True, -): - """ - Manage the legacy positional constructor parameters. - - Used to warn the user when passing positional parameters to strategy constructors - (which is deprecated). - - To allow for a smooth transition, we allow the user to pass positional - arguments to the constructor. However, we warn the user that - this soft transition mechanism will be removed in the future. - """ - - if len(args) == 1: - # No positional argument has been passed (good!) - return args, kwargs - elif len(args) == 0: - # This should never happen and will fail later. - # assert len(args) == 0, "At least the 'self' argument should be passed" - return args, kwargs - - all_init_args = dict(inspect.signature(init_method).parameters) - - # Remove 'self' from the list of arguments - all_init_args.pop("self") - - # Divide parameters in groups - pos_only_args = [ - (k, v) - for k, v in all_init_args.items() - if v.kind == inspect.Parameter.POSITIONAL_ONLY - ] - pos_or_keyword = [ - (k, v) - for k, v in all_init_args.items() - if v.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD - ] - var_pos = [ - (k, v) - for k, v in all_init_args.items() - if v.kind == inspect.Parameter.VAR_POSITIONAL - ] - keyword_only_args = [ - (k, v) - for k, v in all_init_args.items() - if v.kind == inspect.Parameter.KEYWORD_ONLY - ] - var_keyword_args = [ - (k, v) - for k, v in all_init_args.items() - if v.kind == inspect.Parameter.VAR_KEYWORD - ] - - error_str = ( - f"Avalanche is transitioning to strategy constructors that accept named (keyword) arguments only. " - f"This is done to ensure that there is no confusion regarding the meaning of each argument (strategies can have many arguments). " - f"Your are passing {len(args) - 1} positional arguments to the {class_name}.__init__ method. " - f"Consider passing them as names arguments." - ) - - if allow_pos_args: - error_str += ( - " The ability to pass positional arguments will be removed in the future." - ) - warnings.warn(error_str, category=PositionalArgumentDeprecatedWarning) - else: - raise PositionalArgumentDeprecatedWarning(error_str) - - args_to_manage = list(args) - kwargs_to_manage = dict(kwargs) - - result_args = [args_to_manage.pop(0)] # Add self - result_kwargs = OrderedDict() - - unset_arguments = set(all_init_args.keys()) - - for argument_name, arg_def in pos_only_args: - if len(args_to_manage) > 0: - result_args.append(args_to_manage.pop(0)) - elif arg_def.default is inspect.Parameter.empty: - raise ValueError( - f"Positional-only argument {argument_name} is not set (and no default set)." - ) - unset_arguments.remove(argument_name) - - for argument_name, arg_def in pos_or_keyword: - if len(args_to_manage) > 0: - result_args.append(args_to_manage.pop(0)) - elif argument_name in kwargs_to_manage: - result_kwargs[argument_name] = kwargs_to_manage.pop(argument_name) - elif arg_def.default is inspect.Parameter.empty: - raise ValueError( - f"Parameter {argument_name} is not set (and no default provided)." - ) - - if argument_name not in unset_arguments: - # This is the same error and message raised by Python when passing - # multiple values for an argument. - raise TypeError(f"Got multiple values for argument '{argument_name}'") - unset_arguments.remove(argument_name) - - if len(var_pos) > 0: - # assert len(var_pos) == 1, "Only one var-positional argument is supported" - argument_name = var_pos[0][0] - if len(args_to_manage) > 0: - result_args.extend(args_to_manage) - args_to_manage = list() - - if argument_name not in unset_arguments: - # This is the same error and message raised by Python when passing - # multiple values for an argument. - raise TypeError(f"Got multiple values for argument '{argument_name}'") - - unset_arguments.remove(argument_name) - - for argument_name, arg_def in keyword_only_args: - if len(args_to_manage) > 0 and argument_name in kwargs_to_manage: - raise TypeError( - f"Got multiple values for argument '{argument_name}' (passed as both positional and named parameter)" - ) - - if len(args_to_manage) > 0: - # This is where the soft transition mechanism is implemented. - # The legacy positional arguments are transformed to keyword arguments. - result_kwargs[argument_name] = args_to_manage.pop(0) - elif argument_name in kwargs_to_manage: - result_kwargs[argument_name] = kwargs_to_manage.pop(argument_name) - elif arg_def.default is inspect.Parameter.empty: - raise ValueError( - f"Keyword-only parameter {argument_name} is not set (and no default set)." - ) - - if argument_name not in unset_arguments: - # This is the same error and message raised by Python when passing - # multiple values for an argument. - raise TypeError(f"Got multiple values for argument '{argument_name}'") - - unset_arguments.remove(argument_name) - - if len(var_keyword_args) > 0: - # assert len(var_keyword_args) == 1, "Only one var-keyword argument is supported" - argument_name = var_keyword_args[0][0] - result_kwargs.update(kwargs_to_manage) - kwargs_to_manage = dict() - - if argument_name not in unset_arguments: - # This is the same error and message raised by Python when passing - # multiple values for an argument. - raise TypeError(f"Got multiple values for argument '{argument_name}'") - unset_arguments.remove(argument_name) - - assert len(unset_arguments) == 0 - - return result_args, result_kwargs - - -def _support_legacy_strategy_positional_args(func, cls_name): - @functools.wraps(func) - def wrap_init(*args, **kwargs): - # print("wrap_init", cls_name, args[0]) - _warn_init_has_positional_args(func, cls_name) - args, kwargs = _merge_legacy_positional_arguments( - func, cls_name, args, kwargs, strict_init_check=False, allow_pos_args=True - ) - return func(*args, **kwargs) - - return wrap_init - - class BaseTemplate(BaseStrategyProtocol[TExperienceType]): """Base class for continual learning skeletons. @@ -568,8 +370,206 @@ def _experiences_parameter_as_iterable( return [experiences] -__all__ = ["BaseTemplate"] - - class PositionalArgumentDeprecatedWarning(UserWarning): pass + + +def _warn_init_has_positional_args(init_method, class_name): + init_args = inspect.signature(init_method).parameters + positional_args = [ + k + for k, v in init_args.items() + if v.kind + in { + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.VAR_POSITIONAL, + } + ] + if len(positional_args) > 1: # self is always present + warnings.warn( + f"Avalanche is transitioning to strategy constructors that accept named (keyword) arguments only. " + f"This is done to ensure that there is no confusion regarding the meaning of each argument (strategies can have many arguments). " + f"Your strategy {class_name}.__init__ method still has some positional-only or " + f"positional-or-keyword arguments. Consider removing them. Offending arguments: {positional_args}. " + f"This can be achieved by adding a * in the argument list of your __init__ method just after 'self'. " + f"More info: https://peps.python.org/pep-3102/#specification" + ) + + +def _merge_legacy_positional_arguments( + init_method, + class_name: str, + args: Sequence[Any], + kwargs: Dict[str, Any], + strict_init_check=True, + allow_pos_args=True, +): + """ + Manage the legacy positional constructor parameters. + + Used to warn the user when passing positional parameters to strategy constructors + (which is deprecated). + + To allow for a smooth transition, we allow the user to pass positional + arguments to the constructor. However, we warn the user that + this soft transition mechanism will be removed in the future. + """ + + if len(args) == 1: + # No positional argument has been passed (good!) + return args, kwargs + elif len(args) == 0: + # This should never happen and will fail later. + # assert len(args) == 0, "At least the 'self' argument should be passed" + return args, kwargs + + all_init_args = dict(inspect.signature(init_method).parameters) + + # Remove 'self' from the list of arguments + all_init_args.pop("self") + + # Divide parameters in groups + pos_only_args = [ + (k, v) + for k, v in all_init_args.items() + if v.kind == inspect.Parameter.POSITIONAL_ONLY + ] + pos_or_keyword = [ + (k, v) + for k, v in all_init_args.items() + if v.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + ] + var_pos = [ + (k, v) + for k, v in all_init_args.items() + if v.kind == inspect.Parameter.VAR_POSITIONAL + ] + keyword_only_args = [ + (k, v) + for k, v in all_init_args.items() + if v.kind == inspect.Parameter.KEYWORD_ONLY + ] + var_keyword_args = [ + (k, v) + for k, v in all_init_args.items() + if v.kind == inspect.Parameter.VAR_KEYWORD + ] + + error_str = ( + f"Avalanche is transitioning to strategy constructors that accept named (keyword) arguments only. " + f"This is done to ensure that there is no confusion regarding the meaning of each argument (strategies can have many arguments). " + f"Your are passing {len(args) - 1} positional arguments to the {class_name}.__init__ method. " + f"Consider passing them as names arguments." + ) + + if allow_pos_args: + error_str += ( + " The ability to pass positional arguments will be removed in the future." + ) + warnings.warn(error_str, category=PositionalArgumentDeprecatedWarning) + else: + raise PositionalArgumentDeprecatedWarning(error_str) + + args_to_manage = list(args) + kwargs_to_manage = dict(kwargs) + + result_args = [args_to_manage.pop(0)] # Add self + result_kwargs = OrderedDict() + + unset_arguments = set(all_init_args.keys()) + + for argument_name, arg_def in pos_only_args: + if len(args_to_manage) > 0: + result_args.append(args_to_manage.pop(0)) + elif arg_def.default is inspect.Parameter.empty: + raise ValueError( + f"Positional-only argument {argument_name} is not set (and no default set)." + ) + unset_arguments.remove(argument_name) + + for argument_name, arg_def in pos_or_keyword: + if len(args_to_manage) > 0: + result_args.append(args_to_manage.pop(0)) + elif argument_name in kwargs_to_manage: + result_kwargs[argument_name] = kwargs_to_manage.pop(argument_name) + elif arg_def.default is inspect.Parameter.empty: + raise ValueError( + f"Parameter {argument_name} is not set (and no default provided)." + ) + + if argument_name not in unset_arguments: + # This is the same error and message raised by Python when passing + # multiple values for an argument. + raise TypeError(f"Got multiple values for argument '{argument_name}'") + unset_arguments.remove(argument_name) + + if len(var_pos) > 0: + # assert len(var_pos) == 1, "Only one var-positional argument is supported" + argument_name = var_pos[0][0] + if len(args_to_manage) > 0: + result_args.extend(args_to_manage) + args_to_manage = list() + + if argument_name not in unset_arguments: + # This is the same error and message raised by Python when passing + # multiple values for an argument. + raise TypeError(f"Got multiple values for argument '{argument_name}'") + + unset_arguments.remove(argument_name) + + for argument_name, arg_def in keyword_only_args: + if len(args_to_manage) > 0 and argument_name in kwargs_to_manage: + raise TypeError( + f"Got multiple values for argument '{argument_name}' (passed as both positional and named parameter)" + ) + + if len(args_to_manage) > 0: + # This is where the soft transition mechanism is implemented. + # The legacy positional arguments are transformed to keyword arguments. + result_kwargs[argument_name] = args_to_manage.pop(0) + elif argument_name in kwargs_to_manage: + result_kwargs[argument_name] = kwargs_to_manage.pop(argument_name) + elif arg_def.default is inspect.Parameter.empty: + raise ValueError( + f"Keyword-only parameter {argument_name} is not set (and no default set)." + ) + + if argument_name not in unset_arguments: + # This is the same error and message raised by Python when passing + # multiple values for an argument. + raise TypeError(f"Got multiple values for argument '{argument_name}'") + + unset_arguments.remove(argument_name) + + if len(var_keyword_args) > 0: + # assert len(var_keyword_args) == 1, "Only one var-keyword argument is supported" + argument_name = var_keyword_args[0][0] + result_kwargs.update(kwargs_to_manage) + kwargs_to_manage = dict() + + if argument_name not in unset_arguments: + # This is the same error and message raised by Python when passing + # multiple values for an argument. + raise TypeError(f"Got multiple values for argument '{argument_name}'") + unset_arguments.remove(argument_name) + + assert len(unset_arguments) == 0 + + return result_args, result_kwargs + + +def _support_legacy_strategy_positional_args(func, cls_name): + @functools.wraps(func) + def wrap_init(*args, **kwargs): + # print("wrap_init", cls_name, args[0]) + _warn_init_has_positional_args(func, cls_name) + args, kwargs = _merge_legacy_positional_arguments( + func, cls_name, args, kwargs, strict_init_check=False, allow_pos_args=True + ) + return func(*args, **kwargs) + + return wrap_init + + +__all__ = ["BaseTemplate"] From 4d15accd1f7ec8f0a13364f9a85451e348e2bd3c Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Mon, 5 Feb 2024 17:45:09 +0100 Subject: [PATCH 19/21] Fix lint issues --- avalanche/training/templates/base.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/avalanche/training/templates/base.py b/avalanche/training/templates/base.py index 80b47f325..d58bae600 100644 --- a/avalanche/training/templates/base.py +++ b/avalanche/training/templates/base.py @@ -148,9 +148,9 @@ def train( self.model.to(self.device) # Normalize training and eval data. - experiences_list: Iterable[TExperienceType] = ( - _experiences_parameter_as_iterable(experiences) - ) + experiences_list: Iterable[ + TExperienceType + ] = _experiences_parameter_as_iterable(experiences) if eval_streams is None: eval_streams = [experiences_list] @@ -202,9 +202,9 @@ def eval( self.is_training = False self.model.eval() - experiences_list: Iterable[TExperienceType] = ( - _experiences_parameter_as_iterable(experiences) - ) + experiences_list: Iterable[ + TExperienceType + ] = _experiences_parameter_as_iterable(experiences) self.current_eval_stream = experiences_list self._before_eval(**kwargs) From b949bac5330b7ad5ffd1dedd5aae3965b40d00bf Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Mon, 5 Feb 2024 18:06:04 +0100 Subject: [PATCH 20/21] Add named parameters "spellcheck". Fix warning name. --- avalanche/training/templates/base.py | 62 ++++++++++++++++----- tests/models/test_packnet.py | 4 +- tests/training/test_online_strategies.py | 4 +- tests/training/test_strategies.py | 69 ++++++++++++++---------- 4 files changed, 93 insertions(+), 46 deletions(-) diff --git a/avalanche/training/templates/base.py b/avalanche/training/templates/base.py index d58bae600..433ebba1c 100644 --- a/avalanche/training/templates/base.py +++ b/avalanche/training/templates/base.py @@ -8,6 +8,7 @@ OrderedDict, Sequence, Optional, + Type, TypeVar, Union, List, @@ -47,9 +48,7 @@ class BaseTemplate(BaseStrategyProtocol[TExperienceType]): def __init_subclass__(cls, **kwargs): # This is needed to manage the transition to keyword-only arguments. - cls.__init__ = _support_legacy_strategy_positional_args( - cls.__init__, cls.__name__ - ) + cls.__init__ = _support_legacy_strategy_positional_args(cls) super().__init_subclass__(**kwargs) # we need this only for type checking @@ -370,7 +369,7 @@ def _experiences_parameter_as_iterable( return [experiences] -class PositionalArgumentDeprecatedWarning(UserWarning): +class PositionalArgumentsDeprecatedWarning(UserWarning): pass @@ -402,7 +401,6 @@ def _merge_legacy_positional_arguments( class_name: str, args: Sequence[Any], kwargs: Dict[str, Any], - strict_init_check=True, allow_pos_args=True, ): """ @@ -467,9 +465,9 @@ def _merge_legacy_positional_arguments( error_str += ( " The ability to pass positional arguments will be removed in the future." ) - warnings.warn(error_str, category=PositionalArgumentDeprecatedWarning) + warnings.warn(error_str, category=PositionalArgumentsDeprecatedWarning) else: - raise PositionalArgumentDeprecatedWarning(error_str) + raise PositionalArgumentsDeprecatedWarning(error_str) args_to_manage = list(args) kwargs_to_manage = dict(kwargs) @@ -559,15 +557,53 @@ def _merge_legacy_positional_arguments( return result_args, result_kwargs -def _support_legacy_strategy_positional_args(func, cls_name): - @functools.wraps(func) +def _check_mispelled_kwargs(cls: Type, kwargs: Dict[str, Any]): + # First: we gather all parameter names of inits of all the classes in the mro + all_init_args = set() + for c in cls.mro(): + # We then consider only positional_or_keyword and keyword_only arguments + # Also, it does not make sense to include self + all_init_args.update( + k + for k, v in inspect.signature(c.__init__).parameters.items() + if v.kind + in { + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + } + and k != "self" + ) + + passed_parameters = set(kwargs.keys()) + passed_parameters.discard("self") # self should not be in kwargs, but it's better to be safe + + # Then we check if there are any mispelled/unexpected arguments + unexpected_args = list(passed_parameters - all_init_args) + + if len(unexpected_args) == 1: + raise TypeError( + f"{cls.__name__}.__init__ got an unexpected keyword argument: {unexpected_args[0]}. " + "This parameter is not accepted by the strategy class or any of its super classes. " + "Please check if you have mispelled the parameter name." + ) + elif len(unexpected_args) > 1: + raise TypeError( + f"{cls.__name__}.__init__ got unexpected keyword arguments: {unexpected_args}. " + "Those parameters are not accepted by the strategy class or any of its super classes. " + "Please check if you have mispelled any parameter name." + ) + + +def _support_legacy_strategy_positional_args(cls): + init_method, cls_name = cls.__init__, cls.__name__ + @functools.wraps(init_method) def wrap_init(*args, **kwargs): - # print("wrap_init", cls_name, args[0]) - _warn_init_has_positional_args(func, cls_name) + _warn_init_has_positional_args(init_method, cls_name) args, kwargs = _merge_legacy_positional_arguments( - func, cls_name, args, kwargs, strict_init_check=False, allow_pos_args=True + init_method, cls_name, args, kwargs, allow_pos_args=True ) - return func(*args, **kwargs) + _check_mispelled_kwargs(cls, kwargs) + return init_method(*args, **kwargs) return wrap_init diff --git a/tests/models/test_packnet.py b/tests/models/test_packnet.py index 5b423eca9..c80aeb402 100644 --- a/tests/models/test_packnet.py +++ b/tests/models/test_packnet.py @@ -2,7 +2,7 @@ from avalanche.models.packnet import PackNetModel, packnet_simple_mlp from avalanche.training.supervised.strategy_wrappers import PackNet from avalanche.training.templates.base import ( - PositionalArgumentDeprecatedWarning, + PositionalArgumentsDeprecatedWarning, ) from torch.optim import SGD, Adam import torch @@ -29,7 +29,7 @@ def _test_PackNetPlugin(self, expectations, optimizer_constructor, lr): ) model = construct_model() optimizer = optimizer_constructor(model.parameters(), lr=lr) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = PackNet( model, prune_proportion=0.5, diff --git a/tests/training/test_online_strategies.py b/tests/training/test_online_strategies.py index 852fd0db6..7d95b7856 100644 --- a/tests/training/test_online_strategies.py +++ b/tests/training/test_online_strategies.py @@ -10,7 +10,7 @@ from avalanche.benchmarks.scenarios.online import OnlineCLScenario from avalanche.training import OnlineNaive from avalanche.training.templates.base import ( - PositionalArgumentDeprecatedWarning, + PositionalArgumentsDeprecatedWarning, ) from tests.unit_tests_utils import get_fast_benchmark from avalanche.training.plugins.evaluation import default_evaluator @@ -49,7 +49,7 @@ def test_naive(self): # With task boundaries model, optimizer, criterion, _ = self.init_sit() - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = OnlineNaive( model, optimizer, diff --git a/tests/training/test_strategies.py b/tests/training/test_strategies.py index 30bbd6921..5265d5c31 100644 --- a/tests/training/test_strategies.py +++ b/tests/training/test_strategies.py @@ -61,7 +61,7 @@ from avalanche.training.templates import SupervisedTemplate from avalanche.training.templates.base import _group_experiences_by_stream from avalanche.training.templates.base import ( - PositionalArgumentDeprecatedWarning, + PositionalArgumentsDeprecatedWarning, ) from avalanche.training.utils import get_last_fc_layer from tests.training.test_strategy_utils import run_strategy @@ -275,7 +275,18 @@ def test_naive(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertRaises(TypeError): + strategy = Naive( + model, + optimizer, + criterion, + train_mb_size=64, + device=self.device, + eval_mb_size=50, + trai_epochs=2, + ) + + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = Naive( model, optimizer, @@ -319,7 +330,7 @@ def after_train_dataset_adaptation( # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = JointTraining( model, optimizer, @@ -361,7 +372,7 @@ def test_cwrstar(self): model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) last_fc_name, _ = get_last_fc_layer(model) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = CWRStar( model, optimizer, @@ -413,7 +424,7 @@ def test_cwrstar(self): def test_replay(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = Replay( model, optimizer, @@ -446,7 +457,7 @@ def test_replay(self): def test_mer(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = MER( model, optimizer, @@ -479,7 +490,7 @@ def test_mer(self): def test_gdumb(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = GDumb( model, optimizer, @@ -510,7 +521,7 @@ def test_gdumb(self): def test_cumulative(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = Cumulative( model, optimizer, @@ -537,7 +548,7 @@ def test_cumulative(self): def test_slda(self): model, _, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = StreamingLDA( model, criterion, @@ -567,7 +578,7 @@ def test_warning_slda_lwf(self): def test_lwf(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = LwF( model, optimizer, @@ -599,7 +610,7 @@ def test_lwf(self): def test_agem(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = AGEM( model, optimizer, @@ -629,7 +640,7 @@ def test_agem(self): def test_gem(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = GEM( model, optimizer, @@ -659,7 +670,7 @@ def test_gem(self): def test_ewc(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = EWC( model, optimizer, @@ -690,7 +701,7 @@ def test_ewc(self): def test_ewc_online(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = EWC( model, optimizer, @@ -722,7 +733,7 @@ def test_ewc_online(self): def test_rwalk(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = Naive( model, optimizer, @@ -762,7 +773,7 @@ def test_rwalk(self): def test_synaptic_intelligence(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = SynapticIntelligence( model, optimizer, @@ -794,7 +805,7 @@ def test_cope(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = CoPE( model, optimizer, @@ -831,7 +842,7 @@ def test_pnn(self): # eval on future tasks is not allowed. model = PNN(num_layers=3, in_features=6, hidden_features_per_column=10) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = PNNStrategy( model, optimizer, @@ -850,7 +861,7 @@ def test_pnn(self): def test_expertgate(self): model = ExpertGate(shape=(3, 227, 227), device=self.device) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = ExpertGateStrategy( model, optimizer, @@ -889,7 +900,7 @@ def test_icarl(self): cls = torch.nn.Sigmoid() optimizer = SGD([*fe.parameters(), *cls.parameters()], lr=0.001) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = ICaRL( fe, cls, @@ -909,7 +920,7 @@ def test_icarl(self): def test_lfl(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = LFL( model, optimizer, @@ -940,7 +951,7 @@ def test_lfl(self): def test_mas(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = MAS( model, optimizer, @@ -973,7 +984,7 @@ def test_mas(self): def test_bic(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = BiC( model, optimizer, @@ -994,7 +1005,7 @@ def test_bic(self): def test_mir(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = MIR( model, optimizer, @@ -1028,7 +1039,7 @@ def test_mir(self): def test_erace(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = ER_ACE( model, optimizer, @@ -1045,7 +1056,7 @@ def test_erace(self): def test_eraml(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = ER_AML( model, model.features, @@ -1065,7 +1076,7 @@ def test_eraml(self): def test_l2p(self): _, _, _, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = LearningToPrompt( "simpleMLP", criterion=CrossEntropyLoss(), @@ -1084,7 +1095,7 @@ def test_l2p(self): def test_der(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = DER( model, optimizer, @@ -1134,7 +1145,7 @@ def test_feature_replay(self): # We do not add MT criterion cause FeatureReplay uses # MaskedCrossEntropy as main criterion - with self.assertWarns(PositionalArgumentDeprecatedWarning): + with self.assertWarns(PositionalArgumentsDeprecatedWarning): strategy = FeatureReplay( model, optimizer, From 6eb295494e3b09d2cd022b3fb5bc52a8b6c31d31 Mon Sep 17 00:00:00 2001 From: Lorenzo Pellegrini Date: Mon, 5 Feb 2024 18:06:53 +0100 Subject: [PATCH 21/21] Fix linting issue --- avalanche/training/templates/base.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/avalanche/training/templates/base.py b/avalanche/training/templates/base.py index 433ebba1c..2ec228291 100644 --- a/avalanche/training/templates/base.py +++ b/avalanche/training/templates/base.py @@ -147,9 +147,9 @@ def train( self.model.to(self.device) # Normalize training and eval data. - experiences_list: Iterable[ - TExperienceType - ] = _experiences_parameter_as_iterable(experiences) + experiences_list: Iterable[TExperienceType] = ( + _experiences_parameter_as_iterable(experiences) + ) if eval_streams is None: eval_streams = [experiences_list] @@ -201,9 +201,9 @@ def eval( self.is_training = False self.model.eval() - experiences_list: Iterable[ - TExperienceType - ] = _experiences_parameter_as_iterable(experiences) + experiences_list: Iterable[TExperienceType] = ( + _experiences_parameter_as_iterable(experiences) + ) self.current_eval_stream = experiences_list self._before_eval(**kwargs) @@ -575,7 +575,9 @@ def _check_mispelled_kwargs(cls: Type, kwargs: Dict[str, Any]): ) passed_parameters = set(kwargs.keys()) - passed_parameters.discard("self") # self should not be in kwargs, but it's better to be safe + passed_parameters.discard( + "self" + ) # self should not be in kwargs, but it's better to be safe # Then we check if there are any mispelled/unexpected arguments unexpected_args = list(passed_parameters - all_init_args) @@ -596,6 +598,7 @@ def _check_mispelled_kwargs(cls: Type, kwargs: Dict[str, Any]): def _support_legacy_strategy_positional_args(cls): init_method, cls_name = cls.__init__, cls.__name__ + @functools.wraps(init_method) def wrap_init(*args, **kwargs): _warn_init_has_positional_args(init_method, cls_name)