Skip to content

Commit

Permalink
Merge branch 'master' into checkpoint_improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonioCarta authored Jan 25, 2024
2 parents 87c0932 + 74ce960 commit cddcde9
Show file tree
Hide file tree
Showing 27 changed files with 311 additions and 149 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/environment-update.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ jobs:
- name: install conda environment
run: |
mamba create -n avalanche-env -y -v python=${{ matrix.python-version }} -c conda-forge &&
conda run -n avalanche-env --no-capture-output mamba install -y -v pytorch torchvision cpuonly -c pytorch &&
conda run -n avalanche-env --no-capture-output mamba env update --file environment.yml -v
conda run -n avalanche-env --no-capture-output pip install -r requirements.txt
- name: python unit test
id: unittest
env:
Expand Down
3 changes: 1 addition & 2 deletions avalanche/benchmarks/classic/openloris.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
a number of configuration parameters."""

from pathlib import Path
from typing import Union, Any, Optional
from typing_extensions import Literal
from typing import Union, Any, Optional, Literal

from avalanche.benchmarks.classic.classic_benchmarks_utils import (
check_vision_benchmark,
Expand Down
4 changes: 1 addition & 3 deletions avalanche/benchmarks/classic/stream51.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
# Website: www.continualai.org #
################################################################################
from pathlib import Path
from typing import List, Optional, Union

from typing_extensions import Literal
from typing import List, Optional, Union, Literal

from avalanche.benchmarks.datasets import Stream51
from avalanche.benchmarks.scenarios.deprecated.generic_benchmark_creation import (
Expand Down
3 changes: 1 addition & 2 deletions avalanche/benchmarks/datasets/lvis_dataset/lvis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
""" LVIS PyTorch Object Detection Dataset """

from pathlib import Path
from typing import Optional, Union, List, Sequence
import dill
from typing import Optional, Union, List, Sequence, TypedDict

import torch
from PIL import Image
from torchvision.datasets.folder import default_loader
from torchvision.transforms import ToTensor
from typing_extensions import TypedDict

from avalanche.benchmarks.datasets import (
DownloadableDataset,
Expand Down
3 changes: 1 addition & 2 deletions avalanche/benchmarks/datasets/mini_imagenet/mini_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,9 @@
import glob
import dill
from pathlib import Path
from typing import Union, List, Tuple, Dict
from typing import Union, List, Tuple, Dict, Literal

from torchvision.datasets.folder import default_loader
from typing_extensions import Literal

import PIL
import numpy as np
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
Iterator,
TypeVar,
Union,
overload,
)
from typing_extensions import overload
from avalanche.benchmarks.utils.data import AvalancheDataset
from avalanche.benchmarks.utils.dataset_utils import manage_advanced_indexing

Expand Down
2 changes: 1 addition & 1 deletion avalanche/benchmarks/scenarios/generic_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
Union,
Generic,
overload,
final,
)
from typing_extensions import final

import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion avalanche/benchmarks/scenarios/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
TypeVar,
Union,
Protocol,
Literal,
)
from typing_extensions import Literal
import warnings
from avalanche.benchmarks.utils.data import AvalancheDataset
from avalanche.benchmarks.utils.utils import concat_datasets
Expand Down
4 changes: 0 additions & 4 deletions avalanche/benchmarks/utils/collate_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,19 @@
################################################################################

import itertools
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import (
List,
TypeVar,
Generic,
Sequence,
Tuple,
Dict,
Union,
overload,
)
from typing_extensions import TypeAlias

import torch
from torch import Tensor
from torch.utils.data import default_collate

BatchT = TypeVar("BatchT")
ExampleT = TypeVar("ExampleT")
Expand Down
3 changes: 1 addition & 2 deletions avalanche/benchmarks/utils/dataset_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
# Website: avalanche.continualai.org #
################################################################################

from typing import TypeVar, SupportsInt, Sequence
from typing import TypeVar, SupportsInt, Sequence, Protocol

from torch.utils.data.dataset import Dataset
from typing_extensions import Protocol

T_co = TypeVar("T_co", covariant=True)
TTargetType = TypeVar("TTargetType")
Expand Down
3 changes: 1 addition & 2 deletions avalanche/benchmarks/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from abc import ABC, abstractmethod
import bisect
import copy
from typing import Iterator, overload
from typing_extensions import final
from typing import Iterator, overload, final
import numpy as np
from numpy import ndarray
from torch import Tensor
Expand Down
2 changes: 1 addition & 1 deletion avalanche/benchmarks/utils/transform_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
Union,
Callable,
Sequence,
Protocol,
)
from typing_extensions import Protocol

from avalanche.benchmarks.utils.transforms import (
MultiParamCompose,
Expand Down
3 changes: 1 addition & 2 deletions avalanche/distributed/distributed_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
import pickle
import warnings
from io import BytesIO
from typing import ContextManager, Optional, List, Any, Iterable, Dict, TypeVar
from typing import ContextManager, Optional, List, Any, Iterable, Dict, TypeVar, Literal

import torch
from torch import Tensor
from torch.nn.modules import Module
from torch.nn.parallel import DistributedDataParallel
from typing_extensions import Literal
from torch.distributed import init_process_group, broadcast_object_list


Expand Down
3 changes: 2 additions & 1 deletion avalanche/evaluation/metric_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
List,
Union,
overload,
Literal,
Protocol,
)
from typing_extensions import Literal, Protocol
from .metric_results import MetricValue, MetricType, AlternativeValues
from .metric_utils import (
get_metric_name,
Expand Down
2 changes: 1 addition & 1 deletion avalanche/evaluation/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
################################################################################
from matplotlib.figure import Figure
from numpy import arange
from typing_extensions import Literal
from typing import (
Any,
Callable,
Expand All @@ -19,6 +18,7 @@
Optional,
TYPE_CHECKING,
List,
Literal,
)

import wandb
Expand Down
2 changes: 1 addition & 1 deletion avalanche/evaluation/metrics/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Callable,
Sequence,
Optional,
Protocol,
)

from avalanche.benchmarks.utils.data import AvalancheDataset
Expand All @@ -41,7 +42,6 @@
from json import JSONEncoder

from torch.utils.data import Subset, ConcatDataset
from typing_extensions import Protocol

from avalanche.evaluation import PluginMetric
from avalanche.evaluation.metric_results import MetricValue
Expand Down
4 changes: 1 addition & 3 deletions avalanche/evaluation/metrics/images_samples.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, TYPE_CHECKING, Tuple
from typing import List, TYPE_CHECKING, Tuple, Literal

from torch import Tensor
from torch.utils.data import DataLoader
Expand All @@ -15,8 +15,6 @@
)
from avalanche.evaluation.metric_utils import get_metric_name

from typing_extensions import Literal


if TYPE_CHECKING:
from avalanche.training.templates import SupervisedTemplate
Expand Down
3 changes: 1 addition & 2 deletions avalanche/evaluation/metrics/labels_repartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
List,
Counter,
overload,
Literal,
)

from matplotlib.figure import Figure
Expand All @@ -19,8 +20,6 @@
default_history_repartition_image_creator,
)

from typing_extensions import Literal


if TYPE_CHECKING:
from avalanche.training.templates import SupervisedTemplate
Expand Down
4 changes: 1 addition & 3 deletions avalanche/evaluation/metrics/mean_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"""
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Callable, Dict, Set, TYPE_CHECKING, List, Optional, TypeVar
from typing import Callable, Dict, Set, TYPE_CHECKING, List, Optional, TypeVar, Literal

import torch
from matplotlib.axes import Axes
Expand All @@ -26,8 +26,6 @@
from avalanche.evaluation.metric_results import MetricValue, AlternativeValues


from typing_extensions import Literal

if TYPE_CHECKING:
from avalanche.training.templates import SupervisedTemplate
from avalanche.evaluation.metric_results import MetricResult
Expand Down
18 changes: 11 additions & 7 deletions avalanche/models/bic_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Iterable, SupportsInt
import torch


Expand All @@ -10,25 +11,28 @@ class BiasLayer(torch.nn.Module):
Recognition. 2019"
"""

def __init__(self, device, clss):
def __init__(self, clss: Iterable[SupportsInt]):
"""
:param device: device used by the main model. 'cpu' or 'cuda'
:param clss: list of classes of the current layer. This are use
to identify the columns which are multiplied by the Bias
correction Layer.
"""
super().__init__()
self.alpha = torch.nn.Parameter(torch.ones(1, device=device))
self.beta = torch.nn.Parameter(torch.zeros(1, device=device))
self.alpha = torch.nn.Parameter(torch.ones(1))
self.beta = torch.nn.Parameter(torch.zeros(1))

self.clss = torch.Tensor(list(clss)).long().to(device)
self.not_clss = None
unique_classes = list(sorted(set(int(x) for x in clss)))

self.register_buffer("clss", torch.tensor(unique_classes, dtype=torch.long))

def forward(self, x):
alpha = torch.ones_like(x)
beta = torch.ones_like(x)
beta = torch.zeros_like(x)

alpha[:, self.clss] = self.alpha
beta[:, self.clss] = self.beta

return alpha * x + beta


__all__ = ["BiasLayer"]
Loading

0 comments on commit cddcde9

Please sign in to comment.