From e2dfaaeef94591fdd31e6113e5c174d56ac0e0df Mon Sep 17 00:00:00 2001 From: gogamid <36050790+gogamid@users.noreply.github.com> Date: Tue, 13 Feb 2024 01:55:34 +0100 Subject: [PATCH 1/6] feat: add support for weighted loss in generative replay --- .../training/plugins/generative_replay.py | 72 ++++++++++++++++--- .../training/supervised/strategy_wrappers.py | 9 +++ 2 files changed, 72 insertions(+), 9 deletions(-) diff --git a/avalanche/training/plugins/generative_replay.py b/avalanche/training/plugins/generative_replay.py index 1d89eb565..88a0edab6 100644 --- a/avalanche/training/plugins/generative_replay.py +++ b/avalanche/training/plugins/generative_replay.py @@ -15,8 +15,8 @@ """ from copy import deepcopy -from typing import Optional -from avalanche.core import SupervisedPlugin +from typing import Optional, Any +from avalanche.core import SupervisedPlugin, Template import torch @@ -52,11 +52,14 @@ class GenerativeReplayPlugin(SupervisedPlugin): """ def __init__( - self, - generator_strategy=None, - untrained_solver: bool = True, - replay_size: Optional[int] = None, - increasing_replay_size: bool = False, + self, + generator_strategy=None, + untrained_solver: bool = True, + replay_size: Optional[int] = None, + increasing_replay_size: bool = False, + is_weighted_replay: bool = False, + weight_replay_loss_factor: float = 1.0, + weight_replay_loss: float = 0.0001, ): """ Init. @@ -71,6 +74,9 @@ def __init__( self.model_is_generator = False self.replay_size = replay_size self.increasing_replay_size = increasing_replay_size + self.is_weighted_replay = is_weighted_replay + self.weight_replay_loss_factor = weight_replay_loss_factor + self.weight_replay_loss = weight_replay_loss def before_training(self, strategy, *args, **kwargs): """Checks whether we are using a user defined external generator @@ -85,7 +91,7 @@ def before_training(self, strategy, *args, **kwargs): self.model_is_generator = True def before_training_exp( - self, strategy, num_workers: int = 0, shuffle: bool = True, **kwargs + self, strategy, num_workers: int = 0, shuffle: bool = True, **kwargs ): """ Make deep copies of generator and solver before training new experience. @@ -101,7 +107,7 @@ def before_training_exp( self.old_model.eval() def after_training_exp( - self, strategy, num_workers: int = 0, shuffle: bool = True, **kwargs + self, strategy, num_workers: int = 0, shuffle: bool = True, **kwargs ): """ Set untrained_solver boolean to False after (the first) experience, @@ -109,11 +115,59 @@ def after_training_exp( """ self.untrained_solver = False + def before_backward(self, strategy: Template, *args, **kwargs) -> Any: + super().before_backward(strategy, *args, **kwargs) + """ + Generate replay data and calculate the loss on the replay data. + Add weighted loss to the total loss if the user has set the weight_replay_loss + """ + if not self.is_weighted_replay: + # If we are not using weighted loss, ignore this method + return + + if self.untrained_solver: + # do not generate on the first experience + return + + # determine how many replay data points to generate + if self.replay_size: + number_replays_to_generate = self.replay_size + else: + if self.increasing_replay_size: + number_replays_to_generate = len(strategy.mbatch[0]) * ( + strategy.experience.current_experience + ) + else: + number_replays_to_generate = len(strategy.mbatch[0]) + replay_data = self.old_generator.generate(number_replays_to_generate).to( + strategy.device + ) + # get labels for replay data + if not self.model_is_generator: + with torch.no_grad(): + replay_output = self.old_model(replay_data).argmax(dim=-1) + else: + # Mock labels: + replay_output = torch.zeros(replay_data.shape[0]) + + # make copy of mbatch + mbatch = deepcopy(strategy.mbatch) + # replace mbatch with replay data, calculate loss and add to strategy.loss + strategy.mbatch = [replay_data, replay_output, strategy.mbatch[-1]] + strategy.forward() + strategy.loss += self.weight_replay_loss * strategy.criterion() + self.weight_replay_loss *= self.weight_replay_loss_factor + # restore mbatch + strategy.mbatch = mbatch + def before_training_iteration(self, strategy, **kwargs): """ Generating and appending replay data to current minibatch before each training iteration. """ + if self.is_weighted_replay: + # When using weighted loss, do not add replay data to the current minibatch + return if self.untrained_solver: # The solver needs to be trained before labelling generated data and # the generator needs to be trained before we can sample. diff --git a/avalanche/training/supervised/strategy_wrappers.py b/avalanche/training/supervised/strategy_wrappers.py index 541d500b2..658c13d05 100644 --- a/avalanche/training/supervised/strategy_wrappers.py +++ b/avalanche/training/supervised/strategy_wrappers.py @@ -437,6 +437,9 @@ def __init__( generator_strategy: Optional[BaseTemplate] = None, replay_size: Optional[int] = None, increasing_replay_size: bool = False, + is_weighted_replay: bool = False, + weight_replay_loss_factor: float = 1.0, + weight_replay_loss: float = 0.0001, **base_kwargs ): """ @@ -499,6 +502,9 @@ def __init__( GenerativeReplayPlugin( replay_size=replay_size, increasing_replay_size=increasing_replay_size, + is_weighted_replay=is_weighted_replay, + weight_replay_loss_factor=weight_replay_loss_factor, + weight_replay_loss = weight_replay_loss, ) ], ) @@ -507,6 +513,9 @@ def __init__( generator_strategy=self.generator_strategy, replay_size=replay_size, increasing_replay_size=increasing_replay_size, + is_weighted_replay=is_weighted_replay, + weight_replay_loss_factor=weight_replay_loss_factor, + weight_replay_loss=weight_replay_loss, ) tgp = TrainGeneratorAfterExpPlugin() From 0855d46006568816323a27f15d74dcdc791b5aed Mon Sep 17 00:00:00 2001 From: gogamid <36050790+gogamid@users.noreply.github.com> Date: Tue, 13 Feb 2024 01:55:59 +0100 Subject: [PATCH 2/6] add example for generative replay with weighted loss --- .../generative_replay_splitMNIST_weighted.py | 97 +++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 examples/generative_replay_splitMNIST_weighted.py diff --git a/examples/generative_replay_splitMNIST_weighted.py b/examples/generative_replay_splitMNIST_weighted.py new file mode 100644 index 000000000..a6be2cd7b --- /dev/null +++ b/examples/generative_replay_splitMNIST_weighted.py @@ -0,0 +1,97 @@ +################################################################################ +# Copyright (c) 2021 ContinualAI. # +# Copyrights licensed under the MIT License. # +# See the accompanying LICENSE file for terms. # +# # +# Date: 13-02-2024 # +# Author(s): Imron Gamidli # +# # +################################################################################ + +""" +This is a simple example on how to use the GenerativeReplay strategy with weighted replay loss. +""" +import datetime +import argparse +import torch +from torch.nn import CrossEntropyLoss +from torchvision import transforms +from torchvision.transforms import ToTensor, RandomCrop +import torch.optim.lr_scheduler +from avalanche.benchmarks import SplitMNIST +from avalanche.models import SimpleMLP +from avalanche.training.supervised import GenerativeReplay +from avalanche.evaluation.metrics import ( + forgetting_metrics, + accuracy_metrics, + loss_metrics, +) +from avalanche.logging import InteractiveLogger, TextLogger +from avalanche.training.plugins import EvaluationPlugin + + +def main(args): + # --- CONFIG + device = torch.device( + f"cuda:{args.cuda}" if torch.cuda.is_available() and args.cuda >= 0 else "cpu" + ) + + # --- BENCHMARK CREATION + benchmark = SplitMNIST(n_experiences=5, seed=1234, fixed_class_order=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + # --------- + + # MODEL CREATION + model = SimpleMLP(num_classes=benchmark.n_classes, hidden_size=10) + + # choose some metrics and evaluation method + interactive_logger = InteractiveLogger() + file_name = 'logs/log_' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + '.log' + text_logger = TextLogger(open(file_name, 'a')) + + eval_plugin = EvaluationPlugin( + # accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True), + accuracy_metrics(experience=True, stream=True), + loss_metrics(minibatch=True), + # loss_metrics(minibatch=True, epoch=True, experience=True, stream=True), + # forgetting_metrics(experience=True), + loggers=[interactive_logger, text_logger], + ) + + # CREATE THE STRATEGY INSTANCE (GenerativeReplay) + cl_strategy = GenerativeReplay( + model, + torch.optim.Adam(model.parameters(), lr=0.001), + CrossEntropyLoss(), + train_mb_size=100, + train_epochs=2, + eval_mb_size=100, + device=device, + evaluator=eval_plugin, + is_weighted_replay=True, + weight_replay_loss_factor=2.0, + weight_replay_loss=0.001, + ) + + # TRAINING LOOP + print("Starting experiment...") + results = [] + for experience in benchmark.train_stream: + print("Start of experience ", experience.current_experience) + cl_strategy.train(experience) + print("Training completed") + + print("Computing accuracy on the whole test set") + results.append(cl_strategy.eval(benchmark.test_stream)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--cuda", + type=int, + default=0, + help="Select zero-indexed cuda device. -1 to use CPU.", + ) + args = parser.parse_args() + print(args) + main(args) From affe662108626d0b2f2fdc7d3fd74c28fbb6eedb Mon Sep 17 00:00:00 2001 From: gogamid <36050790+gogamid@users.noreply.github.com> Date: Sun, 17 Mar 2024 20:23:30 +0100 Subject: [PATCH 3/6] fix: move docstring befoe super call --- .../training/plugins/generative_replay.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/avalanche/training/plugins/generative_replay.py b/avalanche/training/plugins/generative_replay.py index 88a0edab6..d965766b1 100644 --- a/avalanche/training/plugins/generative_replay.py +++ b/avalanche/training/plugins/generative_replay.py @@ -52,14 +52,14 @@ class GenerativeReplayPlugin(SupervisedPlugin): """ def __init__( - self, - generator_strategy=None, - untrained_solver: bool = True, - replay_size: Optional[int] = None, - increasing_replay_size: bool = False, - is_weighted_replay: bool = False, - weight_replay_loss_factor: float = 1.0, - weight_replay_loss: float = 0.0001, + self, + generator_strategy=None, + untrained_solver: bool = True, + replay_size: Optional[int] = None, + increasing_replay_size: bool = False, + is_weighted_replay: bool = False, + weight_replay_loss_factor: float = 1.0, + weight_replay_loss: float = 0.0001, ): """ Init. @@ -91,7 +91,7 @@ def before_training(self, strategy, *args, **kwargs): self.model_is_generator = True def before_training_exp( - self, strategy, num_workers: int = 0, shuffle: bool = True, **kwargs + self, strategy, num_workers: int = 0, shuffle: bool = True, **kwargs ): """ Make deep copies of generator and solver before training new experience. @@ -107,7 +107,7 @@ def before_training_exp( self.old_model.eval() def after_training_exp( - self, strategy, num_workers: int = 0, shuffle: bool = True, **kwargs + self, strategy, num_workers: int = 0, shuffle: bool = True, **kwargs ): """ Set untrained_solver boolean to False after (the first) experience, @@ -116,11 +116,11 @@ def after_training_exp( self.untrained_solver = False def before_backward(self, strategy: Template, *args, **kwargs) -> Any: - super().before_backward(strategy, *args, **kwargs) """ Generate replay data and calculate the loss on the replay data. Add weighted loss to the total loss if the user has set the weight_replay_loss """ + super().before_backward(strategy, *args, **kwargs) if not self.is_weighted_replay: # If we are not using weighted loss, ignore this method return From 0f42d7f2597312d637e062f310391426cdf04a0c Mon Sep 17 00:00:00 2001 From: gogamid <36050790+gogamid@users.noreply.github.com> Date: Sun, 17 Mar 2024 20:41:32 +0100 Subject: [PATCH 4/6] chore: add documentation for the arguments --- avalanche/training/plugins/generative_replay.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/avalanche/training/plugins/generative_replay.py b/avalanche/training/plugins/generative_replay.py index d965766b1..3d6daa8b3 100644 --- a/avalanche/training/plugins/generative_replay.py +++ b/avalanche/training/plugins/generative_replay.py @@ -49,6 +49,14 @@ class GenerativeReplayPlugin(SupervisedPlugin): double the amount of replay data added to each data batch. The effect will be that the older experiences will gradually increase in importance to the final loss. + :param is_weighted_replay: If set to True, the loss function will be weighted + and more importance will be given to the replay data as the number of + experiences increases. + :param weight_replay_loss_factor: If is_weighted_replay is set to True, the user + can specify a factor the weight will be multiplied by in each iteration, + the default is 1.0 + :param weight_replay_loss: The user can specify the initial weight of the loss for + the replay data. The default is 0.0001 """ def __init__( From 7e4fd5ad35e21ab4f8902e00c8c8323769cd683f Mon Sep 17 00:00:00 2001 From: gogamid <36050790+gogamid@users.noreply.github.com> Date: Sun, 17 Mar 2024 20:49:49 +0100 Subject: [PATCH 5/6] fix: modify one of the existing example instead of creating a new one --- examples/generative_replay_splitMNIST.py | 37 +++---- .../generative_replay_splitMNIST_weighted.py | 97 ------------------- 2 files changed, 20 insertions(+), 114 deletions(-) delete mode 100644 examples/generative_replay_splitMNIST_weighted.py diff --git a/examples/generative_replay_splitMNIST.py b/examples/generative_replay_splitMNIST.py index 7112128ad..ce60f4c6e 100644 --- a/examples/generative_replay_splitMNIST.py +++ b/examples/generative_replay_splitMNIST.py @@ -3,31 +3,27 @@ # Copyrights licensed under the MIT License. # # See the accompanying LICENSE file for terms. # # # -# Date: 01-04-2022 # -# Author(s): Florian Mies # -# E-mail: contact@continualai.org # -# Website: avalanche.continualai.org # +# Date: 13-02-2024 # +# Author(s): Imron Gamidli # +# # ################################################################################ """ -This is a simple example on how to use the Replay strategy. +This is a simple example on how to use the GenerativeReplay strategy with weighted replay loss. """ - +import datetime import argparse import torch from torch.nn import CrossEntropyLoss -from torchvision import transforms -from torchvision.transforms import ToTensor, RandomCrop import torch.optim.lr_scheduler from avalanche.benchmarks import SplitMNIST from avalanche.models import SimpleMLP from avalanche.training.supervised import GenerativeReplay from avalanche.evaluation.metrics import ( - forgetting_metrics, accuracy_metrics, loss_metrics, ) -from avalanche.logging import InteractiveLogger +from avalanche.logging import InteractiveLogger, TextLogger from avalanche.training.plugins import EvaluationPlugin @@ -38,20 +34,23 @@ def main(args): ) # --- BENCHMARK CREATION - benchmark = SplitMNIST(n_experiences=10, seed=1234) + benchmark = SplitMNIST( + n_experiences=5, seed=1234, fixed_class_order=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + ) # --------- # MODEL CREATION - model = SimpleMLP(num_classes=benchmark.n_classes) + model = SimpleMLP(num_classes=benchmark.n_classes, hidden_size=10) # choose some metrics and evaluation method interactive_logger = InteractiveLogger() + file_name = "logs/log_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + ".log" + text_logger = TextLogger(open(file_name, "a")) eval_plugin = EvaluationPlugin( - accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True), - loss_metrics(minibatch=True, epoch=True, experience=True, stream=True), - forgetting_metrics(experience=True), - loggers=[interactive_logger], + accuracy_metrics(experience=True, stream=True), + loss_metrics(minibatch=True), + loggers=[interactive_logger, text_logger], ) # CREATE THE STRATEGY INSTANCE (GenerativeReplay) @@ -60,10 +59,13 @@ def main(args): torch.optim.Adam(model.parameters(), lr=0.001), CrossEntropyLoss(), train_mb_size=100, - train_epochs=4, + train_epochs=2, eval_mb_size=100, device=device, evaluator=eval_plugin, + is_weighted_replay=True, + weight_replay_loss_factor=2.0, + weight_replay_loss=0.001, ) # TRAINING LOOP @@ -87,4 +89,5 @@ def main(args): help="Select zero-indexed cuda device. -1 to use CPU.", ) args = parser.parse_args() + print(args) main(args) diff --git a/examples/generative_replay_splitMNIST_weighted.py b/examples/generative_replay_splitMNIST_weighted.py deleted file mode 100644 index a6be2cd7b..000000000 --- a/examples/generative_replay_splitMNIST_weighted.py +++ /dev/null @@ -1,97 +0,0 @@ -################################################################################ -# Copyright (c) 2021 ContinualAI. # -# Copyrights licensed under the MIT License. # -# See the accompanying LICENSE file for terms. # -# # -# Date: 13-02-2024 # -# Author(s): Imron Gamidli # -# # -################################################################################ - -""" -This is a simple example on how to use the GenerativeReplay strategy with weighted replay loss. -""" -import datetime -import argparse -import torch -from torch.nn import CrossEntropyLoss -from torchvision import transforms -from torchvision.transforms import ToTensor, RandomCrop -import torch.optim.lr_scheduler -from avalanche.benchmarks import SplitMNIST -from avalanche.models import SimpleMLP -from avalanche.training.supervised import GenerativeReplay -from avalanche.evaluation.metrics import ( - forgetting_metrics, - accuracy_metrics, - loss_metrics, -) -from avalanche.logging import InteractiveLogger, TextLogger -from avalanche.training.plugins import EvaluationPlugin - - -def main(args): - # --- CONFIG - device = torch.device( - f"cuda:{args.cuda}" if torch.cuda.is_available() and args.cuda >= 0 else "cpu" - ) - - # --- BENCHMARK CREATION - benchmark = SplitMNIST(n_experiences=5, seed=1234, fixed_class_order=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) - # --------- - - # MODEL CREATION - model = SimpleMLP(num_classes=benchmark.n_classes, hidden_size=10) - - # choose some metrics and evaluation method - interactive_logger = InteractiveLogger() - file_name = 'logs/log_' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + '.log' - text_logger = TextLogger(open(file_name, 'a')) - - eval_plugin = EvaluationPlugin( - # accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True), - accuracy_metrics(experience=True, stream=True), - loss_metrics(minibatch=True), - # loss_metrics(minibatch=True, epoch=True, experience=True, stream=True), - # forgetting_metrics(experience=True), - loggers=[interactive_logger, text_logger], - ) - - # CREATE THE STRATEGY INSTANCE (GenerativeReplay) - cl_strategy = GenerativeReplay( - model, - torch.optim.Adam(model.parameters(), lr=0.001), - CrossEntropyLoss(), - train_mb_size=100, - train_epochs=2, - eval_mb_size=100, - device=device, - evaluator=eval_plugin, - is_weighted_replay=True, - weight_replay_loss_factor=2.0, - weight_replay_loss=0.001, - ) - - # TRAINING LOOP - print("Starting experiment...") - results = [] - for experience in benchmark.train_stream: - print("Start of experience ", experience.current_experience) - cl_strategy.train(experience) - print("Training completed") - - print("Computing accuracy on the whole test set") - results.append(cl_strategy.eval(benchmark.test_stream)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--cuda", - type=int, - default=0, - help="Select zero-indexed cuda device. -1 to use CPU.", - ) - args = parser.parse_args() - print(args) - main(args) From 6082a3cc7a29fc20e0debeeebc6123f2c0768573 Mon Sep 17 00:00:00 2001 From: gogamid <36050790+gogamid@users.noreply.github.com> Date: Sun, 17 Mar 2024 20:50:33 +0100 Subject: [PATCH 6/6] chore: run black formatter --- avalanche/training/supervised/strategy_wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/avalanche/training/supervised/strategy_wrappers.py b/avalanche/training/supervised/strategy_wrappers.py index 658c13d05..b13091b64 100644 --- a/avalanche/training/supervised/strategy_wrappers.py +++ b/avalanche/training/supervised/strategy_wrappers.py @@ -504,7 +504,7 @@ def __init__( increasing_replay_size=increasing_replay_size, is_weighted_replay=is_weighted_replay, weight_replay_loss_factor=weight_replay_loss_factor, - weight_replay_loss = weight_replay_loss, + weight_replay_loss=weight_replay_loss, ) ], )