From e2c22e98c249ff054a52b8d616f02a917cbb338d Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Wed, 13 Mar 2024 09:10:57 +0100 Subject: [PATCH 01/18] add seed argument to split_online_stream --- avalanche/benchmarks/scenarios/online.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/avalanche/benchmarks/scenarios/online.py b/avalanche/benchmarks/scenarios/online.py index b88c427d4..6f692975b 100644 --- a/avalanche/benchmarks/scenarios/online.py +++ b/avalanche/benchmarks/scenarios/online.py @@ -167,6 +167,7 @@ def __init__( shuffle: bool = True, drop_last: bool = False, access_task_boundaries: bool = False, + seed: int = None, ) -> None: """Returns a lazy stream generated by splitting an experience into smaller ones. @@ -181,7 +182,8 @@ def __init__( :param experience_size: The experience size (number of instances). :param shuffle: If True, instances will be shuffled before splitting. :param drop_last: If True, the last mini-experience will be dropped if - not of size `experience_size` + not of size `experience_size`. + :param seed: random seed for shuffling the data if `shuffle == True`. :return: The list of datasets that will be used to create the mini-experiences. """ @@ -190,10 +192,12 @@ def __init__( self.shuffle = shuffle self.drop_last = drop_last self.access_task_boundaries = access_task_boundaries + self.seed = seed # we need to fix the seed because repeated calls to the generator # must return the same order every time. - self.seed = random.randint(0, 2**32 - 1) + if seed is None: + self.seed = random.randint(0, 2**32 - 1) def __iter__(self) -> Generator[OnlineCLExperience, None, None]: exp_dataset = self.experience.dataset @@ -250,6 +254,7 @@ def _default_online_split( access_task_boundaries: bool, exp: DatasetExperience, size: int, + seed: int ): return FixedSizeExperienceSplitter( experience=exp, @@ -257,6 +262,7 @@ def _default_online_split( shuffle=shuffle, drop_last=drop_last, access_task_boundaries=access_task_boundaries, + seed=seed ) @@ -272,6 +278,7 @@ def split_online_stream( ] ] = None, access_task_boundaries: bool = False, + seed: int = None, ) -> CLStream[DatasetExperience[TCLDataset]]: """Split a stream of large batches to create an online stream of small mini-batches. @@ -300,6 +307,7 @@ def split_online_stream( A good starting to understand the mechanism is to look at the implementation of the standard splitting function :func:`fixed_size_experience_split`. + :param seed: random seed used for shuffling by the default splitter. :return: A lazy online stream with experiences of size `experience_size`. """ @@ -308,7 +316,8 @@ def split_online_stream( # However, MyPy does not understand what a partial is -_- def default_online_split_wrapper(e, e_sz): return _default_online_split( - shuffle, drop_last, access_task_boundaries, e, e_sz + shuffle, drop_last, access_task_boundaries, e, e_sz, + seed=seed ) split_strategy = default_online_split_wrapper From aed48cc4bf0ae7eb1047f656b31350697ada9478 Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Wed, 13 Mar 2024 16:56:48 +0100 Subject: [PATCH 02/18] metrics as updatable objects + metric collector and plot utilities --- .../scenarios/validation_scenario.py | 1 + avalanche/evaluation/collector.py | 45 +++++++ avalanche/evaluation/functional.py | 33 +++++ avalanche/evaluation/plot_utils.py | 34 +++++ examples/updatable_objects.py | 117 ++++++++++++++++++ tests/evaluation/test_functional.py | 45 +++++++ tests/evaluation/test_plots.py | 23 ++++ 7 files changed, 298 insertions(+) create mode 100644 avalanche/evaluation/collector.py create mode 100644 avalanche/evaluation/functional.py create mode 100644 examples/updatable_objects.py create mode 100644 tests/evaluation/test_functional.py create mode 100644 tests/evaluation/test_plots.py diff --git a/avalanche/benchmarks/scenarios/validation_scenario.py b/avalanche/benchmarks/scenarios/validation_scenario.py index 74d3bb2ad..d3ad9e4ce 100644 --- a/avalanche/benchmarks/scenarios/validation_scenario.py +++ b/avalanche/benchmarks/scenarios/validation_scenario.py @@ -96,6 +96,7 @@ def random_validation_split_strategy_wrapper(data): # don't drop classes-timeline for compatibility with old API e0 = next(iter(train_stream)) + if hasattr(e0, "dataset") and hasattr(e0.dataset, "targets"): train_stream = with_classes_timeline(train_stream) valid_stream = with_classes_timeline(valid_stream) diff --git a/avalanche/evaluation/collector.py b/avalanche/evaluation/collector.py new file mode 100644 index 000000000..09fdb8162 --- /dev/null +++ b/avalanche/evaluation/collector.py @@ -0,0 +1,45 @@ +import numpy as np + +from avalanche._annotations import experimental + + +@experimental() +class MetricCollector: + def __init__(self, stream): + self.metrics_res = None + + self._stream_len = 0 + self._coeffs = [] + for exp in stream: + self._stream_len += 1 + self._coeffs.append(len(exp.dataset)) + self._coeffs = np.array(self._coeffs) + self._coeffs = self._coeffs / self._coeffs.sum() + + def update(self, res): + if self.metrics_res is None: + self.metrics_res = {} + for k, v in res.items(): + if len(v) != self._stream_len: + raise ValueError(f"Length does not correspond to stream. " + f"Found {len(v)}, expected {self._stream_len}") + self.metrics_res[k] = [v] + else: + for k, v in res.items(): + self.metrics_res[k].append(v) + + def get(self, name, time_reduce=None, exp_reduce=None): + assert time_reduce in {None} # 'last', 'average' + assert exp_reduce in {None, "sample_mean", "experience_mean"} + assert name in self.metrics_res + + mvals = np.array(self.metrics_res[name]) + if exp_reduce is None: + return mvals + elif exp_reduce == "sample_mean": + mvals = mvals * self._coeffs[None, :] # weight by num.samples + return mvals.sum(axis=1) # weighted avg across exp. + elif exp_reduce == "experience_means": + return mvals.mean(axis=1) # avg across exp. + else: + raise ValueError("BUG. It should never get here.") diff --git a/avalanche/evaluation/functional.py b/avalanche/evaluation/functional.py new file mode 100644 index 000000000..fbe24b5d9 --- /dev/null +++ b/avalanche/evaluation/functional.py @@ -0,0 +1,33 @@ +import numpy as np + + +def forgetting(accuracy_matrix, boundary_indices=None): + """Forgetting. + + Given an experience `k` learned at time `boundary_indices[k]`, + forgetting is the `accuracy_matrix[t, k] - accuracy_matrix[t, boundary_indices[k]]`. + Forgetting is set to zero before learning on the experience. + + :param accuracy_matrix: 2D accuracy matrix with shape + :param boundary_indices: time index for each experience, corresponding to + the time after the experience was learned. Optional if + `accuracy_matrix` is a square matrix, which is the most common case. + In this setting, `boundary_indices[k] = k`. + :return: + """ + accuracy_matrix = np.array(accuracy_matrix) + forgetting_matrix = np.zeros_like(accuracy_matrix) + + if boundary_indices is None: + assert accuracy_matrix.shape[0] == accuracy_matrix.shape[1] + boundary_indices = list(range(accuracy_matrix.shape[0])) + + for k in range(accuracy_matrix.shape[1]): + t_task = boundary_indices[k] + acc_first = accuracy_matrix[t_task][k] + + # before learning exp k, forgetting is zero + forgetting_matrix[:t_task+1, k] = 0 + # then, it's acc_first - acc[t, k] + forgetting_matrix[t_task+1:, k] = acc_first - accuracy_matrix[t_task+1:, k] + return forgetting_matrix diff --git a/avalanche/evaluation/plot_utils.py b/avalanche/evaluation/plot_utils.py index 44a05ee6d..0a09ef0a0 100644 --- a/avalanche/evaluation/plot_utils.py +++ b/avalanche/evaluation/plot_utils.py @@ -9,6 +9,8 @@ # Website: avalanche.continualai.org # ################################################################################ import matplotlib.pyplot as plt +import numpy as np +from matplotlib import pyplot as plt def learning_curves_plot(all_metrics: dict): @@ -28,3 +30,35 @@ def learning_curves_plot(all_metrics: dict): ax.set_xlabel("Iterations") ax.set_ylabel("Experience Accuracy") return fig + + +def plot_metric_matrix( + metric_matrix, + title, + *, + ax=None, +): + """Plot a matrix of metrics (e.g. forgetting over time). + + :param metric_matrix: 2D accuracy matrix with shape + :param title: plot title + :param ax: axes to use for figure + + :return: a matplotlib.Figure + """ + if ax is None: + fig, ax = plt.subplots() + else: + fig = ax.figure + + metric_matrix = np.array(metric_matrix).T + ax.matshow(metric_matrix) + ax.set_ylabel("Eval Experience") + ax.set_xlabel("Time") + ax.set_title(title) + + for i in range(len(metric_matrix)): + for j in range(len(metric_matrix[0])): + ax.text(j, i, f"{metric_matrix[i][j]:.3f}", + ha="center", va="center", color="w") + return fig diff --git a/examples/updatable_objects.py b/examples/updatable_objects.py new file mode 100644 index 000000000..e314a52cd --- /dev/null +++ b/examples/updatable_objects.py @@ -0,0 +1,117 @@ +from types import SimpleNamespace + +import torch +from torch.optim import SGD +from torch.optim.lr_scheduler import ExponentialLR +from torch.utils.data import DataLoader + +from avalanche.benchmarks import SplitMNIST, with_classes_timeline +from avalanche.benchmarks.scenarios import split_online_stream +from avalanche.benchmarks.utils.data_loader import ReplayDataLoader + +from avalanche.evaluation.collector import MetricCollector +from avalanche.evaluation.functional import forgetting +from avalanche.evaluation.metrics import Accuracy +from avalanche.evaluation.plot_utils import plot_metric_matrix +from avalanche.models import SimpleMLP +from avalanche.training import ReservoirSamplingBuffer +from avalanche.training.losses import MaskedCrossEntropy + + +def train_experience(agent_state, exp, epochs=10): + # defining the training for a single exp. instead of the whole stream + # should be better because this way training does not need to care about eval loop. + # agent_state takes the role of the strategy object + + # this is the usual before_exp + # updatable_objs = [agent_state.replay, agent_state.loss, agent_state.model] + # [uo.pre_update(exp, agent_state) for uo in updatable_objs] + agent_state.loss.before_training_exp(SimpleNamespace(experience=exp)) + # agent_state.model.recursive_adaptation(exp) + + for ep in range(epochs): + agent_state.model.train() + if len(agent_state.replay.buffer) > 0: + dl = ReplayDataLoader(exp.dataset, agent_state.replay.buffer, + batch_size=32, shuffle=True) + else: + dl = DataLoader(exp.dataset, batch_size=32, shuffle=True) + + for x, y, _ in dl: + x, y = x.cuda(), y.cuda() + agent_state.opt.zero_grad() + yp = agent_state.model(x) + l = agent_state.loss(yp, y) + # l += agent_state.reg_loss() + l.backward() + agent_state.opt.step() + + # this is the usual after_exp + # updatable_objs = [agent_state.replay, agent_state.loss, agent_state.model] + # [uo.post_update(exp, agent_state) for uo in updatable_objs] + agent_state.replay.update(SimpleNamespace(experience=exp)) + agent_state.scheduler.step() + return agent_state + + +@torch.no_grad() +def my_eval(model, stream, metrics, **extra_args): + # eval also becomes simpler. Notice how in Avalanche it's harder to check whether + # we are evaluating a single exp. or the whole stream. + # Now we evaluate each stream with a separate function call + + [uo.reset() for uo in metrics] + res = {uo.__class__.__name__: [] for uo in metrics} + for exp in stream: + dl = DataLoader(exp.dataset, batch_size=512, num_workers=8) + for x, y, _ in dl: + x, y = x.cuda(), y.cuda() + yp = model(x) + [uo.update(yp, y) for uo in metrics] + [res[uo.__class__.__name__].append(uo.result()) for uo in metrics] + return res + + +if __name__ == '__main__': + bm = SplitMNIST(n_experiences=5) + train_stream, test_stream = bm.train_stream, bm.test_stream + ocl_stream = split_online_stream(train_stream, experience_size=256, seed=1234) + ocl_stream = with_classes_timeline(ocl_stream) + + # agent state collects all the objects that are needed during training + # many of these objects will have some state that is updated during training. + # The training function returns the updated agent state at each step. + agent_state = SimpleNamespace() + agent_state.replay = ReservoirSamplingBuffer(max_size=200) + agent_state.loss = MaskedCrossEntropy() + + # model + agent_state.model = SimpleMLP(num_classes=10).cuda() + + # optim and scheduler + agent_state.opt = SGD(agent_state.model.parameters(), lr=0.001) + agent_state.scheduler = ExponentialLR(agent_state.opt, gamma=0.999) + + print("Start training...") + metrics = [Accuracy()] + mc_test = MetricCollector(test_stream) + for exp in ocl_stream: + print(exp.classes_in_this_experience, end=' ') + agent_state = train_experience(agent_state, exp, epochs=2) + if exp.logging().is_last_subexp: + print() + res = my_eval(agent_state.model, test_stream, metrics) + print("EVAL: ", res) + mc_test.update(res) + + acc_timeline = mc_test.get("Accuracy", exp_reduce="sample_mean") + print(mc_test.get("Accuracy", exp_reduce=None)) + print(acc_timeline) + + acc_matrix = mc_test.get("Accuracy", exp_reduce=None) + fig = plot_metric_matrix(acc_matrix, title="Accuracy - Train") + fig.savefig("accuracy.png") + + fm = forgetting(acc_matrix) + fig = plot_metric_matrix(fm, title="Forgetting - Train") + fig.savefig("forgetting.png") diff --git a/tests/evaluation/test_functional.py b/tests/evaluation/test_functional.py new file mode 100644 index 000000000..650fe1065 --- /dev/null +++ b/tests/evaluation/test_functional.py @@ -0,0 +1,45 @@ +""" Metrics Tests""" + +import unittest + +import torch + +from avalanche.evaluation.functional import forgetting +import numpy as np + + +class FunctionalMetricTests(unittest.TestCase): + def test_diag_forgetting(self): + vals = [ + [1.0, 0.0, 0.0], + [0.5, 0.8, 1.0], + [0.2, 1.0, 0.5] + ] + target = [ + [0.0, 0.0, 0.0], + [0.5, 0.0, 0.0], + [0.8, -0.2, 0.0] + ] + fm = forgetting(vals) + np.testing.assert_array_almost_equal(fm, target) + + def test_non_diag_forgetting(self): + """non-diagonal matrix requires explicit boundary indices""" + vals = [ + [1.0, 0.0, 0.0], + [0.5, 0.8, 1.0], + [0.5, 0.5, 1.0], + [0.2, 1.0, 0.5] + ] + target = [ + [0.0, 0.0, 0.0], + [0.5, 0.0, 0.0], + [0.5, 0.3, 0.0], + [0.8, -0.2, 0.0] + ] + fm = forgetting(vals, boundary_indices=[0, 1, 3]) + np.testing.assert_array_almost_equal(fm, target) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/evaluation/test_plots.py b/tests/evaluation/test_plots.py new file mode 100644 index 000000000..65123d901 --- /dev/null +++ b/tests/evaluation/test_plots.py @@ -0,0 +1,23 @@ +""" Metrics Tests""" + +import unittest + +from avalanche.evaluation.plot_utils import plot_metric_matrix + + +class PlotMetricTests(unittest.TestCase): + def test_basic(self): + """just checking that the code runs without errors.""" + mvals = [ + [0.76370887, 0.43307481, 0.30893094, 0.25093633, 0.2075], + [0.74177468, 0.82286715, 0.55876494, 0.42234707, 0.341], + [0.57178465, 0.65048787, 0.7563081, 0.56941323, 0.4562], + [0.65802592, 0.52339254, 0.58947543, 0.68339576, 0.5474], + [0.65802592, 0.52339254, 0.58947543, 0.68339576, 0.5474], + [0.4995015, 0.58819114, 0.57320717, 0.62322097, 0.6925] + ] + plot_metric_matrix(mvals, title="Accuracy - Train") + + +if __name__ == "__main__": + unittest.main() From 74b4f67313a1b41b5bf1aecd487672edf3682162 Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Mon, 25 Mar 2024 18:22:52 +0100 Subject: [PATCH 03/18] ADD time_reduce and collector tests --- avalanche/core.py | 25 ++++++++++++- avalanche/evaluation/collector.py | 21 ++++++++--- avalanche/evaluation/functional.py | 1 + examples/updatable_objects.py | 9 +++-- tests/evaluation/test_collector.py | 56 ++++++++++++++++++++++++++++++ 5 files changed, 103 insertions(+), 9 deletions(-) create mode 100644 tests/evaluation/test_collector.py diff --git a/avalanche/core.py b/avalanche/core.py index 0d2364359..018aa204a 100644 --- a/avalanche/core.py +++ b/avalanche/core.py @@ -1,13 +1,36 @@ from abc import ABC -from typing import Any, TypeVar, Generic +from typing import Any, TypeVar, Generic, Protocol, runtime_checkable from typing import TYPE_CHECKING +from avalanche.benchmarks import CLExperience + if TYPE_CHECKING: from avalanche.training.templates.base import BaseTemplate Template = TypeVar("Template", bound="BaseTemplate") +class Agent(): + def __init__(self): + self.updatable_objects = [] + self.train_steps = 0 + + def __setattribute__(self): + pass + + def pre_update(self): + pass + + +@runtime_checkable +class Updatable(Protocol): + def pre_update(self, agent: Agent, exp: CLExperience): + pass + + def post_udpate(self, agent: Agent, exp: CLExperience): + pass + + class BasePlugin(Generic[Template], ABC): """ABC for BaseTemplate plugins. diff --git a/avalanche/evaluation/collector.py b/avalanche/evaluation/collector.py index 09fdb8162..29c29e52b 100644 --- a/avalanche/evaluation/collector.py +++ b/avalanche/evaluation/collector.py @@ -29,17 +29,28 @@ def update(self, res): self.metrics_res[k].append(v) def get(self, name, time_reduce=None, exp_reduce=None): - assert time_reduce in {None} # 'last', 'average' + assert time_reduce in {None, "last", "mean"} assert exp_reduce in {None, "sample_mean", "experience_mean"} assert name in self.metrics_res mvals = np.array(self.metrics_res[name]) if exp_reduce is None: - return mvals + pass # nothing to do here elif exp_reduce == "sample_mean": mvals = mvals * self._coeffs[None, :] # weight by num.samples - return mvals.sum(axis=1) # weighted avg across exp. - elif exp_reduce == "experience_means": - return mvals.mean(axis=1) # avg across exp. + mvals = mvals.sum(axis=1) # weighted avg across exp. + elif exp_reduce == "experience_mean": + mvals = mvals.mean(axis=1) # avg across exp. else: raise ValueError("BUG. It should never get here.") + + if time_reduce is None: + pass # nothing to do here + elif time_reduce == "last": + mvals = mvals[-1] # last timestep + elif time_reduce == "mean": + mvals = mvals.mean(axis=0) # avg. over time + else: + raise ValueError("BUG. It should never get here.") + + return mvals \ No newline at end of file diff --git a/avalanche/evaluation/functional.py b/avalanche/evaluation/functional.py index fbe24b5d9..f591b4a86 100644 --- a/avalanche/evaluation/functional.py +++ b/avalanche/evaluation/functional.py @@ -30,4 +30,5 @@ def forgetting(accuracy_matrix, boundary_indices=None): forgetting_matrix[:t_task+1, k] = 0 # then, it's acc_first - acc[t, k] forgetting_matrix[t_task+1:, k] = acc_first - accuracy_matrix[t_task+1:, k] + return forgetting_matrix diff --git a/examples/updatable_objects.py b/examples/updatable_objects.py index e314a52cd..4f2fd5c74 100644 --- a/examples/updatable_objects.py +++ b/examples/updatable_objects.py @@ -28,14 +28,16 @@ def train_experience(agent_state, exp, epochs=10): # [uo.pre_update(exp, agent_state) for uo in updatable_objs] agent_state.loss.before_training_exp(SimpleNamespace(experience=exp)) # agent_state.model.recursive_adaptation(exp) + # agent_state.pre_update(agent_state, exp) + data = exp.dataset.train() for ep in range(epochs): agent_state.model.train() if len(agent_state.replay.buffer) > 0: - dl = ReplayDataLoader(exp.dataset, agent_state.replay.buffer, + dl = ReplayDataLoader(data, agent_state.replay.buffer, batch_size=32, shuffle=True) else: - dl = DataLoader(exp.dataset, batch_size=32, shuffle=True) + dl = DataLoader(data, batch_size=32, shuffle=True) for x, y, _ in dl: x, y = x.cuda(), y.cuda() @@ -51,6 +53,7 @@ def train_experience(agent_state, exp, epochs=10): # [uo.post_update(exp, agent_state) for uo in updatable_objs] agent_state.replay.update(SimpleNamespace(experience=exp)) agent_state.scheduler.step() + # agent_state.post_update(agent_state, exp) return agent_state @@ -60,9 +63,9 @@ def my_eval(model, stream, metrics, **extra_args): # we are evaluating a single exp. or the whole stream. # Now we evaluate each stream with a separate function call - [uo.reset() for uo in metrics] res = {uo.__class__.__name__: [] for uo in metrics} for exp in stream: + [uo.reset() for uo in metrics] dl = DataLoader(exp.dataset, batch_size=512, num_workers=8) for x, y, _ in dl: x, y = x.cuda(), y.cuda() diff --git a/tests/evaluation/test_collector.py b/tests/evaluation/test_collector.py new file mode 100644 index 000000000..456853395 --- /dev/null +++ b/tests/evaluation/test_collector.py @@ -0,0 +1,56 @@ +""" Metrics Tests""" + +import unittest +from types import SimpleNamespace + +import torch + +from avalanche.evaluation.collector import MetricCollector +from avalanche.evaluation.functional import forgetting +import numpy as np + + +class MetricCollectorTests(unittest.TestCase): + def test_time_reduce(self): + # we just need an object with __len__ + fake_stream = [ + SimpleNamespace(dataset=list(range(1))), + SimpleNamespace(dataset=list(range(2))), + ] + # m = MockMetric([0, 1, 2, 3, 4, 5]) + mc = MetricCollector(fake_stream) + stream_of_mvals = [ + [1, 3], + [5, 7], + [11, 13] + ] + for vvv in stream_of_mvals: + mc.update({"FakeMetric": vvv}) + + # time_reduce = None + v = mc.get("FakeMetric", time_reduce=None, exp_reduce=None) + np.testing.assert_array_almost_equal(v, stream_of_mvals) + v = mc.get("FakeMetric", time_reduce=None, exp_reduce="sample_mean") + np.testing.assert_array_almost_equal(v, [(1+3*2)/3, (5+7*2)/3, (11+13*2)/3]) + v = mc.get("FakeMetric", time_reduce=None, exp_reduce="experience_mean") + np.testing.assert_array_almost_equal(v, [(1+3)/2, (5+7)/2, (11+13)/2]) + + # time = "last" + v = mc.get("FakeMetric", time_reduce="last", exp_reduce=None) + np.testing.assert_array_almost_equal(v, [11, 13]) + v = mc.get("FakeMetric", time_reduce="last", exp_reduce="sample_mean") + self.assertAlmostEqual(v, (11+13*2)/3) + v = mc.get("FakeMetric", time_reduce="last", exp_reduce="experience_mean") + self.assertAlmostEqual(v, (11+13)/2) + + # time_reduce = "mean" + v = mc.get("FakeMetric", time_reduce="mean", exp_reduce=None) + np.testing.assert_array_almost_equal(v, [(1+5+11)/3, (3+7+13)/3]) + v = mc.get("FakeMetric", time_reduce="mean", exp_reduce="sample_mean") + self.assertAlmostEqual(v, ((1+5+11)/3 + (3+7+13)/3*2) / 3) + v = mc.get("FakeMetric", time_reduce="mean", exp_reduce="experience_mean") + self.assertAlmostEqual(v, ((1+5+11)/3 + (3+7+13)/3) / 2) + + +if __name__ == "__main__": + unittest.main() From f4e3bbc9d2f54a8abcf7f55d14161629186605a7 Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Mon, 25 Mar 2024 18:40:36 +0100 Subject: [PATCH 04/18] add MetricCollector serialization methods --- avalanche/core.py | 41 +++++++++++++++++++++++-------- avalanche/evaluation/collector.py | 31 ++++++++++++++++++++++- examples/updatable_objects.py | 6 ++++- 3 files changed, 66 insertions(+), 12 deletions(-) diff --git a/avalanche/core.py b/avalanche/core.py index 018aa204a..bf6d39def 100644 --- a/avalanche/core.py +++ b/avalanche/core.py @@ -1,3 +1,4 @@ +# TODO: doc from abc import ABC from typing import Any, TypeVar, Generic, Protocol, runtime_checkable from typing import TYPE_CHECKING @@ -10,16 +11,36 @@ Template = TypeVar("Template", bound="BaseTemplate") -class Agent(): - def __init__(self): - self.updatable_objects = [] - self.train_steps = 0 - - def __setattribute__(self): - pass - - def pre_update(self): - pass +class Agent: + def __init__(self, verbose=False): + # TODO: doc + # TODO: test pre_update call + # TODO: test post_update call + self._updatable_objects = [] + self.verbose = verbose + + def __setattr__(self, name, value): + super().__setattr__(name, value) + if hasattr(value, 'pre_update') or hasattr(value, 'post_update'): + self._updatable_objects.append(value) + if self.verbose: + print("Added updatable object ", value) + + def pre_update(self, exp): + # TODO: doc + for uo in self._updatable_objects: + if hasattr(uo, 'pre_update'): + uo.pre_update(self, exp) + if self.verbose: + print("preupdate ", uo) + + def post_update(self, exp): + # TODO: doc + for uo in self._updatable_objects: + if hasattr(uo, 'post_update'): + uo.post_update(self, exp) + if self.verbose: + print("post_update ", uo) @runtime_checkable diff --git a/avalanche/evaluation/collector.py b/avalanche/evaluation/collector.py index 29c29e52b..4fdbd8d95 100644 --- a/avalanche/evaluation/collector.py +++ b/avalanche/evaluation/collector.py @@ -1,3 +1,6 @@ +# TODO: doc +import json + import numpy as np from avalanche._annotations import experimental @@ -5,7 +8,9 @@ @experimental() class MetricCollector: + # TODO: doc def __init__(self, stream): + # TODO: doc self.metrics_res = None self._stream_len = 0 @@ -17,6 +22,7 @@ def __init__(self, stream): self._coeffs = self._coeffs / self._coeffs.sum() def update(self, res): + # TODO: doc if self.metrics_res is None: self.metrics_res = {} for k, v in res.items(): @@ -29,6 +35,7 @@ def update(self, res): self.metrics_res[k].append(v) def get(self, name, time_reduce=None, exp_reduce=None): + # TODO: doc assert time_reduce in {None, "last", "mean"} assert exp_reduce in {None, "sample_mean", "experience_mean"} assert name in self.metrics_res @@ -53,4 +60,26 @@ def get(self, name, time_reduce=None, exp_reduce=None): else: raise ValueError("BUG. It should never get here.") - return mvals \ No newline at end of file + return mvals + + def get_dict(self): + # TODO: doc + # TODO: test + return self.metrics_res + + def load_dict(self, d): + # TODO: doc + # TODO: test + self.metrics_res = d + + def load_json(self, fname): + # TODO: doc + # TODO: test + with open(fname, 'w') as f: + self.metrics_res = json.load(f) + + def to_json(self, fname): + # TODO: doc + # TODO: test + with open(fname, 'w') as f: + json.dump(obj=self.metrics_res, fp=f) diff --git a/examples/updatable_objects.py b/examples/updatable_objects.py index 4f2fd5c74..36460d04f 100644 --- a/examples/updatable_objects.py +++ b/examples/updatable_objects.py @@ -8,6 +8,7 @@ from avalanche.benchmarks import SplitMNIST, with_classes_timeline from avalanche.benchmarks.scenarios import split_online_stream from avalanche.benchmarks.utils.data_loader import ReplayDataLoader +from avalanche.core import Agent from avalanche.evaluation.collector import MetricCollector from avalanche.evaluation.functional import forgetting @@ -76,6 +77,9 @@ def my_eval(model, stream, metrics, **extra_args): if __name__ == '__main__': + # TODO: add checkpointing + # TODO: add dynamic optimizers + bm = SplitMNIST(n_experiences=5) train_stream, test_stream = bm.train_stream, bm.test_stream ocl_stream = split_online_stream(train_stream, experience_size=256, seed=1234) @@ -84,7 +88,7 @@ def my_eval(model, stream, metrics, **extra_args): # agent state collects all the objects that are needed during training # many of these objects will have some state that is updated during training. # The training function returns the updated agent state at each step. - agent_state = SimpleNamespace() + agent_state = Agent() agent_state.replay = ReservoirSamplingBuffer(max_size=200) agent_state.loss = MaskedCrossEntropy() From 1dd8ebb4abbfd69fb0d20bcd086a4fdb12bd6ea6 Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Mon, 25 Mar 2024 19:00:53 +0100 Subject: [PATCH 05/18] adaptable losses --- avalanche/core.py | 24 ++++++++++---------- avalanche/evaluation/collector.py | 1 + avalanche/training/plugins/lwf.py | 2 +- avalanche/training/regularization.py | 32 ++++++++++++++++++++++++++- examples/updatable_objects.py | 5 +++-- tests/training/test_regularization.py | 5 +++-- 6 files changed, 51 insertions(+), 18 deletions(-) diff --git a/avalanche/core.py b/avalanche/core.py index bf6d39def..893eb7c48 100644 --- a/avalanche/core.py +++ b/avalanche/core.py @@ -21,34 +21,34 @@ def __init__(self, verbose=False): def __setattr__(self, name, value): super().__setattr__(name, value) - if hasattr(value, 'pre_update') or hasattr(value, 'post_update'): + if hasattr(value, 'pre_adapt') or hasattr(value, 'post_adapt'): self._updatable_objects.append(value) if self.verbose: print("Added updatable object ", value) - def pre_update(self, exp): + def pre_adapt(self, exp): # TODO: doc for uo in self._updatable_objects: - if hasattr(uo, 'pre_update'): - uo.pre_update(self, exp) + if hasattr(uo, 'pre_adapt'): + uo.pre_adapt(self, exp) if self.verbose: - print("preupdate ", uo) + print("pre_adapt ", uo) - def post_update(self, exp): + def post_adapt(self, exp): # TODO: doc for uo in self._updatable_objects: - if hasattr(uo, 'post_update'): - uo.post_update(self, exp) + if hasattr(uo, 'post_adapt'): + uo.post_adapt(self, exp) if self.verbose: - print("post_update ", uo) + print("post_adapt ", uo) @runtime_checkable -class Updatable(Protocol): - def pre_update(self, agent: Agent, exp: CLExperience): +class Adaptable(Protocol): + def pre_adapt(self, agent: Agent, exp: CLExperience): pass - def post_udpate(self, agent: Agent, exp: CLExperience): + def post_adapt(self, agent: Agent, exp: CLExperience): pass diff --git a/avalanche/evaluation/collector.py b/avalanche/evaluation/collector.py index 4fdbd8d95..f14d05ab4 100644 --- a/avalanche/evaluation/collector.py +++ b/avalanche/evaluation/collector.py @@ -11,6 +11,7 @@ class MetricCollector: # TODO: doc def __init__(self, stream): # TODO: doc + # TODO: add support for different streams for each metric self.metrics_res = None self._stream_len = 0 diff --git a/avalanche/training/plugins/lwf.py b/avalanche/training/plugins/lwf.py index 28ae33ecd..4ed23d141 100644 --- a/avalanche/training/plugins/lwf.py +++ b/avalanche/training/plugins/lwf.py @@ -31,4 +31,4 @@ def after_training_exp(self, strategy, **kwargs): Save a copy of the model after each experience and update self.prev_classes to include the newly learned classes. """ - self.lwf.update(strategy.experience, strategy.model) + self.lwf.post_adapt(strategy, strategy.experience) diff --git a/avalanche/training/regularization.py b/avalanche/training/regularization.py index 277dc1ee3..dc4ab310f 100644 --- a/avalanche/training/regularization.py +++ b/avalanche/training/regularization.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from avalanche.models import MultiTaskModule, avalanche_forward +from avalanche._annotations import deprecated def stable_softmax(x): @@ -39,9 +40,16 @@ class RegularizationMethod: of an experience. """ + @deprecated(0.7, "please switch to pre_update and post_update methods.") def update(self, *args, **kwargs): raise NotImplementedError() + def pre_adapt(self, agent, exp): + pass # implementation may be empty if adapt is not needed + + def post_adapt(self, agent, exp): + pass # implementation may be empty if adapt is not needed + def __call__(self, *args, **kwargs): raise NotImplementedError() @@ -133,6 +141,29 @@ def __call__(self, mb_x, mb_pred, model): ) return alpha * self._lwf_penalty(mb_pred, mb_x, model) + def post_adapt(self, agent, exp): + """Save a copy of the model after each experience and + update self.prev_classes to include the newly learned classes. + + :param agent: agent state + :param exp: current experience + """ + self.expcount += 1 + self.prev_model = copy.deepcopy(agent.model) + task_ids = exp.dataset.targets_task_labels.uniques + + for task_id in task_ids: + task_data = exp.dataset.task_set[task_id] + pc = set(task_data.targets.uniques) + + if task_id not in self.prev_classes_by_task: + self.prev_classes_by_task[task_id] = pc + else: + self.prev_classes_by_task[task_id] = self.prev_classes_by_task[ + task_id + ].union(pc) + + @deprecated(0.7, "switch to post_udpate method") def update(self, experience, model): """Save a copy of the model after each experience and update self.prev_classes to include the newly learned classes. @@ -140,7 +171,6 @@ def update(self, experience, model): :param experience: current experience :param model: current model """ - self.expcount += 1 self.prev_model = copy.deepcopy(model) task_ids = experience.dataset.targets_task_labels.uniques diff --git a/examples/updatable_objects.py b/examples/updatable_objects.py index 36460d04f..a33b17d42 100644 --- a/examples/updatable_objects.py +++ b/examples/updatable_objects.py @@ -15,7 +15,7 @@ from avalanche.evaluation.metrics import Accuracy from avalanche.evaluation.plot_utils import plot_metric_matrix from avalanche.models import SimpleMLP -from avalanche.training import ReservoirSamplingBuffer +from avalanche.training import ReservoirSamplingBuffer, LearningWithoutForgetting from avalanche.training.losses import MaskedCrossEntropy @@ -45,7 +45,7 @@ def train_experience(agent_state, exp, epochs=10): agent_state.opt.zero_grad() yp = agent_state.model(x) l = agent_state.loss(yp, y) - # l += agent_state.reg_loss() + l += agent_state.reg_loss(x, yp, agent_state.model) l.backward() agent_state.opt.step() @@ -91,6 +91,7 @@ def my_eval(model, stream, metrics, **extra_args): agent_state = Agent() agent_state.replay = ReservoirSamplingBuffer(max_size=200) agent_state.loss = MaskedCrossEntropy() + agent_state.reg_loss = LearningWithoutForgetting(alpha=1, temperature=2) # model agent_state.model = SimpleMLP(num_classes=10).cuda() diff --git a/tests/training/test_regularization.py b/tests/training/test_regularization.py index f407b7b2a..92a56ace9 100644 --- a/tests/training/test_regularization.py +++ b/tests/training/test_regularization.py @@ -1,4 +1,5 @@ import unittest +from types import SimpleNamespace import torch from torch.utils.data import DataLoader @@ -33,7 +34,7 @@ def test_lwf(self): assert loss == 0 else: assert loss > 0.0 - lwf.update(exp, teacher) + lwf.post_adapt(SimpleNamespace(model=teacher), exp) lwf = LearningWithoutForgetting() teacher = MTSimpleMLP(input_size=6) @@ -60,7 +61,7 @@ def test_lwf(self): assert torch.norm(weight.grad) > 0 model.zero_grad() - lwf.update(exp, teacher) + lwf.post_adapt(SimpleNamespace(model=teacher), exp) if __name__ == "__main__": From 143b3b0c1b63d1831ae42c04227a1feffd6cb484 Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Mon, 25 Mar 2024 19:18:23 +0100 Subject: [PATCH 06/18] updatable buffers --- avalanche/training/storage_policy.py | 64 ++++++++++++++-------------- 1 file changed, 31 insertions(+), 33 deletions(-) diff --git a/avalanche/training/storage_policy.py b/avalanche/training/storage_policy.py index c0df18c59..cadb176bc 100644 --- a/avalanche/training/storage_policy.py +++ b/avalanche/training/storage_policy.py @@ -24,6 +24,7 @@ ) from avalanche.models import FeatureExtractorBackbone from ..benchmarks.utils.utils import concat_datasets +from avalanche._annotations import deprecated if TYPE_CHECKING: from .templates import SupervisedTemplate, BaseSGDTemplate @@ -54,7 +55,7 @@ def buffer(self) -> AvalancheDataset: def buffer(self, new_buffer: AvalancheDataset): self._buffer = new_buffer - @abstractmethod + @deprecated(0.7, "switch to pre_adapt and post_adapt") def update(self, strategy: "SupervisedTemplate", **kwargs): """Update `self.buffer` using the `strategy` state. @@ -62,6 +63,17 @@ def update(self, strategy: "SupervisedTemplate", **kwargs): :param kwargs: :return: """ + # this should work until we deprecate self.update + self.post_adapt(strategy, strategy.experience) + + @abstractmethod + def post_adapt(self, agent_state, exp): + """Update `self.buffer` using the agent state and current experience. + + :param agent_state: + :param exp: + :return: + """ ... @abstractmethod @@ -92,11 +104,9 @@ def __init__(self, max_size: int): # INVARIANT: _buffer_weights is always sorted. self._buffer_weights = torch.zeros(0) - def update(self, strategy: "SupervisedTemplate", **kwargs): + def post_adapt(self, agent, exp): """Update buffer.""" - experience = strategy.experience - assert experience is not None - self.update_from_dataset(experience.dataset) + self.update_from_dataset(exp.dataset) def update_from_dataset(self, new_data: AvalancheDataset): """Update the buffer using the given dataset. @@ -193,16 +203,6 @@ def buffer(self, new_buffer): "You should modify `self.buffer_groups instead." ) - @abstractmethod - def update(self, strategy: "SupervisedTemplate", **kwargs): - """Update `self.buffer_groups` using the `strategy` state. - - :param strategy: - :param kwargs: - :return: - """ - ... - def resize(self, strategy, new_size): """Update the maximum size of the buffers.""" self.max_size = new_size @@ -229,19 +229,19 @@ def __init__(self, max_size: int, adaptive_size: bool = True, num_experiences=No of experiences to divide capacity over. """ super().__init__(max_size, adaptive_size, num_experiences) + self._num_exps = 0 - def update(self, strategy: "SupervisedTemplate", **kwargs): - assert strategy.experience is not None - new_data = strategy.experience.dataset - num_exps = strategy.clock.train_exp_counter + 1 - lens = self.get_group_lengths(num_exps) + def post_adapt(self, agent, exp): + self._num_exps += 1 + new_data = exp.dataset + lens = self.get_group_lengths(self._num_exps) new_buffer = ReservoirSamplingBuffer(lens[-1]) new_buffer.update_from_dataset(new_data) - self.buffer_groups[num_exps - 1] = new_buffer + self.buffer_groups[self._num_exps - 1] = new_buffer for ll, b in zip(lens, self.buffer_groups.values()): - b.resize(strategy, ll) + b.resize(agent, ll) class ClassBalancedBuffer(BalancedExemplarsBuffer[ReservoirSamplingBuffer]): @@ -279,10 +279,9 @@ def __init__( self.total_num_classes = total_num_classes self.seen_classes: Set[int] = set() - def update(self, strategy: "BaseSGDTemplate", **kwargs): + def post_adapt(self, agent, exp): """Update buffer.""" - assert strategy.experience is not None - self.update_from_dataset(strategy.experience.dataset, strategy) + self.update_from_dataset(exp.dataset, agent) def update_from_dataset( self, new_data: AvalancheDataset, strategy: Optional["BaseSGDTemplate"] = None @@ -361,10 +360,9 @@ def __init__( self.seen_groups: Set[int] = set() self._curr_strategy = None - def update(self, strategy: "SupervisedTemplate", **kwargs): - assert strategy.experience is not None - new_data: AvalancheDataset = strategy.experience.dataset - new_groups = self._make_groups(strategy, new_data) + def post_adapt(self, agent, exp): + new_data: AvalancheDataset = exp.dataset + new_groups = self._make_groups(agent, new_data) self.seen_groups.update(new_groups.keys()) # associate lengths to classes @@ -378,16 +376,16 @@ def update(self, strategy: "SupervisedTemplate", **kwargs): ll = group_to_len[group_id] if group_id in self.buffer_groups: old_buffer_g = self.buffer_groups[group_id] - old_buffer_g.update_from_dataset(strategy, new_data_g) - old_buffer_g.resize(strategy, ll) + old_buffer_g.update_from_dataset(agent, new_data_g) + old_buffer_g.resize(agent, ll) else: new_buffer = _ParametricSingleBuffer(ll, self.selection_strategy) - new_buffer.update_from_dataset(strategy, new_data_g) + new_buffer.update_from_dataset(agent, new_data_g) self.buffer_groups[group_id] = new_buffer # resize buffers for group_id, class_buf in self.buffer_groups.items(): - self.buffer_groups[group_id].resize(strategy, group_to_len[group_id]) + self.buffer_groups[group_id].resize(agent, group_to_len[group_id]) def _make_groups( self, strategy, data: AvalancheDataset From 006e8773a4e9e99300ad92db64256841512f064b Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Mon, 25 Mar 2024 19:27:29 +0100 Subject: [PATCH 07/18] updatable modules --- avalanche/models/dynamic_modules.py | 4 ++-- avalanche/models/dynamic_optimizers.py | 12 ++++++++++++ docs/gitbook/from-zero-to-hero-tutorial/02_models.md | 3 +-- examples/updatable_objects.py | 9 ++++++--- tests/models/test_models.py | 6 +++--- 5 files changed, 24 insertions(+), 10 deletions(-) diff --git a/avalanche/models/dynamic_modules.py b/avalanche/models/dynamic_modules.py index 436c75bee..b2466dff9 100644 --- a/avalanche/models/dynamic_modules.py +++ b/avalanche/models/dynamic_modules.py @@ -60,7 +60,7 @@ class DynamicModule(Module): expanded to allow architectural modifications (multi-head classifiers, progressive networks, ...). - Compared to pytoch Modules, they provide an additional method, + Compared to pytorch Modules, they provide an additional method, `model_adaptation`, which adapts the model given the current experience. """ @@ -73,7 +73,7 @@ def __init__(self, auto_adapt=True): super().__init__() self._auto_adapt = auto_adapt - def recursive_adaptation(self, experience): + def pre_adapt(self, experience): """ Calls self.adaptation recursively accross the hierarchy of pytorch module childrens diff --git a/avalanche/models/dynamic_optimizers.py b/avalanche/models/dynamic_optimizers.py index e48241ee9..4aa7af3d7 100644 --- a/avalanche/models/dynamic_optimizers.py +++ b/avalanche/models/dynamic_optimizers.py @@ -17,6 +17,18 @@ from collections import defaultdict +class DynamicOptimizer: + def __init__(self, opt): + self._opt = opt + + def pre_adapt(self): + # TODO: implement optimizer adaptation after albin PR is merged + # TODO: support scheduling and further wrapping... + # TODO: test + # TODO: doc + pass + + def compare_keys(old_dict, new_dict): not_in_new = list(set(old_dict.keys()) - set(new_dict.keys())) in_both = list(set(old_dict.keys()) & set(new_dict.keys())) diff --git a/docs/gitbook/from-zero-to-hero-tutorial/02_models.md b/docs/gitbook/from-zero-to-hero-tutorial/02_models.md index 0c566f83d..ba6ca49c2 100644 --- a/docs/gitbook/from-zero-to-hero-tutorial/02_models.md +++ b/docs/gitbook/from-zero-to-hero-tutorial/02_models.md @@ -109,7 +109,6 @@ print(mt_model) ### Nested Dynamic Modules Whenever one or more dynamic modules are nested one inside the other, you must call the `recursive_adaptation` method, and if they are nested inside a normal pytorch module (non dynamic), you can call the `avalanche_model_adaptation` function. Avalanche strategies will by default adapt the models before training on each experience by calling `avalanche_model_adaptation` - ```python benchmark = SplitMNIST(5, shuffle=False, class_ids_from_zero_in_each_exp=True, return_task_id=True) @@ -118,7 +117,7 @@ mt_model = as_multitask(model, 'classifier') print(mt_model) for exp in benchmark.train_stream: - mt_model.recursive_adaptation(exp) + mt_model.pre_adapt(exp) print(mt_model) ``` diff --git a/examples/updatable_objects.py b/examples/updatable_objects.py index a33b17d42..efb1b414d 100644 --- a/examples/updatable_objects.py +++ b/examples/updatable_objects.py @@ -14,7 +14,8 @@ from avalanche.evaluation.functional import forgetting from avalanche.evaluation.metrics import Accuracy from avalanche.evaluation.plot_utils import plot_metric_matrix -from avalanche.models import SimpleMLP +from avalanche.models import SimpleMLP, IncrementalClassifier +from avalanche.models.dynamic_optimizers import DynamicOptimizer from avalanche.training import ReservoirSamplingBuffer, LearningWithoutForgetting from avalanche.training.losses import MaskedCrossEntropy @@ -94,10 +95,12 @@ def my_eval(model, stream, metrics, **extra_args): agent_state.reg_loss = LearningWithoutForgetting(alpha=1, temperature=2) # model - agent_state.model = SimpleMLP(num_classes=10).cuda() + agent_state.model = SimpleMLP().cuda() + agent_state.model.classifier = IncrementalClassifier(in_features=512) # optim and scheduler - agent_state.opt = SGD(agent_state.model.parameters(), lr=0.001) + opt = SGD(agent_state.model.parameters(), lr=0.001) + agent_state.opt = DynamicOptimizer(opt) # TODO: maybe we can do this with some hooks??? agent_state.scheduler = ExponentialLR(agent_state.opt, gamma=0.999) print("Start training...") diff --git a/tests/models/test_models.py b/tests/models/test_models.py index b7b2793bf..904186f8a 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -546,10 +546,10 @@ def test_recursive_adaptation(self): sizes = {} for t, exp in enumerate(benchmark.train_stream): sizes[t] = np.max(exp.classes_in_this_experience) + 1 - model1.recursive_adaptation(exp) + model1.pre_adapt(exp) # Second adaptation should not change anything for t, exp in enumerate(benchmark.train_stream): - model1.recursive_adaptation(exp) + model1.pre_adapt(exp) for t, s in sizes.items(): self.assertEqual(s, model1.classifiers[str(t)].classifier.out_features) @@ -562,7 +562,7 @@ def test_recursive_loop(self): model2.layer2 = model1 benchmark = get_fast_benchmark(use_task_labels=True, shuffle=True) - model1.recursive_adaptation(benchmark.train_stream[0]) + model1.pre_adapt(benchmark.train_stream[0]) def test_multi_head_classifier_masking(self): benchmark = get_fast_benchmark(use_task_labels=True, shuffle=True) From 480207050c0bea4e8fd392e7b24b3e0f78d8483c Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Tue, 26 Mar 2024 10:16:49 +0100 Subject: [PATCH 08/18] add multi-stream support to MetricCollector --- avalanche/_annotations.py | 4 +- avalanche/core.py | 22 +++++++++ avalanche/evaluation/collector.py | 66 ++++++++++++++++++-------- avalanche/models/dynamic_optimizers.py | 1 + avalanche/training/losses.py | 13 ++++- examples/updatable_objects.py | 63 +++++++++++++----------- tests/training/test_losses.py | 4 +- 7 files changed, 117 insertions(+), 56 deletions(-) diff --git a/avalanche/_annotations.py b/avalanche/_annotations.py index 1820ae811..b4b2787b4 100644 --- a/avalanche/_annotations.py +++ b/avalanche/_annotations.py @@ -50,6 +50,7 @@ def wrapper(*args, **kwargs): return decorator +# TODO: show deprecation warning only once def deprecated(version: float, reason: str): """Decorator to mark functions as deprecated. @@ -81,13 +82,12 @@ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - warnings.simplefilter("always", DeprecationWarning) + warnings.simplefilter("default", DeprecationWarning) warnings.warn( msg.format(name=func.__name__, version=version, reason=reason), category=DeprecationWarning, stacklevel=2, ) - warnings.simplefilter("default", DeprecationWarning) return func(*args, **kwargs) return wrapper diff --git a/avalanche/core.py b/avalanche/core.py index 893eb7c48..ff31809d7 100644 --- a/avalanche/core.py +++ b/avalanche/core.py @@ -18,6 +18,8 @@ def __init__(self, verbose=False): # TODO: test post_update call self._updatable_objects = [] self.verbose = verbose + self._pre_hooks = [] + self._post_hooks = [] def __setattr__(self, name, value): super().__setattr__(name, value) @@ -33,6 +35,10 @@ def pre_adapt(self, exp): uo.pre_adapt(self, exp) if self.verbose: print("pre_adapt ", uo) + for foo in self._pre_hooks: + if self.verbose: + print("pre_adapt hook ", foo) + foo(self, exp) def post_adapt(self, exp): # TODO: doc @@ -41,10 +47,26 @@ def post_adapt(self, exp): uo.post_adapt(self, exp) if self.verbose: print("post_adapt ", uo) + for foo in self._post_hooks: + if self.verbose: + print("post_adapt hook ", foo) + foo(self, exp) + + def add_pre_hooks(self, foo): + # TODO: doc + # TODO: test + self._pre_hooks.append(foo) + + def add_post_hooks(self, foo): + # TODO: doc + # TODO: test + self._post_hooks.append(foo) @runtime_checkable class Adaptable(Protocol): + # TODO: doc + # TODO: test runtime-checkability def pre_adapt(self, agent: Agent, exp: CLExperience): pass diff --git a/avalanche/evaluation/collector.py b/avalanche/evaluation/collector.py index f14d05ab4..02960b4ce 100644 --- a/avalanche/evaluation/collector.py +++ b/avalanche/evaluation/collector.py @@ -9,43 +9,67 @@ @experimental() class MetricCollector: # TODO: doc - def __init__(self, stream): + def __init__(self): # TODO: doc - # TODO: add support for different streams for each metric - self.metrics_res = None + self.metrics_res = {} - self._stream_len = 0 - self._coeffs = [] + self._stream_len = {} # stream-name -> stream length + self._coeffs = {} # stream-name -> list_of_exp_length + + def _init_stream(self, stream): + # we compute the stream length and number of samples in each experience + # only the first time we encounter the experience because + # iterating over the stream may be expensive + self._stream_len[stream.name] = 0 + coeffs = [] for exp in stream: - self._stream_len += 1 - self._coeffs.append(len(exp.dataset)) - self._coeffs = np.array(self._coeffs) - self._coeffs = self._coeffs / self._coeffs.sum() + self._stream_len[stream.name] += 1 + coeffs.append(len(exp.dataset)) + coeffs = np.array(coeffs) + self._coeffs[stream.name] = coeffs / coeffs.sum() - def update(self, res): + def update(self, res, *, stream=None): # TODO: doc - if self.metrics_res is None: - self.metrics_res = {} - for k, v in res.items(): - if len(v) != self._stream_len: + # TODO: test multi groups + for k, v in res.items(): + # optional safety check on metric shape + # metrics are expected to have one value for each experience in the stream + if stream is not None: + if stream.name not in self._stream_len: + self._init_stream(stream) + if len(v) != self._stream_len[stream.name]: raise ValueError(f"Length does not correspond to stream. " - f"Found {len(v)}, expected {self._stream_len}") - self.metrics_res[k] = [v] - else: - for k, v in res.items(): + f"Found {len(v)}, expected {self._stream_len[stream.name]}") + + # update metrics dictionary + if stream is not None: + k = f"{stream.name}/{k}" + if k in self.metrics_res: self.metrics_res[k].append(v) + else: + self.metrics_res[k] = [v] - def get(self, name, time_reduce=None, exp_reduce=None): + def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None): # TODO: doc assert time_reduce in {None, "last", "mean"} assert exp_reduce in {None, "sample_mean", "experience_mean"} - assert name in self.metrics_res + + if stream is not None: + name = f"{stream.name}/{name}" + if name not in self.metrics_res: + print(f"{name} metric was not found. Maybe you forgot the `stream` argument?") mvals = np.array(self.metrics_res[name]) if exp_reduce is None: pass # nothing to do here elif exp_reduce == "sample_mean": - mvals = mvals * self._coeffs[None, :] # weight by num.samples + if stream is None: + raise ValueError( + "If you want to use `exp_reduce == sample_mean` you need to provide" + "the `stream` argument to the `update` and `get` methods.") + if stream.name not in self._coeffs: + self._init_stream(stream) + mvals = mvals * self._coeffs[stream.name][None, :] # weight by num.samples mvals = mvals.sum(axis=1) # weighted avg across exp. elif exp_reduce == "experience_mean": mvals = mvals.mean(axis=1) # avg across exp. diff --git a/avalanche/models/dynamic_optimizers.py b/avalanche/models/dynamic_optimizers.py index 4aa7af3d7..76032675f 100644 --- a/avalanche/models/dynamic_optimizers.py +++ b/avalanche/models/dynamic_optimizers.py @@ -18,6 +18,7 @@ class DynamicOptimizer: + # TODO: remove, not needed? def __init__(self, opt): self._opt = opt diff --git a/avalanche/training/losses.py b/avalanche/training/losses.py index 014cb4129..460b01aa2 100644 --- a/avalanche/training/losses.py +++ b/avalanche/training/losses.py @@ -8,6 +8,7 @@ from avalanche.training.plugins import SupervisedPlugin from avalanche.training.regularization import cross_entropy_with_oh_targets +from avalanche._annotations import deprecated class ICaRLLossPlugin(SupervisedPlugin): @@ -214,12 +215,20 @@ def current_mask(self, logit_shape): elif self.mask == "all": return list(range(int(logit_shape))) - def adaptation(self, new_classes): + def _adaptation(self, new_classes): self.old_classes = self.old_classes.union(self.current_classes) self.current_classes = set(new_classes) + def pre_adapt(self, agent, exp): + self._adaptation(exp.classes_in_this_experience) + + @deprecated(0.7, "Please switch to the `pre_adapt`or `_adaptation` methods.") + def adaptation(self, new_classes): + self._adaptation(new_classes) + + @deprecated(0.7, "Please switch to the `pre_adapt` method.") def before_training_exp(self, strategy, **kwargs): - self.adaptation(strategy.experience.classes_in_this_experience) + self._adaptation(strategy.experience.classes_in_this_experience) __all__ = ["ICaRLLossPlugin", "SCRLoss", "MaskedCrossEntropy"] diff --git a/examples/updatable_objects.py b/examples/updatable_objects.py index efb1b414d..866c8ede9 100644 --- a/examples/updatable_objects.py +++ b/examples/updatable_objects.py @@ -1,3 +1,5 @@ +import setGPU + from types import SimpleNamespace import torch @@ -15,7 +17,8 @@ from avalanche.evaluation.metrics import Accuracy from avalanche.evaluation.plot_utils import plot_metric_matrix from avalanche.models import SimpleMLP, IncrementalClassifier -from avalanche.models.dynamic_optimizers import DynamicOptimizer +from avalanche.models.dynamic_modules import avalanche_model_adaptation +from avalanche.models.dynamic_optimizers import DynamicOptimizer, reset_optimizer from avalanche.training import ReservoirSamplingBuffer, LearningWithoutForgetting from avalanche.training.losses import MaskedCrossEntropy @@ -28,9 +31,8 @@ def train_experience(agent_state, exp, epochs=10): # this is the usual before_exp # updatable_objs = [agent_state.replay, agent_state.loss, agent_state.model] # [uo.pre_update(exp, agent_state) for uo in updatable_objs] - agent_state.loss.before_training_exp(SimpleNamespace(experience=exp)) # agent_state.model.recursive_adaptation(exp) - # agent_state.pre_update(agent_state, exp) + agent_state.pre_adapt(exp) data = exp.dataset.train() for ep in range(epochs): @@ -49,18 +51,12 @@ def train_experience(agent_state, exp, epochs=10): l += agent_state.reg_loss(x, yp, agent_state.model) l.backward() agent_state.opt.step() - - # this is the usual after_exp - # updatable_objs = [agent_state.replay, agent_state.loss, agent_state.model] - # [uo.post_update(exp, agent_state) for uo in updatable_objs] - agent_state.replay.update(SimpleNamespace(experience=exp)) - agent_state.scheduler.step() - # agent_state.post_update(agent_state, exp) + agent_state.post_adapt(exp) return agent_state @torch.no_grad() -def my_eval(model, stream, metrics, **extra_args): +def my_eval(model, stream, metrics): # eval also becomes simpler. Notice how in Avalanche it's harder to check whether # we are evaluating a single exp. or the whole stream. # Now we evaluate each stream with a separate function call @@ -89,37 +85,46 @@ def my_eval(model, stream, metrics, **extra_args): # agent state collects all the objects that are needed during training # many of these objects will have some state that is updated during training. # The training function returns the updated agent state at each step. - agent_state = Agent() - agent_state.replay = ReservoirSamplingBuffer(max_size=200) - agent_state.loss = MaskedCrossEntropy() - agent_state.reg_loss = LearningWithoutForgetting(alpha=1, temperature=2) + agent = Agent() + agent.replay = ReservoirSamplingBuffer(max_size=200) + agent.loss = MaskedCrossEntropy() + agent.reg_loss = LearningWithoutForgetting(alpha=1, temperature=2) # model - agent_state.model = SimpleMLP().cuda() - agent_state.model.classifier = IncrementalClassifier(in_features=512) + agent.model = SimpleMLP().cuda() + agent.model.classifier = IncrementalClassifier(in_features=512).cuda() + agent.add_pre_hooks(lambda a, e: avalanche_model_adaptation(a.model, e)) # optim and scheduler - opt = SGD(agent_state.model.parameters(), lr=0.001) - agent_state.opt = DynamicOptimizer(opt) # TODO: maybe we can do this with some hooks??? - agent_state.scheduler = ExponentialLR(agent_state.opt, gamma=0.999) + agent.opt = SGD(agent.model.parameters(), lr=0.001) + agent.add_pre_hooks(lambda a, e: reset_optimizer(a.opt, a.model)) # TODO: update after albin PR is merged + # agent_state.opt = DynamicOptimizer(opt) # TODO: maybe we can do this with some hooks??? + agent.scheduler = ExponentialLR(agent.opt, gamma=0.999) + agent.add_post_hooks(lambda a, e: a.scheduler.step()) print("Start training...") metrics = [Accuracy()] - mc_test = MetricCollector(test_stream) + mc = MetricCollector() for exp in ocl_stream: - print(exp.classes_in_this_experience, end=' ') - agent_state = train_experience(agent_state, exp, epochs=2) + if exp.logging().is_first_subexp: + print("training on classes ", exp.classes_in_this_experience) + agent = train_experience(agent, exp, epochs=2) if exp.logging().is_last_subexp: + res = my_eval(agent.model, train_stream, metrics) + mc.update(res, stream=train_stream) + print("RES: ", res) + res = my_eval(agent.model, test_stream, metrics) + mc.update(res, stream=test_stream) + print("EVAL RES: ", res) print() - res = my_eval(agent_state.model, test_stream, metrics) - print("EVAL: ", res) - mc_test.update(res) - acc_timeline = mc_test.get("Accuracy", exp_reduce="sample_mean") - print(mc_test.get("Accuracy", exp_reduce=None)) + print(mc.get_dict()) + + acc_timeline = mc.get("Accuracy", exp_reduce="sample_mean", stream=test_stream) + print(mc.get("Accuracy", exp_reduce=None, stream=test_stream)) print(acc_timeline) - acc_matrix = mc_test.get("Accuracy", exp_reduce=None) + acc_matrix = mc.get("Accuracy", exp_reduce=None, stream=test_stream) fig = plot_metric_matrix(acc_matrix, title="Accuracy - Train") fig.savefig("accuracy.png") diff --git a/tests/training/test_losses.py b/tests/training/test_losses.py index 4935fb88c..5a483c785 100644 --- a/tests/training/test_losses.py +++ b/tests/training/test_losses.py @@ -41,8 +41,8 @@ def test_loss(self): cross_entropy = nn.CrossEntropyLoss() criterion = MaskedCrossEntropy(mask="new") - criterion.adaptation([1, 2, 3, 4]) - criterion.adaptation([5, 6, 7]) + criterion._adaptation([1, 2, 3, 4]) + criterion._adaptation([5, 6, 7]) mb_y = torch.tensor([5, 5, 6, 7, 6]) From 3b5f9b42329a5bc656a1193bdc6c8662dd330a0a Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Tue, 26 Mar 2024 11:13:57 +0100 Subject: [PATCH 09/18] document metric collector --- avalanche/evaluation/collector.py | 66 +++++++++++++++++++++++++++---- 1 file changed, 58 insertions(+), 8 deletions(-) diff --git a/avalanche/evaluation/collector.py b/avalanche/evaluation/collector.py index 02960b4ce..3228b7fd6 100644 --- a/avalanche/evaluation/collector.py +++ b/avalanche/evaluation/collector.py @@ -8,9 +8,25 @@ @experimental() class MetricCollector: - # TODO: doc + """A simple metric collector object. + + Functionlity includes the ability to store metrics over time and compute + aggregated values of them. Serialization is supported via json files. + + Example usage (pseudocode missing imports and init and other objects): + + ``` + mc = MetricCollector() + for exp in train_stream: + agent = train_experience(agent, exp) + res = my_eval(agent.model, test_stream, metrics) + mc.update(res, stream=test_stream) + acc_timeline = mc.get("Accuracy", exp_reduce="sample_mean", stream=test_stream) + ``` + + """ def __init__(self): - # TODO: doc + """Init.""" self.metrics_res = {} self._stream_len = {} # stream-name -> stream length @@ -29,7 +45,12 @@ def _init_stream(self, stream): self._coeffs[stream.name] = coeffs / coeffs.sum() def update(self, res, *, stream=None): - # TODO: doc + """Update the metrics. + + :param res: a dictionary of new metrics with items. + :param stream: optional stream. If a stream is given the full metric + name becomes `f'{stream.name}/{metric_name}'`. + """ # TODO: test multi groups for k, v in res.items(): # optional safety check on metric shape @@ -50,7 +71,23 @@ def update(self, res, *, stream=None): self.metrics_res[k] = [v] def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None): - # TODO: doc + """Returns a metric value given its name and aggregation method. + + :param name: name of the metric. + :param time_reduce: Aggregation over the time dimension. One of {None, 'last', 'mean'}, where: + - None (default) does not use any aggregation + - 'last' returns the last timestep + - 'mean' averages over time + :param exp_reduce: Aggregation over the experience dimension. One of {None, 'sample_mean', 'experience_mean'} where: + - None (default) does not use any aggregation + - `sample_mean` is an average weighted by the number of samples in each experience + - `experience_mean` is an experience average. + :param stream: stream that was used to compute the metric. This is + needed to build the full metric name if the get was called with a + stream name and if `exp_reduce == sample_mean` to get the number + of samples from each experience. + :return: aggregated metric value. + """ assert time_reduce in {None, "last", "mean"} assert exp_reduce in {None, "sample_mean", "experience_mean"} @@ -88,23 +125,36 @@ def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None): return mvals def get_dict(self): - # TODO: doc + """Returns metrics dictionary. + + :return: metrics dictionary + """ # TODO: test return self.metrics_res def load_dict(self, d): - # TODO: doc + """Loads a new metrics dictionary. + + :param d: metrics dictionary + """ # TODO: test self.metrics_res = d def load_json(self, fname): - # TODO: doc + """Loads a metrics dictionary from a json file. + + :param fname: file name + """ # TODO: test with open(fname, 'w') as f: self.metrics_res = json.load(f) def to_json(self, fname): - # TODO: doc + """Stores metrics dictionary as a json filename. + + :param fname: + :return: + """ # TODO: test with open(fname, 'w') as f: json.dump(obj=self.metrics_res, fp=f) From 4035093d63cba285aa2845d5b16e12294e1f1e90 Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Thu, 28 Mar 2024 10:37:38 +0100 Subject: [PATCH 10/18] revert incorrect merge --- avalanche/models/dynamic_optimizers.py | 318 ++++++++++++++++++++----- 1 file changed, 264 insertions(+), 54 deletions(-) diff --git a/avalanche/models/dynamic_optimizers.py b/avalanche/models/dynamic_optimizers.py index e48241ee9..e5dc6108f 100644 --- a/avalanche/models/dynamic_optimizers.py +++ b/avalanche/models/dynamic_optimizers.py @@ -16,14 +16,220 @@ """ from collections import defaultdict +import numpy as np -def compare_keys(old_dict, new_dict): - not_in_new = list(set(old_dict.keys()) - set(new_dict.keys())) - in_both = list(set(old_dict.keys()) & set(new_dict.keys())) - not_in_old = list(set(new_dict.keys()) - set(old_dict.keys())) - return not_in_new, in_both, not_in_old +from avalanche._annotations import deprecated +colors = { + "END": "\033[0m", + 0: "\033[32m", + 1: "\033[33m", + 2: "\033[34m", + 3: "\033[35m", + 4: "\033[36m", +} +colors[None] = colors["END"] + +def _map_optimized_params(optimizer, parameters, old_params=None): + """ + Establishes a mapping between a list of named parameters and the parameters + that are in the optimizer, additionally, + returns the lists of: + + returns: + new_parameters: Names of new parameters in the provided "parameters" argument, + that are not in the old parameters + changed_parameters: Names and indexes of parameters that have changed (grown, shrink) + not_found_in_parameters: List of indexes of optimizer parameters + that are not found in the provided parameters + """ + + if old_params is None: + old_params = {} + + group_mapping = defaultdict(dict) + new_parameters = [] + + found_indexes = [] + changed_parameters = [] + for group in optimizer.param_groups: + params = group["params"] + found_indexes.append(np.zeros(len(params))) + + for n, p in parameters.items(): + gidx = None + pidx = None + + # Find param in optimizer + found = False + + if n in old_params: + search_id = id(old_params[n]) + else: + search_id = id(p) + + for group_idx, group in enumerate(optimizer.param_groups): + params = group["params"] + for param_idx, po in enumerate(params): + if id(po) == search_id: + gidx = group_idx + pidx = param_idx + found = True + # Update found indexes + assert found_indexes[group_idx][param_idx] == 0 + found_indexes[group_idx][param_idx] = 1 + break + if found: + break + + if not found: + new_parameters.append(n) + + if search_id != id(p): + if found: + changed_parameters.append((n, gidx, pidx)) + + if len(optimizer.param_groups) > 1: + group_mapping[n] = gidx + else: + group_mapping[n] = 0 + + not_found_in_parameters = [np.where(arr == 0)[0] for arr in found_indexes] + + return ( + group_mapping, + changed_parameters, + new_parameters, + not_found_in_parameters, + ) + + +def _build_tree_from_name_groups(name_groups): + root = _TreeNode("") # Root node + node_mapping = {} + + # Iterate through each string in the list + for name, group in name_groups.items(): + components = name.split(".") + current_node = root + + # Traverse the tree and construct nodes for each component + for component in components: + if component not in current_node.children: + current_node.children[component] = _TreeNode( + component, parent=current_node + ) + current_node = current_node.children[component] + + # Update the groups for the leaf node + if group is not None: + current_node.groups |= set([group]) + current_node.update_groups_upwards() # Inform parent about the groups + + # Update leaf node mapping dict + node_mapping[name] = current_node + + # This will resolve nodes without group + root.update_groups_downwards() + return root, node_mapping + + +def _print_group_information(node, prefix=""): + # Print the groups for the current node + + if len(node.groups) == 1: + pstring = ( + colors[list(node.groups)[0]] + + f"{prefix}{node.global_name()}: {node.groups}" + + colors["END"] + ) + print(pstring) + else: + print(f"{prefix}{node.global_name()}: {node.groups}") + + # Recursively print group information for children nodes + for child_name, child_node in node.children.items(): + _print_group_information(child_node, prefix + " ") + + +class _ParameterGroupStructure: + """ + Structure used for the resolution of unknown parameter groups, + stores parameters as a tree and propagates parameter groups from leaves of + the same hierarchical level + """ + + def __init__(self, name_groups, verbose=False): + # Here we rebuild the tree + self.root, self.node_mapping = _build_tree_from_name_groups(name_groups) + if verbose: + _print_group_information(self.root) + + def __getitem__(self, name): + return self.node_mapping[name] + + +class _TreeNode: + def __init__(self, name, parent=None): + self.name = name + self.children = {} + self.groups = set() # Set of groups (represented by index) this node belongs to + self.parent = parent # Reference to the parent node + if parent: + # Inform the parent about the new child node + parent.add_child(self) + + def add_child(self, child): + self.children[child.name] = child + + def update_groups_upwards(self): + if self.parent: + if self.groups != {None}: + self.parent.groups |= ( + self.groups + ) # Update parent's groups with the child's groups + self.parent.update_groups_upwards() # Propagate the group update to the parent + + def update_groups_downwards(self, new_groups=None): + # If you are a node with no groups, absorb + if len(self.groups) == 0 and new_groups is not None: + self.groups = self.groups.union(new_groups) + + # Then transmit + if len(self.groups) > 0: + for key, children in self.children.items(): + children.update_groups_downwards(self.groups) + + def global_name(self, initial_name=None): + """ + Returns global node name + """ + if initial_name is None: + initial_name = self.name + elif self.name != "": + initial_name = ".".join([self.name, initial_name]) + + if self.parent: + return self.parent.global_name(initial_name) + else: + return initial_name + + @property + def single_group(self): + if len(self.groups) == 0: + raise AttributeError( + f"Could not identify group for this node {self.global_name()}" + ) + elif len(self.groups) > 1: + raise AttributeError( + f"No unique group found for this node {self.global_name()}" + ) + else: + return list(self.groups)[0] + + +@deprecated(0.6, "update_optimizer with optimized_params=None is now used instead") def reset_optimizer(optimizer, model): """Reset the optimizer to update the list of learnable parameters. @@ -53,7 +259,14 @@ def reset_optimizer(optimizer, model): return optimized_param_id -def update_optimizer(optimizer, new_params, optimized_params, reset_state=False): +def update_optimizer( + optimizer, + new_params, + optimized_params=None, + reset_state=False, + remove_params=False, + verbose=False, +): """Update the optimizer by adding new parameters, removing removed parameters, and adding new parameters to the optimizer, for instance after model has been adapted @@ -64,72 +277,69 @@ def update_optimizer(optimizer, new_params, optimized_params, reset_state=False) :param new_params: Dict (name, param) of new parameters :param optimized_params: Dict (name, param) of - currently optimized parameters (returned by reset_optimizer) - :param reset_state: Wheter to reset the optimizer's state (i.e momentum). - Defaults to False. + currently optimized parameters + :param reset_state: Whether to reset the optimizer's state (i.e momentum). + Defaults to False. + :param remove_params: Whether to remove parameters that were in the optimizer + but are not found in new parameters. For safety reasons, + defaults to False. + :param verbose: If True, prints information about inferred + parameter groups for new params + :return: Dict (name, param) of optimized parameters """ - not_in_new, in_both, not_in_old = compare_keys(optimized_params, new_params) + ( + group_mapping, + changed_parameters, + new_parameters, + not_found_in_parameters, + ) = _map_optimized_params(optimizer, new_params, old_params=optimized_params) + # Change reference to already existing parameters # i.e growing IncrementalClassifier - for key in in_both: - old_p_hash = optimized_params[key] - new_p = new_params[key] + for name, group_idx, param_idx in changed_parameters: + group = optimizer.param_groups[group_idx] + old_p = optimized_params[name] + new_p = new_params[name] # Look for old parameter id in current optimizer - found = False - for group in optimizer.param_groups: - for i, curr_p in enumerate(group["params"]): - if id(curr_p) == id(old_p_hash): - found = True - if id(curr_p) != id(new_p): - group["params"][i] = new_p - optimized_params[key] = new_p - optimizer.state[new_p] = {} - break - if not found: - raise Exception( - f"Parameter {key} expected but " "not found in the optimizer" - ) + group["params"][param_idx] = new_p + if old_p in optimizer.state: + optimizer.state.pop(old_p) + optimizer.state[new_p] = {} # Remove parameters that are not here anymore # This should not happend in most use case - keys_to_remove = [] - for key in not_in_new: - old_p_hash = optimized_params[key] - found = False - for i, group in enumerate(optimizer.param_groups): - keys_to_remove.append([]) - for j, curr_p in enumerate(group["params"]): - if id(curr_p) == id(old_p_hash): - found = True - keys_to_remove[i].append((j, curr_p)) - optimized_params.pop(key) - break - if not found: - raise Exception( - f"Parameter {key} expected but " "not found in the optimizer" - ) - - for i, idx_list in enumerate(keys_to_remove): - for j, p in sorted(idx_list, key=lambda x: x[0], reverse=True): - del optimizer.param_groups[i]["params"][j] - if p in optimizer.state: - optimizer.state.pop(p) + if remove_params: + for group_idx, idx_list in enumerate(not_found_in_parameters): + for j in sorted(idx_list, key=lambda x: x, reverse=True): + p = optimizer.param_groups[group_idx]["params"][j] + optimizer.param_groups[group_idx]["params"].pop(j) + if p in optimizer.state: + optimizer.state.pop(p) + del p # Add newly added parameters (i.e Multitask, PNN) - # by default, add to param groups 0 - for key in not_in_old: + + param_structure = _ParameterGroupStructure(group_mapping, verbose=verbose) + + # New parameters + for key in new_parameters: new_p = new_params[key] - optimizer.param_groups[0]["params"].append(new_p) + group = param_structure[key].single_group + optimizer.param_groups[group]["params"].append(new_p) optimized_params[key] = new_p optimizer.state[new_p] = {} if reset_state: optimizer.state = defaultdict(dict) - return optimized_params + return new_params +@deprecated( + 0.6, + "parameters have to be added manually to the optimizer in an existing or a new parameter group", +) def add_new_params_to_optimizer(optimizer, new_params): """Add new parameters to the trainable parameters. @@ -138,4 +348,4 @@ def add_new_params_to_optimizer(optimizer, new_params): optimizer.add_param_group({"params": new_params}) -__all__ = ["add_new_params_to_optimizer", "reset_optimizer", "update_optimizer"] +__all__ = ["add_new_params_to_optimizer", "reset_optimizer", "update_optimizer"] \ No newline at end of file From 8a69d675c463419c8bc6a1b428c5aa4aa47c86ac Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Thu, 28 Mar 2024 12:20:21 +0100 Subject: [PATCH 11/18] update example --- avalanche/_annotations.py | 2 +- avalanche/evaluation/plot_utils.py | 14 ++-- .../training/templates/common_templates.py | 2 - examples/updatable_objects.py | 82 ++++++++++++------- 4 files changed, 60 insertions(+), 40 deletions(-) diff --git a/avalanche/_annotations.py b/avalanche/_annotations.py index b4b2787b4..e2795b116 100644 --- a/avalanche/_annotations.py +++ b/avalanche/_annotations.py @@ -82,7 +82,7 @@ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - warnings.simplefilter("default", DeprecationWarning) + warnings.simplefilter("once", DeprecationWarning) warnings.warn( msg.format(name=func.__name__, version=version, reason=reason), category=DeprecationWarning, diff --git a/avalanche/evaluation/plot_utils.py b/avalanche/evaluation/plot_utils.py index 0a09ef0a0..9ee83cd62 100644 --- a/avalanche/evaluation/plot_utils.py +++ b/avalanche/evaluation/plot_utils.py @@ -8,7 +8,6 @@ # E-mail: contact@continualai.org # # Website: avalanche.continualai.org # ################################################################################ -import matplotlib.pyplot as plt import numpy as np from matplotlib import pyplot as plt @@ -37,12 +36,14 @@ def plot_metric_matrix( title, *, ax=None, + text_values=True ): """Plot a matrix of metrics (e.g. forgetting over time). :param metric_matrix: 2D accuracy matrix with shape :param title: plot title :param ax: axes to use for figure + :param text_values: (bool) whether to add the value as text in each cell. :return: a matplotlib.Figure """ @@ -53,12 +54,13 @@ def plot_metric_matrix( metric_matrix = np.array(metric_matrix).T ax.matshow(metric_matrix) - ax.set_ylabel("Eval Experience") + ax.set_ylabel("Experience") ax.set_xlabel("Time") ax.set_title(title) - for i in range(len(metric_matrix)): - for j in range(len(metric_matrix[0])): - ax.text(j, i, f"{metric_matrix[i][j]:.3f}", - ha="center", va="center", color="w") + if text_values: + for i in range(len(metric_matrix)): + for j in range(len(metric_matrix[0])): + ax.text(j, i, f"{metric_matrix[i][j]:.3f}", + ha="center", va="center", color="w") return fig diff --git a/avalanche/training/templates/common_templates.py b/avalanche/training/templates/common_templates.py index 8405c5a6f..20ba0465a 100644 --- a/avalanche/training/templates/common_templates.py +++ b/avalanche/training/templates/common_templates.py @@ -80,7 +80,6 @@ class SupervisedTemplate( PLUGIN_CLASS = SupervisedPlugin - # TODO: remove default values of model and optimizer when legacy positional arguments are definitively removed def __init__( self, *, @@ -194,7 +193,6 @@ class SupervisedMetaLearningTemplate( PLUGIN_CLASS = SupervisedPlugin - # TODO: remove default values of model and optimizer when legacy positional arguments are definitively removed def __init__( self, *, diff --git a/examples/updatable_objects.py b/examples/updatable_objects.py index 866c8ede9..75e941d71 100644 --- a/examples/updatable_objects.py +++ b/examples/updatable_objects.py @@ -1,6 +1,6 @@ -import setGPU +import os -from types import SimpleNamespace +import setGPU import torch from torch.optim import SGD @@ -18,26 +18,20 @@ from avalanche.evaluation.plot_utils import plot_metric_matrix from avalanche.models import SimpleMLP, IncrementalClassifier from avalanche.models.dynamic_modules import avalanche_model_adaptation -from avalanche.models.dynamic_optimizers import DynamicOptimizer, reset_optimizer +from avalanche.models.dynamic_optimizers import update_optimizer from avalanche.training import ReservoirSamplingBuffer, LearningWithoutForgetting from avalanche.training.losses import MaskedCrossEntropy def train_experience(agent_state, exp, epochs=10): - # defining the training for a single exp. instead of the whole stream - # should be better because this way training does not need to care about eval loop. - # agent_state takes the role of the strategy object + agent_state.model.train() + data = exp.dataset.train() # avalanche datasets have train/eval modes to switch augmentations - # this is the usual before_exp - # updatable_objs = [agent_state.replay, agent_state.loss, agent_state.model] - # [uo.pre_update(exp, agent_state) for uo in updatable_objs] - # agent_state.model.recursive_adaptation(exp) - agent_state.pre_adapt(exp) - - data = exp.dataset.train() + agent_state.pre_adapt(exp) # update objects and call pre_hooks for ep in range(epochs): - agent_state.model.train() if len(agent_state.replay.buffer) > 0: + # if the replay buffer is not empty we sample from + # current data and replay buffer in parallel dl = ReplayDataLoader(data, agent_state.replay.buffer, batch_size=32, shuffle=True) else: @@ -48,10 +42,14 @@ def train_experience(agent_state, exp, epochs=10): agent_state.opt.zero_grad() yp = agent_state.model(x) l = agent_state.loss(yp, y) + + # you may have to change this if your regularization loss + # needs different arguments l += agent_state.reg_loss(x, yp, agent_state.model) + l.backward() agent_state.opt.step() - agent_state.post_adapt(exp) + agent_state.post_adapt(exp) # update objects and call post_hooks return agent_state @@ -74,60 +72,82 @@ def my_eval(model, stream, metrics): if __name__ == '__main__': - # TODO: add checkpointing - # TODO: add dynamic optimizers - bm = SplitMNIST(n_experiences=5) train_stream, test_stream = bm.train_stream, bm.test_stream + # we split the training stream into online experiences ocl_stream = split_online_stream(train_stream, experience_size=256, seed=1234) + # we add class attributes to the experiences ocl_stream = with_classes_timeline(ocl_stream) # agent state collects all the objects that are needed during training # many of these objects will have some state that is updated during training. - # The training function returns the updated agent state at each step. + # Put all the stateful training objects here. Calling the agent `pre_adapt` + # and `post_adapt` methods will adapt all the objects. + # Sometimes, it may be useful to also put non-stateful objects (e.g. losses) + # to easily switch between them while keeping the same training loop. agent = Agent() agent.replay = ReservoirSamplingBuffer(max_size=200) agent.loss = MaskedCrossEntropy() agent.reg_loss = LearningWithoutForgetting(alpha=1, temperature=2) - # model + # Avalanche models support dynamic addition and expansion of parameters agent.model = SimpleMLP().cuda() agent.model.classifier = IncrementalClassifier(in_features=512).cuda() agent.add_pre_hooks(lambda a, e: avalanche_model_adaptation(a.model, e)) - # optim and scheduler + # optimizer and scheduler agent.opt = SGD(agent.model.parameters(), lr=0.001) - agent.add_pre_hooks(lambda a, e: reset_optimizer(a.opt, a.model)) # TODO: update after albin PR is merged - # agent_state.opt = DynamicOptimizer(opt) # TODO: maybe we can do this with some hooks??? agent.scheduler = ExponentialLR(agent.opt, gamma=0.999) + # we use a hook to update the optimizer before each experience. + # This is needed because the model's parameters may change if you are using + # a dynamic model. + agent.add_pre_hooks(lambda a, e: update_optimizer(a.opt, new_params={}, optimized_params=dict(a.model.named_parameters()))) + # we update the lr scheduler after each experience (not every epoch!) agent.add_post_hooks(lambda a, e: a.scheduler.step()) print("Start training...") metrics = [Accuracy()] - mc = MetricCollector() + mc = MetricCollector() # we put all the metric values here for exp in ocl_stream: - if exp.logging().is_first_subexp: + if exp.logging().is_first_subexp: # new drift print("training on classes ", exp.classes_in_this_experience) + # The training function returns the updated agent state at each step. agent = train_experience(agent, exp, epochs=2) if exp.logging().is_last_subexp: + # after learning new classes, do the eval + # notice that even in online settings we use the non-online stream + # for evaluation because each experience in the original stream + # corresponds to a separate data distribution. res = my_eval(agent.model, train_stream, metrics) mc.update(res, stream=train_stream) - print("RES: ", res) + print("TRAIN METRICS: ", res) res = my_eval(agent.model, test_stream, metrics) mc.update(res, stream=test_stream) - print("EVAL RES: ", res) + print("TEST METRICS: ", res) print() + os.makedirs("./logs", exist_ok=True) + torch.save(agent.model.state_dict(), "./logs/model.pth") + print(mc.get_dict()) + mc.to_json("./logs/mc.json") # save metrics acc_timeline = mc.get("Accuracy", exp_reduce="sample_mean", stream=test_stream) - print(mc.get("Accuracy", exp_reduce=None, stream=test_stream)) - print(acc_timeline) + print("Accuracy:\n", mc.get("Accuracy", exp_reduce=None, stream=test_stream)) + print("AVG Accuracy: ", acc_timeline) acc_matrix = mc.get("Accuracy", exp_reduce=None, stream=test_stream) fig = plot_metric_matrix(acc_matrix, title="Accuracy - Train") - fig.savefig("accuracy.png") + fig.savefig("./logs/accuracy.png") fm = forgetting(acc_matrix) + print("Forgetting:\n", fm) fig = plot_metric_matrix(fm, title="Forgetting - Train") - fig.savefig("forgetting.png") + fig.savefig("./logs/forgetting.png") + + # BONUS: example of re-loading the model from disk + model = SimpleMLP() + model.classifier = IncrementalClassifier(in_features=512) + for exp in ocl_stream: # we need to adapt the model's architecture again + avalanche_model_adaptation(model, exp) + model.load_state_dict(torch.load("./logs/model.pth")) From e721e1779448bdee751c3b23ba7fa5655f15571f Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Thu, 28 Mar 2024 13:21:00 +0100 Subject: [PATCH 12/18] update and fix tests --- avalanche/_annotations.py | 1 - avalanche/core.py | 96 ++++++++++++++++++++++++---- avalanche/evaluation/collector.py | 6 +- avalanche/training/storage_policy.py | 3 +- examples/updatable_objects.py | 4 +- tests/evaluation/test_collector.py | 41 +++++++----- tests/test_core.py | 56 ++++++++++++++++ 7 files changed, 173 insertions(+), 34 deletions(-) create mode 100644 tests/test_core.py diff --git a/avalanche/_annotations.py b/avalanche/_annotations.py index e2795b116..c23ac7d54 100644 --- a/avalanche/_annotations.py +++ b/avalanche/_annotations.py @@ -50,7 +50,6 @@ def wrapper(*args, **kwargs): return decorator -# TODO: show deprecation warning only once def deprecated(version: float, reason: str): """Decorator to mark functions as deprecated. diff --git a/avalanche/core.py b/avalanche/core.py index ff31809d7..efe384d8c 100644 --- a/avalanche/core.py +++ b/avalanche/core.py @@ -12,10 +12,53 @@ class Agent: + """Avalanche Continual Learning Agent. + + The agent stores the state needed by continual learning training methods, + such as optimizers, models, regularization losses. + You can add any objects as attributes dynamically: + + .. code-block:: + + agent = Agent() + agent.replay = ReservoirSamplingBuffer(max_size=200) + agent.loss = MaskedCrossEntropy() + agent.reg_loss = LearningWithoutForgetting(alpha=1, temperature=2) + agent.model = my_model + agent.opt = SGD(agent.model.parameters(), lr=0.001) + agent.scheduler = ExponentialLR(agent.opt, gamma=0.999) + + Many CL objects will need to perform some operation before or + after training on each experience. This is supported via the `Adaptable` + Protocol, which requires the `pre_adapt` and `post_adapt` methods. + To call the pre/post adaptation you can implement your training loop + like in the following example: + + .. code-block:: + + def train(agent, exp): + agent.pre_adapt(exp) + # do training here + agent.post_update(exp) + + Objects that implement the `Adaptable` Protocol will be called by the Agent. + + You can also add additional functionality to the adaptation phases with + hooks. For example: + + .. code-block:: + agent.add_pre_hooks(lambda a, e: update_optimizer(a.opt, new_params={}, optimized_params=dict(a.model.named_parameters()))) + # we update the lr scheduler after each experience (not every epoch!) + agent.add_post_hooks(lambda a, e: a.scheduler.step()) + + + """ def __init__(self, verbose=False): - # TODO: doc - # TODO: test pre_update call - # TODO: test post_update call + """Init. + + :param verbose: If True, print every time an adaptable object or hook + is called during the adaptation. Useful for debugging. + """ self._updatable_objects = [] self.verbose = verbose self._pre_hooks = [] @@ -29,7 +72,12 @@ def __setattr__(self, name, value): print("Added updatable object ", value) def pre_adapt(self, exp): - # TODO: doc + """Pre-adaptation. + + Remember to call this before training on a new experience. + + :param exp: current experience + """ for uo in self._updatable_objects: if hasattr(uo, 'pre_adapt'): uo.pre_adapt(self, exp) @@ -41,7 +89,12 @@ def pre_adapt(self, exp): foo(self, exp) def post_adapt(self, exp): - # TODO: doc + """Post-adaptation. + + Remember to call this after training on a new experience. + + :param exp: current experience + """ for uo in self._updatable_objects: if hasattr(uo, 'post_adapt'): uo.post_adapt(self, exp) @@ -53,20 +106,39 @@ def post_adapt(self, exp): foo(self, exp) def add_pre_hooks(self, foo): - # TODO: doc - # TODO: test + """Add a pre-adaptation hooks + + Hooks take two arguments: ``. + + :param foo: the hook function + """ self._pre_hooks.append(foo) def add_post_hooks(self, foo): - # TODO: doc - # TODO: test + """Add a post-adaptation hooks + + Hooks take two arguments: ``. + + :param foo: the hook function + """ self._post_hooks.append(foo) -@runtime_checkable class Adaptable(Protocol): - # TODO: doc - # TODO: test runtime-checkability + """Adaptable objects Protocol. + + These class documents the Adaptable objects API but it is not necessary + for an object to inherit from it since the `Agent` will search for the methods + dynamically. + + Adaptable objects are objects that require to run their `pre_adapt` and + `post_adapt` methods before (and after, respectively) training on each + experience. + + Adaptable objects can implement only the method that they need since the + `Agent` will look for the methods dynamically and call it only if it is + implemented. + """ def pre_adapt(self, agent: Agent, exp: CLExperience): pass diff --git a/avalanche/evaluation/collector.py b/avalanche/evaluation/collector.py index f0b964e16..1720286d1 100644 --- a/avalanche/evaluation/collector.py +++ b/avalanche/evaluation/collector.py @@ -15,6 +15,11 @@ class MetricCollector: Functionlity includes the ability to store metrics over time and compute aggregated values of them. Serialization is supported via json files. + You can pass a stream to the `update` method to compute separate metrics + for each stream. Metric names will become `{stream.name}/{metric_name}`. + When you recover the stream using the `get` method, you need to pass the + same stream. + Example usage: .. code-block:: python @@ -52,7 +57,6 @@ def update(self, res, *, stream=None): :param stream: optional stream. If a stream is given the full metric name becomes `f'{stream.name}/{metric_name}'`. """ - # TODO: test multi groups for k, v in res.items(): # optional safety check on metric shape # metrics are expected to have one value for each experience in the stream diff --git a/avalanche/training/storage_policy.py b/avalanche/training/storage_policy.py index cadb176bc..ca450a9dd 100644 --- a/avalanche/training/storage_policy.py +++ b/avalanche/training/storage_policy.py @@ -66,7 +66,6 @@ def update(self, strategy: "SupervisedTemplate", **kwargs): # this should work until we deprecate self.update self.post_adapt(strategy, strategy.experience) - @abstractmethod def post_adapt(self, agent_state, exp): """Update `self.buffer` using the agent state and current experience. @@ -74,7 +73,7 @@ def post_adapt(self, agent_state, exp): :param exp: :return: """ - ... + pass @abstractmethod def resize(self, strategy: "SupervisedTemplate", new_size: int): diff --git a/examples/updatable_objects.py b/examples/updatable_objects.py index 75e941d71..4ee9d5c3c 100644 --- a/examples/updatable_objects.py +++ b/examples/updatable_objects.py @@ -50,7 +50,6 @@ def train_experience(agent_state, exp, epochs=10): l.backward() agent_state.opt.step() agent_state.post_adapt(exp) # update objects and call post_hooks - return agent_state @torch.no_grad() @@ -111,8 +110,7 @@ def my_eval(model, stream, metrics): for exp in ocl_stream: if exp.logging().is_first_subexp: # new drift print("training on classes ", exp.classes_in_this_experience) - # The training function returns the updated agent state at each step. - agent = train_experience(agent, exp, epochs=2) + train_experience(agent, exp, epochs=2) if exp.logging().is_last_subexp: # after learning new classes, do the eval # notice that even in online settings we use the non-online stream diff --git a/tests/evaluation/test_collector.py b/tests/evaluation/test_collector.py index 456853395..2fc794788 100644 --- a/tests/evaluation/test_collector.py +++ b/tests/evaluation/test_collector.py @@ -13,42 +13,53 @@ class MetricCollectorTests(unittest.TestCase): def test_time_reduce(self): # we just need an object with __len__ - fake_stream = [ - SimpleNamespace(dataset=list(range(1))), - SimpleNamespace(dataset=list(range(2))), - ] + class ToyStream: + def __init__(self): + self.els = [ + SimpleNamespace(dataset=list(range(1))), + SimpleNamespace(dataset=list(range(2))), + ] + self.name = "toy" + + def __getitem__(self, item): + return self.els[item] + + def __len__(self): + return len(self.els) + + fake_stream = ToyStream() # m = MockMetric([0, 1, 2, 3, 4, 5]) - mc = MetricCollector(fake_stream) + mc = MetricCollector() stream_of_mvals = [ [1, 3], [5, 7], [11, 13] ] for vvv in stream_of_mvals: - mc.update({"FakeMetric": vvv}) + mc.update({"toy/FakeMetric": vvv}) # time_reduce = None - v = mc.get("FakeMetric", time_reduce=None, exp_reduce=None) + v = mc.get("FakeMetric", time_reduce=None, exp_reduce=None, stream=fake_stream) np.testing.assert_array_almost_equal(v, stream_of_mvals) - v = mc.get("FakeMetric", time_reduce=None, exp_reduce="sample_mean") + v = mc.get("FakeMetric", time_reduce=None, exp_reduce="sample_mean", stream=fake_stream) np.testing.assert_array_almost_equal(v, [(1+3*2)/3, (5+7*2)/3, (11+13*2)/3]) - v = mc.get("FakeMetric", time_reduce=None, exp_reduce="experience_mean") + v = mc.get("FakeMetric", time_reduce=None, exp_reduce="experience_mean", stream=fake_stream) np.testing.assert_array_almost_equal(v, [(1+3)/2, (5+7)/2, (11+13)/2]) # time = "last" - v = mc.get("FakeMetric", time_reduce="last", exp_reduce=None) + v = mc.get("FakeMetric", time_reduce="last", exp_reduce=None, stream=fake_stream) np.testing.assert_array_almost_equal(v, [11, 13]) - v = mc.get("FakeMetric", time_reduce="last", exp_reduce="sample_mean") + v = mc.get("FakeMetric", time_reduce="last", exp_reduce="sample_mean", stream=fake_stream) self.assertAlmostEqual(v, (11+13*2)/3) - v = mc.get("FakeMetric", time_reduce="last", exp_reduce="experience_mean") + v = mc.get("FakeMetric", time_reduce="last", exp_reduce="experience_mean", stream=fake_stream) self.assertAlmostEqual(v, (11+13)/2) # time_reduce = "mean" - v = mc.get("FakeMetric", time_reduce="mean", exp_reduce=None) + v = mc.get("FakeMetric", time_reduce="mean", exp_reduce=None, stream=fake_stream) np.testing.assert_array_almost_equal(v, [(1+5+11)/3, (3+7+13)/3]) - v = mc.get("FakeMetric", time_reduce="mean", exp_reduce="sample_mean") + v = mc.get("FakeMetric", time_reduce="mean", exp_reduce="sample_mean", stream=fake_stream) self.assertAlmostEqual(v, ((1+5+11)/3 + (3+7+13)/3*2) / 3) - v = mc.get("FakeMetric", time_reduce="mean", exp_reduce="experience_mean") + v = mc.get("FakeMetric", time_reduce="mean", exp_reduce="experience_mean", stream=fake_stream) self.assertAlmostEqual(v, ((1+5+11)/3 + (3+7+13)/3) / 2) diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 000000000..eb088d368 --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,56 @@ +import unittest + +from avalanche.core import Agent + + +class SimpleAdaptableObject: + def __init__(self): + self.has_called_pre = False + self.has_called_post = False + + def pre_adapt(self, agent, exp): + self.has_called_pre = True + + def post_adapt(self, agent, exp): + self.has_called_post = True + + +class AgentTests(unittest.TestCase): + def test_adapt_hooks(self): + agent = Agent() + + IS_FOO_PRE_CALLED = False + def foo_pre(a, e): + nonlocal IS_FOO_PRE_CALLED + IS_FOO_PRE_CALLED = True + agent.add_pre_hooks(foo_pre) + + IS_FOO_POST_CALLED = False + def foo_post(a, e): + nonlocal IS_FOO_POST_CALLED + IS_FOO_POST_CALLED = True + agent.add_post_hooks(foo_post) + + e = None + agent.pre_adapt(e) + assert IS_FOO_PRE_CALLED + assert not IS_FOO_POST_CALLED + + IS_FOO_PRE_CALLED = False + agent.post_adapt(e) + assert IS_FOO_POST_CALLED + assert not IS_FOO_PRE_CALLED + + def test_adapt_objects(self): + agent = Agent() + agent.obj = SimpleAdaptableObject() + + e = None + agent.pre_adapt(e) + assert agent.obj.has_called_pre + assert not agent.obj.has_called_post + + agent.obj.has_called_pre = False + agent.post_adapt(e) + assert agent.obj.has_called_post + assert not agent.obj.has_called_pre From 639454eef8564b2205bf499d09dbafa4469f1884 Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Wed, 3 Apr 2024 17:48:09 +0200 Subject: [PATCH 13/18] formatting --- avalanche/benchmarks/scenarios/online.py | 7 +-- avalanche/core.py | 8 ++- avalanche/evaluation/collector.py | 19 +++++-- avalanche/evaluation/functional.py | 6 +- avalanche/evaluation/plot_utils.py | 18 +++--- avalanche/models/dynamic_optimizers.py | 2 +- examples/updatable_objects.py | 17 ++++-- tests/evaluation/test_collector.py | 71 +++++++++++++++++------- tests/evaluation/test_functional.py | 26 ++------- tests/evaluation/test_plots.py | 2 +- tests/test_core.py | 4 ++ 11 files changed, 107 insertions(+), 73 deletions(-) diff --git a/avalanche/benchmarks/scenarios/online.py b/avalanche/benchmarks/scenarios/online.py index 6f692975b..1e51dfadc 100644 --- a/avalanche/benchmarks/scenarios/online.py +++ b/avalanche/benchmarks/scenarios/online.py @@ -254,7 +254,7 @@ def _default_online_split( access_task_boundaries: bool, exp: DatasetExperience, size: int, - seed: int + seed: int, ): return FixedSizeExperienceSplitter( experience=exp, @@ -262,7 +262,7 @@ def _default_online_split( shuffle=shuffle, drop_last=drop_last, access_task_boundaries=access_task_boundaries, - seed=seed + seed=seed, ) @@ -316,8 +316,7 @@ def split_online_stream( # However, MyPy does not understand what a partial is -_- def default_online_split_wrapper(e, e_sz): return _default_online_split( - shuffle, drop_last, access_task_boundaries, e, e_sz, - seed=seed + shuffle, drop_last, access_task_boundaries, e, e_sz, seed=seed ) split_strategy = default_online_split_wrapper diff --git a/avalanche/core.py b/avalanche/core.py index efe384d8c..4e0b5b279 100644 --- a/avalanche/core.py +++ b/avalanche/core.py @@ -53,6 +53,7 @@ def train(agent, exp): """ + def __init__(self, verbose=False): """Init. @@ -66,7 +67,7 @@ def __init__(self, verbose=False): def __setattr__(self, name, value): super().__setattr__(name, value) - if hasattr(value, 'pre_adapt') or hasattr(value, 'post_adapt'): + if hasattr(value, "pre_adapt") or hasattr(value, "post_adapt"): self._updatable_objects.append(value) if self.verbose: print("Added updatable object ", value) @@ -79,7 +80,7 @@ def pre_adapt(self, exp): :param exp: current experience """ for uo in self._updatable_objects: - if hasattr(uo, 'pre_adapt'): + if hasattr(uo, "pre_adapt"): uo.pre_adapt(self, exp) if self.verbose: print("pre_adapt ", uo) @@ -96,7 +97,7 @@ def post_adapt(self, exp): :param exp: current experience """ for uo in self._updatable_objects: - if hasattr(uo, 'post_adapt'): + if hasattr(uo, "post_adapt"): uo.post_adapt(self, exp) if self.verbose: print("post_adapt ", uo) @@ -139,6 +140,7 @@ class Adaptable(Protocol): `Agent` will look for the methods dynamically and call it only if it is implemented. """ + def pre_adapt(self, agent: Agent, exp: CLExperience): pass diff --git a/avalanche/evaluation/collector.py b/avalanche/evaluation/collector.py index 1720286d1..30825b000 100644 --- a/avalanche/evaluation/collector.py +++ b/avalanche/evaluation/collector.py @@ -1,6 +1,7 @@ """ Utilities for metrics collection outside of Avalanche training Templates. """ + import json import numpy as np @@ -31,6 +32,7 @@ class MetricCollector: acc_timeline = mc.get("Accuracy", exp_reduce="sample_mean", stream=test_stream) """ + def __init__(self): """Init.""" self.metrics_res = {} @@ -64,8 +66,10 @@ def update(self, res, *, stream=None): if stream.name not in self._stream_len: self._init_stream(stream) if len(v) != self._stream_len[stream.name]: - raise ValueError(f"Length does not correspond to stream. " - f"Found {len(v)}, expected {self._stream_len[stream.name]}") + raise ValueError( + f"Length does not correspond to stream. " + f"Found {len(v)}, expected {self._stream_len[stream.name]}" + ) # update metrics dictionary if stream is not None: @@ -99,7 +103,9 @@ def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None): if stream is not None: name = f"{stream.name}/{name}" if name not in self.metrics_res: - print(f"{name} metric was not found. Maybe you forgot the `stream` argument?") + print( + f"{name} metric was not found. Maybe you forgot the `stream` argument?" + ) mvals = np.array(self.metrics_res[name]) if exp_reduce is None: @@ -108,7 +114,8 @@ def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None): if stream is None: raise ValueError( "If you want to use `exp_reduce == sample_mean` you need to provide" - "the `stream` argument to the `update` and `get` methods.") + "the `stream` argument to the `update` and `get` methods." + ) if stream.name not in self._coeffs: self._init_stream(stream) mvals = mvals * self._coeffs[stream.name][None, :] # weight by num.samples @@ -152,7 +159,7 @@ def load_json(self, fname): :param fname: file name """ # TODO: test - with open(fname, 'w') as f: + with open(fname, "w") as f: self.metrics_res = json.load(f) def to_json(self, fname): @@ -162,5 +169,5 @@ def to_json(self, fname): :return: """ # TODO: test - with open(fname, 'w') as f: + with open(fname, "w") as f: json.dump(obj=self.metrics_res, fp=f) diff --git a/avalanche/evaluation/functional.py b/avalanche/evaluation/functional.py index f591b4a86..862e8a240 100644 --- a/avalanche/evaluation/functional.py +++ b/avalanche/evaluation/functional.py @@ -27,8 +27,10 @@ def forgetting(accuracy_matrix, boundary_indices=None): acc_first = accuracy_matrix[t_task][k] # before learning exp k, forgetting is zero - forgetting_matrix[:t_task+1, k] = 0 + forgetting_matrix[: t_task + 1, k] = 0 # then, it's acc_first - acc[t, k] - forgetting_matrix[t_task+1:, k] = acc_first - accuracy_matrix[t_task+1:, k] + forgetting_matrix[t_task + 1 :, k] = ( + acc_first - accuracy_matrix[t_task + 1 :, k] + ) return forgetting_matrix diff --git a/avalanche/evaluation/plot_utils.py b/avalanche/evaluation/plot_utils.py index 9ee83cd62..0c7b5f591 100644 --- a/avalanche/evaluation/plot_utils.py +++ b/avalanche/evaluation/plot_utils.py @@ -31,13 +31,7 @@ def learning_curves_plot(all_metrics: dict): return fig -def plot_metric_matrix( - metric_matrix, - title, - *, - ax=None, - text_values=True -): +def plot_metric_matrix(metric_matrix, title, *, ax=None, text_values=True): """Plot a matrix of metrics (e.g. forgetting over time). :param metric_matrix: 2D accuracy matrix with shape @@ -61,6 +55,12 @@ def plot_metric_matrix( if text_values: for i in range(len(metric_matrix)): for j in range(len(metric_matrix[0])): - ax.text(j, i, f"{metric_matrix[i][j]:.3f}", - ha="center", va="center", color="w") + ax.text( + j, + i, + f"{metric_matrix[i][j]:.3f}", + ha="center", + va="center", + color="w", + ) return fig diff --git a/avalanche/models/dynamic_optimizers.py b/avalanche/models/dynamic_optimizers.py index e5dc6108f..33d6eb72d 100644 --- a/avalanche/models/dynamic_optimizers.py +++ b/avalanche/models/dynamic_optimizers.py @@ -348,4 +348,4 @@ def add_new_params_to_optimizer(optimizer, new_params): optimizer.add_param_group({"params": new_params}) -__all__ = ["add_new_params_to_optimizer", "reset_optimizer", "update_optimizer"] \ No newline at end of file +__all__ = ["add_new_params_to_optimizer", "reset_optimizer", "update_optimizer"] diff --git a/examples/updatable_objects.py b/examples/updatable_objects.py index 4ee9d5c3c..481a0bcbe 100644 --- a/examples/updatable_objects.py +++ b/examples/updatable_objects.py @@ -25,15 +25,18 @@ def train_experience(agent_state, exp, epochs=10): agent_state.model.train() - data = exp.dataset.train() # avalanche datasets have train/eval modes to switch augmentations + data = ( + exp.dataset.train() + ) # avalanche datasets have train/eval modes to switch augmentations agent_state.pre_adapt(exp) # update objects and call pre_hooks for ep in range(epochs): if len(agent_state.replay.buffer) > 0: # if the replay buffer is not empty we sample from # current data and replay buffer in parallel - dl = ReplayDataLoader(data, agent_state.replay.buffer, - batch_size=32, shuffle=True) + dl = ReplayDataLoader( + data, agent_state.replay.buffer, batch_size=32, shuffle=True + ) else: dl = DataLoader(data, batch_size=32, shuffle=True) @@ -70,7 +73,7 @@ def my_eval(model, stream, metrics): return res -if __name__ == '__main__': +if __name__ == "__main__": bm = SplitMNIST(n_experiences=5) train_stream, test_stream = bm.train_stream, bm.test_stream # we split the training stream into online experiences @@ -100,7 +103,11 @@ def my_eval(model, stream, metrics): # we use a hook to update the optimizer before each experience. # This is needed because the model's parameters may change if you are using # a dynamic model. - agent.add_pre_hooks(lambda a, e: update_optimizer(a.opt, new_params={}, optimized_params=dict(a.model.named_parameters()))) + agent.add_pre_hooks( + lambda a, e: update_optimizer( + a.opt, new_params={}, optimized_params=dict(a.model.named_parameters()) + ) + ) # we update the lr scheduler after each experience (not every epoch!) agent.add_post_hooks(lambda a, e: a.scheduler.step()) diff --git a/tests/evaluation/test_collector.py b/tests/evaluation/test_collector.py index 2fc794788..482ccbcb3 100644 --- a/tests/evaluation/test_collector.py +++ b/tests/evaluation/test_collector.py @@ -30,37 +30,68 @@ def __len__(self): fake_stream = ToyStream() # m = MockMetric([0, 1, 2, 3, 4, 5]) mc = MetricCollector() - stream_of_mvals = [ - [1, 3], - [5, 7], - [11, 13] - ] + stream_of_mvals = [[1, 3], [5, 7], [11, 13]] for vvv in stream_of_mvals: mc.update({"toy/FakeMetric": vvv}) # time_reduce = None v = mc.get("FakeMetric", time_reduce=None, exp_reduce=None, stream=fake_stream) np.testing.assert_array_almost_equal(v, stream_of_mvals) - v = mc.get("FakeMetric", time_reduce=None, exp_reduce="sample_mean", stream=fake_stream) - np.testing.assert_array_almost_equal(v, [(1+3*2)/3, (5+7*2)/3, (11+13*2)/3]) - v = mc.get("FakeMetric", time_reduce=None, exp_reduce="experience_mean", stream=fake_stream) - np.testing.assert_array_almost_equal(v, [(1+3)/2, (5+7)/2, (11+13)/2]) + v = mc.get( + "FakeMetric", time_reduce=None, exp_reduce="sample_mean", stream=fake_stream + ) + np.testing.assert_array_almost_equal( + v, [(1 + 3 * 2) / 3, (5 + 7 * 2) / 3, (11 + 13 * 2) / 3] + ) + v = mc.get( + "FakeMetric", + time_reduce=None, + exp_reduce="experience_mean", + stream=fake_stream, + ) + np.testing.assert_array_almost_equal( + v, [(1 + 3) / 2, (5 + 7) / 2, (11 + 13) / 2] + ) # time = "last" - v = mc.get("FakeMetric", time_reduce="last", exp_reduce=None, stream=fake_stream) + v = mc.get( + "FakeMetric", time_reduce="last", exp_reduce=None, stream=fake_stream + ) np.testing.assert_array_almost_equal(v, [11, 13]) - v = mc.get("FakeMetric", time_reduce="last", exp_reduce="sample_mean", stream=fake_stream) - self.assertAlmostEqual(v, (11+13*2)/3) - v = mc.get("FakeMetric", time_reduce="last", exp_reduce="experience_mean", stream=fake_stream) - self.assertAlmostEqual(v, (11+13)/2) + v = mc.get( + "FakeMetric", + time_reduce="last", + exp_reduce="sample_mean", + stream=fake_stream, + ) + self.assertAlmostEqual(v, (11 + 13 * 2) / 3) + v = mc.get( + "FakeMetric", + time_reduce="last", + exp_reduce="experience_mean", + stream=fake_stream, + ) + self.assertAlmostEqual(v, (11 + 13) / 2) # time_reduce = "mean" - v = mc.get("FakeMetric", time_reduce="mean", exp_reduce=None, stream=fake_stream) - np.testing.assert_array_almost_equal(v, [(1+5+11)/3, (3+7+13)/3]) - v = mc.get("FakeMetric", time_reduce="mean", exp_reduce="sample_mean", stream=fake_stream) - self.assertAlmostEqual(v, ((1+5+11)/3 + (3+7+13)/3*2) / 3) - v = mc.get("FakeMetric", time_reduce="mean", exp_reduce="experience_mean", stream=fake_stream) - self.assertAlmostEqual(v, ((1+5+11)/3 + (3+7+13)/3) / 2) + v = mc.get( + "FakeMetric", time_reduce="mean", exp_reduce=None, stream=fake_stream + ) + np.testing.assert_array_almost_equal(v, [(1 + 5 + 11) / 3, (3 + 7 + 13) / 3]) + v = mc.get( + "FakeMetric", + time_reduce="mean", + exp_reduce="sample_mean", + stream=fake_stream, + ) + self.assertAlmostEqual(v, ((1 + 5 + 11) / 3 + (3 + 7 + 13) / 3 * 2) / 3) + v = mc.get( + "FakeMetric", + time_reduce="mean", + exp_reduce="experience_mean", + stream=fake_stream, + ) + self.assertAlmostEqual(v, ((1 + 5 + 11) / 3 + (3 + 7 + 13) / 3) / 2) if __name__ == "__main__": diff --git a/tests/evaluation/test_functional.py b/tests/evaluation/test_functional.py index 650fe1065..7c78a509f 100644 --- a/tests/evaluation/test_functional.py +++ b/tests/evaluation/test_functional.py @@ -10,33 +10,15 @@ class FunctionalMetricTests(unittest.TestCase): def test_diag_forgetting(self): - vals = [ - [1.0, 0.0, 0.0], - [0.5, 0.8, 1.0], - [0.2, 1.0, 0.5] - ] - target = [ - [0.0, 0.0, 0.0], - [0.5, 0.0, 0.0], - [0.8, -0.2, 0.0] - ] + vals = [[1.0, 0.0, 0.0], [0.5, 0.8, 1.0], [0.2, 1.0, 0.5]] + target = [[0.0, 0.0, 0.0], [0.5, 0.0, 0.0], [0.8, -0.2, 0.0]] fm = forgetting(vals) np.testing.assert_array_almost_equal(fm, target) def test_non_diag_forgetting(self): """non-diagonal matrix requires explicit boundary indices""" - vals = [ - [1.0, 0.0, 0.0], - [0.5, 0.8, 1.0], - [0.5, 0.5, 1.0], - [0.2, 1.0, 0.5] - ] - target = [ - [0.0, 0.0, 0.0], - [0.5, 0.0, 0.0], - [0.5, 0.3, 0.0], - [0.8, -0.2, 0.0] - ] + vals = [[1.0, 0.0, 0.0], [0.5, 0.8, 1.0], [0.5, 0.5, 1.0], [0.2, 1.0, 0.5]] + target = [[0.0, 0.0, 0.0], [0.5, 0.0, 0.0], [0.5, 0.3, 0.0], [0.8, -0.2, 0.0]] fm = forgetting(vals, boundary_indices=[0, 1, 3]) np.testing.assert_array_almost_equal(fm, target) diff --git a/tests/evaluation/test_plots.py b/tests/evaluation/test_plots.py index 65123d901..2a628417a 100644 --- a/tests/evaluation/test_plots.py +++ b/tests/evaluation/test_plots.py @@ -14,7 +14,7 @@ def test_basic(self): [0.57178465, 0.65048787, 0.7563081, 0.56941323, 0.4562], [0.65802592, 0.52339254, 0.58947543, 0.68339576, 0.5474], [0.65802592, 0.52339254, 0.58947543, 0.68339576, 0.5474], - [0.4995015, 0.58819114, 0.57320717, 0.62322097, 0.6925] + [0.4995015, 0.58819114, 0.57320717, 0.62322097, 0.6925], ] plot_metric_matrix(mvals, title="Accuracy - Train") diff --git a/tests/test_core.py b/tests/test_core.py index eb088d368..905f05bb6 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -20,15 +20,19 @@ def test_adapt_hooks(self): agent = Agent() IS_FOO_PRE_CALLED = False + def foo_pre(a, e): nonlocal IS_FOO_PRE_CALLED IS_FOO_PRE_CALLED = True + agent.add_pre_hooks(foo_pre) IS_FOO_POST_CALLED = False + def foo_post(a, e): nonlocal IS_FOO_POST_CALLED IS_FOO_POST_CALLED = True + agent.add_post_hooks(foo_post) e = None From 1808f20df4f0d8d64e2d401e7ba01a032ce02933 Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Wed, 17 Apr 2024 10:16:47 +0200 Subject: [PATCH 14/18] add custom weights argument for exp_reduce in MetricCollector --- avalanche/evaluation/collector.py | 14 +++++++++++--- examples/updatable_objects.py | 8 +++----- tests/evaluation/test_collector.py | 26 ++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/avalanche/evaluation/collector.py b/avalanche/evaluation/collector.py index 30825b000..475459704 100644 --- a/avalanche/evaluation/collector.py +++ b/avalanche/evaluation/collector.py @@ -79,7 +79,7 @@ def update(self, res, *, stream=None): else: self.metrics_res[k] = [v] - def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None): + def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None, weights=None): """Returns a metric value given its name and aggregation method. :param name: name of the metric. @@ -91,14 +91,20 @@ def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None): - None (default) does not use any aggregation - `sample_mean` is an average weighted by the number of samples in each experience - `experience_mean` is an experience average. + - 'weighted_sum' is a weighted sum of the experiences using the `weights` argument. :param stream: stream that was used to compute the metric. This is needed to build the full metric name if the get was called with a stream name and if `exp_reduce == sample_mean` to get the number of samples from each experience. + :param weights: weights for each experience when `exp_reduce == 'weighted_sum`. :return: aggregated metric value. """ assert time_reduce in {None, "last", "mean"} - assert exp_reduce in {None, "sample_mean", "experience_mean"} + assert exp_reduce in {None, "sample_mean", "experience_mean", "weighted_sum"} + if exp_reduce == "weighted_sum": + assert weights is not None, "You should set the `weights` argument when `exp_reduce == 'weighted_sum'`." + else: + assert weights is None, "Can't use the `weights` argument when `exp_reduce != 'weighted_sum'`" if stream is not None: name = f"{stream.name}/{name}" @@ -122,11 +128,13 @@ def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None): mvals = mvals.sum(axis=1) # weighted avg across exp. elif exp_reduce == "experience_mean": mvals = mvals.mean(axis=1) # avg across exp. + elif exp_reduce == "weighted_sum": + weights = np.array(weights)[None, :] + mvals = (mvals * weights).sum(axis=1) else: raise ValueError("BUG. It should never get here.") if time_reduce is None: - pass # nothing to do here elif time_reduce == "last": mvals = mvals[-1] # last timestep diff --git a/examples/updatable_objects.py b/examples/updatable_objects.py index 481a0bcbe..12532f729 100644 --- a/examples/updatable_objects.py +++ b/examples/updatable_objects.py @@ -25,10 +25,8 @@ def train_experience(agent_state, exp, epochs=10): agent_state.model.train() - data = ( - exp.dataset.train() - ) # avalanche datasets have train/eval modes to switch augmentations - + # avalanche datasets have train/eval modes to switch augmentations + data = exp.dataset.train() agent_state.pre_adapt(exp) # update objects and call pre_hooks for ep in range(epochs): if len(agent_state.replay.buffer) > 0: @@ -64,7 +62,7 @@ def my_eval(model, stream, metrics): res = {uo.__class__.__name__: [] for uo in metrics} for exp in stream: [uo.reset() for uo in metrics] - dl = DataLoader(exp.dataset, batch_size=512, num_workers=8) + dl = DataLoader(exp.dataset.eval(), batch_size=512, num_workers=8) for x, y, _ in dl: x, y = x.cuda(), y.cuda() yp = model(x) diff --git a/tests/evaluation/test_collector.py b/tests/evaluation/test_collector.py index 482ccbcb3..e6146fb99 100644 --- a/tests/evaluation/test_collector.py +++ b/tests/evaluation/test_collector.py @@ -52,6 +52,16 @@ def __len__(self): np.testing.assert_array_almost_equal( v, [(1 + 3) / 2, (5 + 7) / 2, (11 + 13) / 2] ) + v = mc.get( + "FakeMetric", + time_reduce=None, + exp_reduce="weighted_sum", + weights=[1, 2], + stream=fake_stream + ) + np.testing.assert_array_almost_equal( + v, [(1 + 3*2), (5 + 7*2), (11 + 13*2)] + ) # time = "last" v = mc.get( @@ -72,6 +82,14 @@ def __len__(self): stream=fake_stream, ) self.assertAlmostEqual(v, (11 + 13) / 2) + v = mc.get( + "FakeMetric", + time_reduce="last", + exp_reduce="weighted_sum", + stream=fake_stream, + weights=[1, 2] + ) + self.assertAlmostEqual(v, 11 + 13*2) # time_reduce = "mean" v = mc.get( @@ -92,6 +110,14 @@ def __len__(self): stream=fake_stream, ) self.assertAlmostEqual(v, ((1 + 5 + 11) / 3 + (3 + 7 + 13) / 3) / 2) + v = mc.get( + "FakeMetric", + time_reduce="mean", + exp_reduce="weighted_sum", + stream=fake_stream, + weights=[1, 2] + ) + self.assertAlmostEqual(v, ((1 + 3*2) + (5 + 7*2) + (11 + 13*2)) / 3) if __name__ == "__main__": From 847af49e71065173b144220c02c80085e2678f61 Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Wed, 17 Apr 2024 10:20:14 +0200 Subject: [PATCH 15/18] update doc --- avalanche/core.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/avalanche/core.py b/avalanche/core.py index 4e0b5b279..4a744f813 100644 --- a/avalanche/core.py +++ b/avalanche/core.py @@ -1,4 +1,10 @@ -# TODO: doc +""" +This module contains Protocols for some of the main components of Avalanche, +such as strategy plugins and the agent state. + +Most of these protocols are checked dynamically at runtime, so it is often not +necessary to inherit explicit from them or implement all the methods. +""" from abc import ABC from typing import Any, TypeVar, Generic, Protocol, runtime_checkable from typing import TYPE_CHECKING @@ -39,7 +45,7 @@ class Agent: def train(agent, exp): agent.pre_adapt(exp) # do training here - agent.post_update(exp) + agent.post_adapt(exp) Objects that implement the `Adaptable` Protocol will be called by the Agent. From 1128d4321ad53c38f569d9da795a32088da4e359 Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Wed, 17 Apr 2024 11:44:50 +0200 Subject: [PATCH 16/18] add DynamicOptimizer wrapper --- avalanche/models/dynamic_modules.py | 2 +- avalanche/models/dynamic_optimizers.py | 53 ++++++++++++++++++++++++- examples/updatable_objects.py | 22 +++++----- tests/models/test_dynamic_optimizers.py | 34 ++++++++++++++++ 4 files changed, 97 insertions(+), 14 deletions(-) create mode 100644 tests/models/test_dynamic_optimizers.py diff --git a/avalanche/models/dynamic_modules.py b/avalanche/models/dynamic_modules.py index b2466dff9..9a43cda54 100644 --- a/avalanche/models/dynamic_modules.py +++ b/avalanche/models/dynamic_modules.py @@ -73,7 +73,7 @@ def __init__(self, auto_adapt=True): super().__init__() self._auto_adapt = auto_adapt - def pre_adapt(self, experience): + def pre_adapt(self, agent, experience): """ Calls self.adaptation recursively accross the hierarchy of pytorch module childrens diff --git a/avalanche/models/dynamic_optimizers.py b/avalanche/models/dynamic_optimizers.py index 33d6eb72d..a1a97523c 100644 --- a/avalanche/models/dynamic_optimizers.py +++ b/avalanche/models/dynamic_optimizers.py @@ -18,7 +18,9 @@ import numpy as np -from avalanche._annotations import deprecated +from avalanche._annotations import deprecated, experimental +from avalanche.benchmarks import CLExperience +from avalanche.core import Adaptable, Agent colors = { "END": "\033[0m", @@ -31,6 +33,55 @@ colors[None] = colors["END"] +@experimental("New dynamic optimizers. The API may slightly change in the next versions.") +class DynamicOptimizer(Adaptable): + """Avalanche dynamic optimizer. + + In continual learning, many model architecture may change over time (e.g. + adding new units to the classifier). This is handled by the `DynamicModule`. + Changing the model's architecture requires updating the optimizer to add + the new parameters in the optimizer's state. + + This class provides a simple wrapper to handle the optimizer + update via the `Adaptable` Protocol. + + This class provides direct methods only for `zero_grad` and `step` to support + basic training functionality. Other methods of the base optimizers must be + called by using the base optimizer directly (e.g. `self.optim.add_param_group`). + + NOTE: the model must be adapted *before* calling this method. + To ensure this, ensure that the optimizer is added to the agent state + after the model: + + .. code-block:: + + agent.model = model + agent.optimizer = optimizer + + # ... more init code + + # pre_adapt will call the pre_adapt methods in order, + # first model.pre_adapt, then optimizer.pre_adapt + agent.pre_adapt(experience) + """ + def __init__(self, optim): + self.optim = optim + + def zero_grad(self): + self.optim.zero_grad() + + def step(self): + self.optim.step() + + def pre_adapt(self, agent: Agent, exp: CLExperience): + """Adapt the optimizer before training on the current experience.""" + update_optimizer( + self.optim, + new_params=dict(agent.model.named_parameters()), + optimized_params=dict(agent.model.named_parameters()) + ) + + def _map_optimized_params(optimizer, parameters, old_params=None): """ Establishes a mapping between a list of named parameters and the parameters diff --git a/examples/updatable_objects.py b/examples/updatable_objects.py index 12532f729..5313561d6 100644 --- a/examples/updatable_objects.py +++ b/examples/updatable_objects.py @@ -18,7 +18,7 @@ from avalanche.evaluation.plot_utils import plot_metric_matrix from avalanche.models import SimpleMLP, IncrementalClassifier from avalanche.models.dynamic_modules import avalanche_model_adaptation -from avalanche.models.dynamic_optimizers import update_optimizer +from avalanche.models.dynamic_optimizers import update_optimizer, DynamicOptimizer from avalanche.training import ReservoirSamplingBuffer, LearningWithoutForgetting from avalanche.training.losses import MaskedCrossEntropy @@ -55,10 +55,11 @@ def train_experience(agent_state, exp, epochs=10): @torch.no_grad() def my_eval(model, stream, metrics): - # eval also becomes simpler. Notice how in Avalanche it's harder to check whether - # we are evaluating a single exp. or the whole stream. - # Now we evaluate each stream with a separate function call + """Evaluate `model` on `stream` computing `metrics`. + Returns a dictionary {metric_name: list-of-results}. + """ + model.eval() res = {uo.__class__.__name__: [] for uo in metrics} for exp in stream: [uo.reset() for uo in metrics] @@ -96,16 +97,13 @@ def my_eval(model, stream, metrics): agent.add_pre_hooks(lambda a, e: avalanche_model_adaptation(a.model, e)) # optimizer and scheduler - agent.opt = SGD(agent.model.parameters(), lr=0.001) - agent.scheduler = ExponentialLR(agent.opt, gamma=0.999) - # we use a hook to update the optimizer before each experience. + # we have update the optimizer before each experience. # This is needed because the model's parameters may change if you are using # a dynamic model. - agent.add_pre_hooks( - lambda a, e: update_optimizer( - a.opt, new_params={}, optimized_params=dict(a.model.named_parameters()) - ) - ) + opt = SGD(agent.model.parameters(), lr=0.001) + agent.opt = DynamicOptimizer(opt) + agent.scheduler = ExponentialLR(opt, gamma=0.999) + # we use a hook to call the scheduler. # we update the lr scheduler after each experience (not every epoch!) agent.add_post_hooks(lambda a, e: a.scheduler.step()) diff --git a/tests/models/test_dynamic_optimizers.py b/tests/models/test_dynamic_optimizers.py new file mode 100644 index 000000000..0012749ed --- /dev/null +++ b/tests/models/test_dynamic_optimizers.py @@ -0,0 +1,34 @@ +import unittest + +from torch.optim import SGD +from torch.utils.data import DataLoader + +from avalanche.core import Agent +from avalanche.models import SimpleMLP, as_multitask +from avalanche.models.dynamic_optimizers import DynamicOptimizer +from avalanche.training import MaskedCrossEntropy +from tests.unit_tests_utils import get_fast_benchmark + + +class TestDynamicOptimizers(unittest.TestCase): + def test_dynamic_optimizer(self): + bm = get_fast_benchmark(use_task_labels=True) + agent = Agent() + agent.loss = MaskedCrossEntropy() + agent.model = as_multitask(SimpleMLP(input_size=6), "classifier") + opt = SGD(agent.model.parameters(), lr=0.001) + agent.opt = DynamicOptimizer(opt) + + for exp in bm.train_stream: + agent.model.train() + data = exp.dataset.train() + agent.pre_adapt(exp) + for ep in range(1): + dl = DataLoader(data, batch_size=32, shuffle=True) + for x, y, t in dl: + agent.opt.zero_grad() + yp = agent.model(x, t) + l = agent.loss(yp, y) + l.backward() + agent.opt.step() + agent.post_adapt(exp) From 447f168006dddca5414f9d096a875df9d93f3334 Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Wed, 17 Apr 2024 11:45:15 +0200 Subject: [PATCH 17/18] format --- avalanche/core.py | 1 + avalanche/evaluation/collector.py | 12 +++++++++--- avalanche/models/dynamic_optimizers.py | 7 +++++-- tests/evaluation/test_collector.py | 12 ++++++------ 4 files changed, 21 insertions(+), 11 deletions(-) diff --git a/avalanche/core.py b/avalanche/core.py index 4a744f813..e1ed5368b 100644 --- a/avalanche/core.py +++ b/avalanche/core.py @@ -5,6 +5,7 @@ Most of these protocols are checked dynamically at runtime, so it is often not necessary to inherit explicit from them or implement all the methods. """ + from abc import ABC from typing import Any, TypeVar, Generic, Protocol, runtime_checkable from typing import TYPE_CHECKING diff --git a/avalanche/evaluation/collector.py b/avalanche/evaluation/collector.py index 475459704..9809f8e59 100644 --- a/avalanche/evaluation/collector.py +++ b/avalanche/evaluation/collector.py @@ -79,7 +79,9 @@ def update(self, res, *, stream=None): else: self.metrics_res[k] = [v] - def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None, weights=None): + def get( + self, name, *, time_reduce=None, exp_reduce=None, stream=None, weights=None + ): """Returns a metric value given its name and aggregation method. :param name: name of the metric. @@ -102,9 +104,13 @@ def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None, weights=N assert time_reduce in {None, "last", "mean"} assert exp_reduce in {None, "sample_mean", "experience_mean", "weighted_sum"} if exp_reduce == "weighted_sum": - assert weights is not None, "You should set the `weights` argument when `exp_reduce == 'weighted_sum'`." + assert ( + weights is not None + ), "You should set the `weights` argument when `exp_reduce == 'weighted_sum'`." else: - assert weights is None, "Can't use the `weights` argument when `exp_reduce != 'weighted_sum'`" + assert ( + weights is None + ), "Can't use the `weights` argument when `exp_reduce != 'weighted_sum'`" if stream is not None: name = f"{stream.name}/{name}" diff --git a/avalanche/models/dynamic_optimizers.py b/avalanche/models/dynamic_optimizers.py index a1a97523c..f587aad47 100644 --- a/avalanche/models/dynamic_optimizers.py +++ b/avalanche/models/dynamic_optimizers.py @@ -33,7 +33,9 @@ colors[None] = colors["END"] -@experimental("New dynamic optimizers. The API may slightly change in the next versions.") +@experimental( + "New dynamic optimizers. The API may slightly change in the next versions." +) class DynamicOptimizer(Adaptable): """Avalanche dynamic optimizer. @@ -64,6 +66,7 @@ class DynamicOptimizer(Adaptable): # first model.pre_adapt, then optimizer.pre_adapt agent.pre_adapt(experience) """ + def __init__(self, optim): self.optim = optim @@ -78,7 +81,7 @@ def pre_adapt(self, agent: Agent, exp: CLExperience): update_optimizer( self.optim, new_params=dict(agent.model.named_parameters()), - optimized_params=dict(agent.model.named_parameters()) + optimized_params=dict(agent.model.named_parameters()), ) diff --git a/tests/evaluation/test_collector.py b/tests/evaluation/test_collector.py index e6146fb99..78ee2ac34 100644 --- a/tests/evaluation/test_collector.py +++ b/tests/evaluation/test_collector.py @@ -57,10 +57,10 @@ def __len__(self): time_reduce=None, exp_reduce="weighted_sum", weights=[1, 2], - stream=fake_stream + stream=fake_stream, ) np.testing.assert_array_almost_equal( - v, [(1 + 3*2), (5 + 7*2), (11 + 13*2)] + v, [(1 + 3 * 2), (5 + 7 * 2), (11 + 13 * 2)] ) # time = "last" @@ -87,9 +87,9 @@ def __len__(self): time_reduce="last", exp_reduce="weighted_sum", stream=fake_stream, - weights=[1, 2] + weights=[1, 2], ) - self.assertAlmostEqual(v, 11 + 13*2) + self.assertAlmostEqual(v, 11 + 13 * 2) # time_reduce = "mean" v = mc.get( @@ -115,9 +115,9 @@ def __len__(self): time_reduce="mean", exp_reduce="weighted_sum", stream=fake_stream, - weights=[1, 2] + weights=[1, 2], ) - self.assertAlmostEqual(v, ((1 + 3*2) + (5 + 7*2) + (11 + 13*2)) / 3) + self.assertAlmostEqual(v, ((1 + 3 * 2) + (5 + 7 * 2) + (11 + 13 * 2)) / 3) if __name__ == "__main__": From 876140c8fe37dbfe990a06f5269d60f1c95e08cd Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Tue, 30 Apr 2024 11:27:21 +0200 Subject: [PATCH 18/18] fix DynamicOptimizer initialization --- avalanche/models/dynamic_optimizers.py | 34 ++++++-- tests/models/test_dynamic_optimizers.py | 100 +++++++++++++++++++++++- tests/models/test_models.py | 6 +- 3 files changed, 127 insertions(+), 13 deletions(-) diff --git a/avalanche/models/dynamic_optimizers.py b/avalanche/models/dynamic_optimizers.py index f587aad47..10e487dab 100644 --- a/avalanche/models/dynamic_optimizers.py +++ b/avalanche/models/dynamic_optimizers.py @@ -67,8 +67,19 @@ class DynamicOptimizer(Adaptable): agent.pre_adapt(experience) """ - def __init__(self, optim): + def __init__(self, optim, model, reset_state=False, verbose=False): self.optim = optim + self.reset_state = reset_state + self.verbose = verbose + + # initialize param-id + self._optimized_param_id = update_optimizer( + self.optim, + dict(model.named_parameters()), + None, + reset_state=True, + verbose=self.verbose, + ) def zero_grad(self): self.optim.zero_grad() @@ -78,10 +89,12 @@ def step(self): def pre_adapt(self, agent: Agent, exp: CLExperience): """Adapt the optimizer before training on the current experience.""" - update_optimizer( + self._optimized_param_id = update_optimizer( self.optim, - new_params=dict(agent.model.named_parameters()), - optimized_params=dict(agent.model.named_parameters()), + dict(agent.model.named_parameters()), + self._optimized_param_id, + reset_state=self.reset_state, + verbose=self.verbose, ) @@ -329,16 +342,22 @@ def update_optimizer( Newly added parameters are added by default to parameter group 0 - :param new_params: Dict (name, param) of new parameters + WARNING: the first call to `update_optimizer` must be done before + calling the model's adaptation. + + :param optimizer: the Optimizer object. + :param new_params: Dict (name, param) of new parameters. :param optimized_params: Dict (name, param) of - currently optimized parameters + currently optimized parameters. In most use cases, it will be `None in + the first call and the return value of the last `update_optimizer` call + for the subsequent calls. :param reset_state: Whether to reset the optimizer's state (i.e momentum). Defaults to False. :param remove_params: Whether to remove parameters that were in the optimizer but are not found in new parameters. For safety reasons, defaults to False. :param verbose: If True, prints information about inferred - parameter groups for new params + parameter groups for new params. :return: Dict (name, param) of optimized parameters """ @@ -381,7 +400,6 @@ def update_optimizer( new_p = new_params[key] group = param_structure[key].single_group optimizer.param_groups[group]["params"].append(new_p) - optimized_params[key] = new_p optimizer.state[new_p] = {} if reset_state: diff --git a/tests/models/test_dynamic_optimizers.py b/tests/models/test_dynamic_optimizers.py index 0012749ed..c7f8b6f25 100644 --- a/tests/models/test_dynamic_optimizers.py +++ b/tests/models/test_dynamic_optimizers.py @@ -1,10 +1,12 @@ import unittest +from torch.nn import CrossEntropyLoss from torch.optim import SGD from torch.utils.data import DataLoader from avalanche.core import Agent -from avalanche.models import SimpleMLP, as_multitask +from avalanche.models import SimpleMLP, as_multitask, IncrementalClassifier, MTSimpleMLP +from avalanche.models.dynamic_modules import avalanche_model_adaptation from avalanche.models.dynamic_optimizers import DynamicOptimizer from avalanche.training import MaskedCrossEntropy from tests.unit_tests_utils import get_fast_benchmark @@ -17,7 +19,7 @@ def test_dynamic_optimizer(self): agent.loss = MaskedCrossEntropy() agent.model = as_multitask(SimpleMLP(input_size=6), "classifier") opt = SGD(agent.model.parameters(), lr=0.001) - agent.opt = DynamicOptimizer(opt) + agent.opt = DynamicOptimizer(opt, agent.model, verbose=False) for exp in bm.train_stream: agent.model.train() @@ -32,3 +34,97 @@ def test_dynamic_optimizer(self): l.backward() agent.opt.step() agent.post_adapt(exp) + + def init_scenario(self, multi_task=False): + if multi_task: + model = MTSimpleMLP(input_size=6, hidden_size=10) + else: + model = SimpleMLP(input_size=6, hidden_size=10) + model.classifier = IncrementalClassifier(10, 1) + criterion = CrossEntropyLoss() + benchmark = get_fast_benchmark(use_task_labels=multi_task) + return model, criterion, benchmark + + def _is_param_in_optimizer(self, param, optimizer): + for group in optimizer.param_groups: + for curr_p in group["params"]: + if hash(curr_p) == hash(param): + return True + return False + + def _is_param_in_optimizer_group(self, param, optimizer): + for group_idx, group in enumerate(optimizer.param_groups): + for curr_p in group["params"]: + if hash(curr_p) == hash(param): + return group_idx + return None + + def test_optimizer_groups_clf_til(self): + """ + Tests the automatic assignation of new + MultiHead parameters to the optimizer + """ + model, criterion, benchmark = self.init_scenario(multi_task=True) + + g1 = [] + g2 = [] + for n, p in model.named_parameters(): + if "classifier" in n: + g1.append(p) + else: + g2.append(p) + + agent = Agent() + agent.model = model + optimizer = SGD([{"params": g1, "lr": 0.1}, {"params": g2, "lr": 0.05}]) + agent.optimizer = DynamicOptimizer(optimizer, model=model, verbose=False) + + for experience in benchmark.train_stream: + avalanche_model_adaptation(model, experience) + agent.optimizer.pre_adapt(agent, experience) + + for n, p in model.named_parameters(): + assert self._is_param_in_optimizer(p, agent.optimizer.optim) + if "classifier" in n: + self.assertEqual( + self._is_param_in_optimizer_group(p, agent.optimizer.optim), 0 + ) + else: + self.assertEqual( + self._is_param_in_optimizer_group(p, agent.optimizer.optim), 1 + ) + + def test_optimizer_groups_clf_cil(self): + """ + Tests the automatic assignation of new + IncrementalClassifier parameters to the optimizer + """ + model, criterion, benchmark = self.init_scenario(multi_task=False) + + g1 = [] + g2 = [] + for n, p in model.named_parameters(): + if "classifier" in n: + g1.append(p) + else: + g2.append(p) + + agent = Agent() + agent.model = model + optimizer = SGD([{"params": g1, "lr": 0.1}, {"params": g2, "lr": 0.05}]) + agent.optimizer = DynamicOptimizer(optimizer, model) + + for experience in benchmark.train_stream: + avalanche_model_adaptation(model, experience) + agent.optimizer.pre_adapt(agent, experience) + + for n, p in model.named_parameters(): + assert self._is_param_in_optimizer(p, agent.optimizer.optim) + if "classifier" in n: + self.assertEqual( + self._is_param_in_optimizer_group(p, agent.optimizer.optim), 0 + ) + else: + self.assertEqual( + self._is_param_in_optimizer_group(p, agent.optimizer.optim), 1 + ) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index a5cbb4f51..fd1dda0f0 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -369,10 +369,10 @@ def test_recursive_adaptation(self): sizes = {} for t, exp in enumerate(benchmark.train_stream): sizes[t] = np.max(exp.classes_in_this_experience) + 1 - model1.pre_adapt(exp) + model1.pre_adapt(None, exp) # Second adaptation should not change anything for t, exp in enumerate(benchmark.train_stream): - model1.pre_adapt(exp) + model1.pre_adapt(None, exp) for t, s in sizes.items(): self.assertEqual(s, model1.classifiers[str(t)].classifier.out_features) @@ -385,7 +385,7 @@ def test_recursive_loop(self): model2.layer2 = model1 benchmark = get_fast_benchmark(use_task_labels=True, shuffle=True) - model1.pre_adapt(benchmark.train_stream[0]) + model1.pre_adapt(None, benchmark.train_stream[0]) def test_multi_head_classifier_masking(self): benchmark = get_fast_benchmark(use_task_labels=True, shuffle=True)