Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multiple inheritance in Python 3.11 #1587

Merged
Merged
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix minor issue. Move declaration of CriterionType.
lrzpellegrini committed Feb 2, 2024
commit 384d82a8c33702c5130d800cbaab189ee75dcaf1
2 changes: 0 additions & 2 deletions .github/workflows/unit-test.yml
Original file line number Diff line number Diff line change
@@ -59,8 +59,6 @@ jobs:
echo "Running checkpointing tests..." &&
bash ./tests/checkpointing/test_checkpointing.sh &&
echo "Running distributed training tests..." &&
cd tests &&
PYTHONPATH=.. python run_dist_tests.py &&
cd .. &&
echo "While running unit tests, the following datasets were downloaded:" &&
ls ~/.avalanche/data
7 changes: 2 additions & 5 deletions avalanche/training/supervised/ar1.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,6 @@
CWRStarPlugin,
)
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.utils import (
replace_bn_with_brn,
get_last_fc_layer,
@@ -27,6 +26,7 @@
LayerAndParameter,
)
from avalanche.training.plugins.evaluation import default_evaluator
from avalanche.training.templates.strategy_mixin_protocol import CriterionType


class AR1(SupervisedTemplate):
@@ -44,7 +44,7 @@ class AR1(SupervisedTemplate):
def __init__(
self,
*,
criterion: CriterionType = None,
criterion: CriterionType = CrossEntropyLoss(),
lr: float = 0.001,
inc_lr: float = 5e-5,
momentum=0.9,
@@ -149,9 +149,6 @@ def __init__(

optimizer = SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=l2)

if criterion is None:
criterion = CrossEntropyLoss()

self.ewc_lambda = ewc_lambda
self.freeze_below_layer = freeze_below_layer
self.rm_sz = rm_sz
2 changes: 1 addition & 1 deletion avalanche/training/supervised/cumulative.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@
from avalanche.training.plugins.evaluation import default_evaluator
from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType


class Cumulative(SupervisedTemplate):
2 changes: 1 addition & 1 deletion avalanche/training/supervised/deep_slda.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
)
from avalanche.models.dynamic_modules import MultiTaskModule
from avalanche.models import FeatureExtractorBackbone
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType


class StreamingLDA(SupervisedTemplate):
2 changes: 1 addition & 1 deletion avalanche/training/supervised/der.py
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@
from avalanche.benchmarks.utils.data import AvalancheDataset
from avalanche.benchmarks.utils.data_attribute import TensorDataAttribute
from avalanche.benchmarks.utils.flat_data import FlatData
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType
from avalanche.training.utils import cycle
from avalanche.core import SupervisedPlugin
from avalanche.training.plugins.evaluation import (
2 changes: 1 addition & 1 deletion avalanche/training/supervised/er_ace.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
)
from avalanche.training.storage_policy import ClassBalancedBuffer
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType
from avalanche.training.utils import cycle


2 changes: 1 addition & 1 deletion avalanche/training/supervised/er_aml.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
)
from avalanche.training.storage_policy import ClassBalancedBuffer
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType
from avalanche.training.utils import cycle


2 changes: 1 addition & 1 deletion avalanche/training/supervised/expert_gate.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin, LwFPlugin
from avalanche.training.plugins.evaluation import default_evaluator
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType


class ExpertGateStrategy(SupervisedTemplate):
2 changes: 1 addition & 1 deletion avalanche/training/supervised/feature_replay.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
from avalanche.training.plugins.evaluation import EvaluationPlugin, default_evaluator
from avalanche.training.storage_policy import ClassBalancedBuffer
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType
from avalanche.training.utils import cycle
from avalanche.training.losses import MaskedCrossEntropy

2 changes: 1 addition & 1 deletion avalanche/training/supervised/joint_training.py
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@
_experiences_parameter_as_iterable,
_group_experiences_by_stream,
)
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType


class AlreadyTrainedError(Exception):
2 changes: 2 additions & 0 deletions avalanche/training/supervised/lamaml.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,8 @@
from torch.optim import Optimizer
import math

from avalanche.training.templates.strategy_mixin_protocol import CriterionType

try:
import higher
except ImportError:
2 changes: 2 additions & 0 deletions avalanche/training/supervised/lamaml_v2.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,8 @@
import warnings
import torch

from avalanche.training.templates.strategy_mixin_protocol import CriterionType

if parse(torch.__version__) < parse("2.0.0"):
warnings.warn(f"LaMAML requires torch >= 2.0.0.")

2 changes: 1 addition & 1 deletion avalanche/training/supervised/mer.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
from avalanche.training.plugins.evaluation import default_evaluator
from avalanche.training.storage_policy import ReservoirSamplingBuffer
from avalanche.training.templates import SupervisedMetaLearningTemplate
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType


class MERBuffer:
4 changes: 2 additions & 2 deletions avalanche/training/supervised/strategy_wrappers.py
Original file line number Diff line number Diff line change
@@ -45,11 +45,11 @@
)
from avalanche.training.templates.base import BaseTemplate
from avalanche.training.templates import SupervisedTemplate
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
from avalanche.evaluation.metrics import loss_metrics
from avalanche.models.generator import MlpVAE, VAE_loss
from avalanche.models.expert_gate import AE_loss
from avalanche.logging import InteractiveLogger
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType


class Naive(SupervisedTemplate):
2 changes: 1 addition & 1 deletion avalanche/training/supervised/strategy_wrappers_online.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@
SupervisedTemplate,
)
from avalanche._annotations import deprecated
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType


@deprecated(
1 change: 1 addition & 0 deletions avalanche/training/templates/base.py
Original file line number Diff line number Diff line change
@@ -248,6 +248,7 @@ def __init_subclass__(cls, **kwargs):
cls.__init__ = _support_legacy_strategy_positional_args(
cls.__init__, cls.__name__
)
super().__init_subclass__(**kwargs)

# we need this only for type checking
PLUGIN_CLASS = BasePlugin
39 changes: 19 additions & 20 deletions avalanche/training/templates/base_sgd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import Any, Callable, Iterable, Sequence, Optional, TypeVar, Union
from typing import Any, Callable, Generic, Iterable, Sequence, Optional, TypeVar, Union
from typing_extensions import TypeAlias
from packaging.version import parse

@@ -22,18 +22,20 @@
collate_from_data_or_kwargs,
)

from avalanche.training.templates.strategy_mixin_protocol import SGDStrategyProtocol
from avalanche.training.templates.strategy_mixin_protocol import (
CriterionType,
SGDStrategyProtocol,
)
from avalanche.training.utils import trigger_plugins


TDatasetExperience = TypeVar("TDatasetExperience", bound=DatasetExperience)
TMBInput = TypeVar("TMBInput")
TMBOutput = TypeVar("TMBOutput")

CriterionType: TypeAlias = Union[Module, Callable[[Tensor, Tensor], Tensor]]


class BaseSGDTemplate(
Generic[TDatasetExperience, TMBInput, TMBOutput],
SGDStrategyProtocol[TDatasetExperience, TMBInput, TMBOutput],
BaseTemplate[TDatasetExperience],
):
@@ -93,21 +95,6 @@ def __init__(
`eval_every` epochs or iterations (Default='epoch').
"""

super().__init__(
model=model,
optimizer=optimizer,
criterion=criterion,
train_mb_size=train_mb_size,
train_epochs=train_epochs,
eval_mb_size=eval_mb_size,
device=device,
plugins=plugins,
evaluator=evaluator,
eval_every=eval_every,
peval_mode=peval_mode,
**kwargs
)

# Call super with all args
if sys.version_info >= (3, 11):
super().__init__(
@@ -127,7 +114,19 @@ def __init__(
else:
super().__init__() # type: ignore
BaseTemplate.__init__(
self=self, model=model, device=device, plugins=plugins
self,
model=model,
optimizer=optimizer,
criterion=criterion,
train_mb_size=train_mb_size,
train_epochs=train_epochs,
eval_mb_size=eval_mb_size,
device=device,
plugins=plugins,
evaluator=evaluator,
eval_every=eval_every,
peval_mode=peval_mode,
**kwargs
)

self.optimizer: Optimizer = optimizer
10 changes: 4 additions & 6 deletions avalanche/training/templates/common_templates.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import sys
from typing import Any, Callable, Dict, Sequence, Optional, TypeVar, Union
import warnings
from typing import Callable, Sequence, Optional, TypeVar, Union
import torch
import inspect

from torch.nn import Module, CrossEntropyLoss
from torch.optim import Optimizer
@@ -13,8 +11,8 @@
EvaluationPlugin,
default_evaluator,
)
from avalanche.training.templates.base import PositionalArgumentDeprecatedWarning
from avalanche.training.templates.strategy_mixin_protocol import (
CriterionType,
SupervisedStrategyProtocol,
TMBOutput,
TMBInput,
@@ -23,7 +21,7 @@
from .observation_type import *
from .problem_type import *
from .update_type import *
from .base_sgd import BaseSGDTemplate, CriterionType
from .base_sgd import BaseSGDTemplate


TDatasetExperience = TypeVar("TDatasetExperience", bound=DatasetExperience)
@@ -202,7 +200,7 @@ def __init__(
*,
model: Module,
optimizer: Optimizer,
criterion=CrossEntropyLoss(),
criterion: CriterionType = CrossEntropyLoss(),
train_mb_size: int = 1,
train_epochs: int = 1,
eval_mb_size: Optional[int] = 1,
10 changes: 7 additions & 3 deletions avalanche/training/templates/strategy_mixin_protocol.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Iterable, List, Optional, TypeVar, Protocol
from typing import Generic, Iterable, List, Optional, TypeVar, Protocol, Callable, Union
from typing_extensions import TypeAlias

from torch import Tensor
import torch
@@ -17,8 +18,10 @@
TMBInput = TypeVar("TMBInput")
TMBOutput = TypeVar("TMBOutput")

CriterionType: TypeAlias = Union[Module, Callable[[Tensor, Tensor], Tensor]]

class BaseStrategyProtocol(Protocol[TExperienceType]):

class BaseStrategyProtocol(Generic[TExperienceType], Protocol[TExperienceType]):
model: Module

device: torch.device
@@ -33,6 +36,7 @@ class BaseStrategyProtocol(Protocol[TExperienceType]):


class SGDStrategyProtocol(
Generic[TSGDExperienceType, TMBInput, TMBOutput],
BaseStrategyProtocol[TSGDExperienceType],
Protocol[TSGDExperienceType, TMBInput, TMBOutput],
):
@@ -52,7 +56,7 @@ class SGDStrategyProtocol(

loss: Tensor

_criterion: Module
_criterion: CriterionType

def forward(self) -> TMBOutput: ...

4 changes: 4 additions & 0 deletions tests/run_dist_tests.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path
import signal
import sys
import unittest
@@ -34,6 +35,9 @@ def get_distributed_test_cases(suite: Union[TestCase, TestSuite]) -> Set[str]:
@click.command()
@click.argument("test_cases", nargs=-1)
def run_distributed_suites(test_cases):
if Path.cwd().name != "tests":
os.chdir(Path.cwd() / "tests")

cases_names = get_distributed_test_cases(
unittest.defaultTestLoader.discover(".")
) # Don't change the path!