Skip to content

Commit

Permalink
Merge pull request #1550 from lrzpellegrini/bic_fix
Browse files Browse the repository at this point in the history
Major BiC fix/reimplementation
  • Loading branch information
AntonioCarta authored Jan 25, 2024
2 parents 30bb612 + 96838ed commit 40119a0
Show file tree
Hide file tree
Showing 4 changed files with 288 additions and 104 deletions.
18 changes: 11 additions & 7 deletions avalanche/models/bic_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Iterable, SupportsInt
import torch


Expand All @@ -10,25 +11,28 @@ class BiasLayer(torch.nn.Module):
Recognition. 2019"
"""

def __init__(self, device, clss):
def __init__(self, clss: Iterable[SupportsInt]):
"""
:param device: device used by the main model. 'cpu' or 'cuda'
:param clss: list of classes of the current layer. This are use
to identify the columns which are multiplied by the Bias
correction Layer.
"""
super().__init__()
self.alpha = torch.nn.Parameter(torch.ones(1, device=device))
self.beta = torch.nn.Parameter(torch.zeros(1, device=device))
self.alpha = torch.nn.Parameter(torch.ones(1))
self.beta = torch.nn.Parameter(torch.zeros(1))

self.clss = torch.Tensor(list(clss)).long().to(device)
self.not_clss = None
unique_classes = list(sorted(set(int(x) for x in clss)))

self.register_buffer("clss", torch.tensor(unique_classes, dtype=torch.long))

def forward(self, x):
alpha = torch.ones_like(x)
beta = torch.ones_like(x)
beta = torch.zeros_like(x)

alpha[:, self.clss] = self.alpha
beta[:, self.clss] = self.beta

return alpha * x + beta


__all__ = ["BiasLayer"]
Loading

0 comments on commit 40119a0

Please sign in to comment.