Skip to content

Commit

Permalink
fix DynamicOptimizer initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonioCarta committed Apr 30, 2024
1 parent 447f168 commit 876140c
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 13 deletions.
34 changes: 26 additions & 8 deletions avalanche/models/dynamic_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,19 @@ class DynamicOptimizer(Adaptable):
agent.pre_adapt(experience)
"""

def __init__(self, optim):
def __init__(self, optim, model, reset_state=False, verbose=False):
self.optim = optim
self.reset_state = reset_state
self.verbose = verbose

# initialize param-id
self._optimized_param_id = update_optimizer(
self.optim,
dict(model.named_parameters()),
None,
reset_state=True,
verbose=self.verbose,
)

def zero_grad(self):
self.optim.zero_grad()
Expand All @@ -78,10 +89,12 @@ def step(self):

def pre_adapt(self, agent: Agent, exp: CLExperience):
"""Adapt the optimizer before training on the current experience."""
update_optimizer(
self._optimized_param_id = update_optimizer(
self.optim,
new_params=dict(agent.model.named_parameters()),
optimized_params=dict(agent.model.named_parameters()),
dict(agent.model.named_parameters()),
self._optimized_param_id,
reset_state=self.reset_state,
verbose=self.verbose,
)


Expand Down Expand Up @@ -329,16 +342,22 @@ def update_optimizer(
Newly added parameters are added by default to parameter group 0
:param new_params: Dict (name, param) of new parameters
WARNING: the first call to `update_optimizer` must be done before
calling the model's adaptation.
:param optimizer: the Optimizer object.
:param new_params: Dict (name, param) of new parameters.
:param optimized_params: Dict (name, param) of
currently optimized parameters
currently optimized parameters. In most use cases, it will be `None in
the first call and the return value of the last `update_optimizer` call
for the subsequent calls.
:param reset_state: Whether to reset the optimizer's state (i.e momentum).
Defaults to False.
:param remove_params: Whether to remove parameters that were in the optimizer
but are not found in new parameters. For safety reasons,
defaults to False.
:param verbose: If True, prints information about inferred
parameter groups for new params
parameter groups for new params.
:return: Dict (name, param) of optimized parameters
"""
Expand Down Expand Up @@ -381,7 +400,6 @@ def update_optimizer(
new_p = new_params[key]
group = param_structure[key].single_group
optimizer.param_groups[group]["params"].append(new_p)
optimized_params[key] = new_p
optimizer.state[new_p] = {}

if reset_state:
Expand Down
100 changes: 98 additions & 2 deletions tests/models/test_dynamic_optimizers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import unittest

from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader

from avalanche.core import Agent
from avalanche.models import SimpleMLP, as_multitask
from avalanche.models import SimpleMLP, as_multitask, IncrementalClassifier, MTSimpleMLP
from avalanche.models.dynamic_modules import avalanche_model_adaptation
from avalanche.models.dynamic_optimizers import DynamicOptimizer
from avalanche.training import MaskedCrossEntropy
from tests.unit_tests_utils import get_fast_benchmark
Expand All @@ -17,7 +19,7 @@ def test_dynamic_optimizer(self):
agent.loss = MaskedCrossEntropy()
agent.model = as_multitask(SimpleMLP(input_size=6), "classifier")
opt = SGD(agent.model.parameters(), lr=0.001)
agent.opt = DynamicOptimizer(opt)
agent.opt = DynamicOptimizer(opt, agent.model, verbose=False)

for exp in bm.train_stream:
agent.model.train()
Expand All @@ -32,3 +34,97 @@ def test_dynamic_optimizer(self):
l.backward()
agent.opt.step()
agent.post_adapt(exp)

def init_scenario(self, multi_task=False):
if multi_task:
model = MTSimpleMLP(input_size=6, hidden_size=10)
else:
model = SimpleMLP(input_size=6, hidden_size=10)
model.classifier = IncrementalClassifier(10, 1)
criterion = CrossEntropyLoss()
benchmark = get_fast_benchmark(use_task_labels=multi_task)
return model, criterion, benchmark

def _is_param_in_optimizer(self, param, optimizer):
for group in optimizer.param_groups:
for curr_p in group["params"]:
if hash(curr_p) == hash(param):
return True
return False

def _is_param_in_optimizer_group(self, param, optimizer):
for group_idx, group in enumerate(optimizer.param_groups):
for curr_p in group["params"]:
if hash(curr_p) == hash(param):
return group_idx
return None

def test_optimizer_groups_clf_til(self):
"""
Tests the automatic assignation of new
MultiHead parameters to the optimizer
"""
model, criterion, benchmark = self.init_scenario(multi_task=True)

g1 = []
g2 = []
for n, p in model.named_parameters():
if "classifier" in n:
g1.append(p)
else:
g2.append(p)

agent = Agent()
agent.model = model
optimizer = SGD([{"params": g1, "lr": 0.1}, {"params": g2, "lr": 0.05}])
agent.optimizer = DynamicOptimizer(optimizer, model=model, verbose=False)

for experience in benchmark.train_stream:
avalanche_model_adaptation(model, experience)
agent.optimizer.pre_adapt(agent, experience)

for n, p in model.named_parameters():
assert self._is_param_in_optimizer(p, agent.optimizer.optim)
if "classifier" in n:
self.assertEqual(
self._is_param_in_optimizer_group(p, agent.optimizer.optim), 0
)
else:
self.assertEqual(
self._is_param_in_optimizer_group(p, agent.optimizer.optim), 1
)

def test_optimizer_groups_clf_cil(self):
"""
Tests the automatic assignation of new
IncrementalClassifier parameters to the optimizer
"""
model, criterion, benchmark = self.init_scenario(multi_task=False)

g1 = []
g2 = []
for n, p in model.named_parameters():
if "classifier" in n:
g1.append(p)
else:
g2.append(p)

agent = Agent()
agent.model = model
optimizer = SGD([{"params": g1, "lr": 0.1}, {"params": g2, "lr": 0.05}])
agent.optimizer = DynamicOptimizer(optimizer, model)

for experience in benchmark.train_stream:
avalanche_model_adaptation(model, experience)
agent.optimizer.pre_adapt(agent, experience)

for n, p in model.named_parameters():
assert self._is_param_in_optimizer(p, agent.optimizer.optim)
if "classifier" in n:
self.assertEqual(
self._is_param_in_optimizer_group(p, agent.optimizer.optim), 0
)
else:
self.assertEqual(
self._is_param_in_optimizer_group(p, agent.optimizer.optim), 1
)
6 changes: 3 additions & 3 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,10 @@ def test_recursive_adaptation(self):
sizes = {}
for t, exp in enumerate(benchmark.train_stream):
sizes[t] = np.max(exp.classes_in_this_experience) + 1
model1.pre_adapt(exp)
model1.pre_adapt(None, exp)
# Second adaptation should not change anything
for t, exp in enumerate(benchmark.train_stream):
model1.pre_adapt(exp)
model1.pre_adapt(None, exp)
for t, s in sizes.items():
self.assertEqual(s, model1.classifiers[str(t)].classifier.out_features)

Expand All @@ -385,7 +385,7 @@ def test_recursive_loop(self):
model2.layer2 = model1

benchmark = get_fast_benchmark(use_task_labels=True, shuffle=True)
model1.pre_adapt(benchmark.train_stream[0])
model1.pre_adapt(None, benchmark.train_stream[0])

def test_multi_head_classifier_masking(self):
benchmark = get_fast_benchmark(use_task_labels=True, shuffle=True)
Expand Down

0 comments on commit 876140c

Please sign in to comment.