diff --git a/avalanche/training/plugins/generative_replay.py b/avalanche/training/plugins/generative_replay.py index 1d89eb565..3d6daa8b3 100644 --- a/avalanche/training/plugins/generative_replay.py +++ b/avalanche/training/plugins/generative_replay.py @@ -15,8 +15,8 @@ """ from copy import deepcopy -from typing import Optional -from avalanche.core import SupervisedPlugin +from typing import Optional, Any +from avalanche.core import SupervisedPlugin, Template import torch @@ -49,6 +49,14 @@ class GenerativeReplayPlugin(SupervisedPlugin): double the amount of replay data added to each data batch. The effect will be that the older experiences will gradually increase in importance to the final loss. + :param is_weighted_replay: If set to True, the loss function will be weighted + and more importance will be given to the replay data as the number of + experiences increases. + :param weight_replay_loss_factor: If is_weighted_replay is set to True, the user + can specify a factor the weight will be multiplied by in each iteration, + the default is 1.0 + :param weight_replay_loss: The user can specify the initial weight of the loss for + the replay data. The default is 0.0001 """ def __init__( @@ -57,6 +65,9 @@ def __init__( untrained_solver: bool = True, replay_size: Optional[int] = None, increasing_replay_size: bool = False, + is_weighted_replay: bool = False, + weight_replay_loss_factor: float = 1.0, + weight_replay_loss: float = 0.0001, ): """ Init. @@ -71,6 +82,9 @@ def __init__( self.model_is_generator = False self.replay_size = replay_size self.increasing_replay_size = increasing_replay_size + self.is_weighted_replay = is_weighted_replay + self.weight_replay_loss_factor = weight_replay_loss_factor + self.weight_replay_loss = weight_replay_loss def before_training(self, strategy, *args, **kwargs): """Checks whether we are using a user defined external generator @@ -109,11 +123,59 @@ def after_training_exp( """ self.untrained_solver = False + def before_backward(self, strategy: Template, *args, **kwargs) -> Any: + """ + Generate replay data and calculate the loss on the replay data. + Add weighted loss to the total loss if the user has set the weight_replay_loss + """ + super().before_backward(strategy, *args, **kwargs) + if not self.is_weighted_replay: + # If we are not using weighted loss, ignore this method + return + + if self.untrained_solver: + # do not generate on the first experience + return + + # determine how many replay data points to generate + if self.replay_size: + number_replays_to_generate = self.replay_size + else: + if self.increasing_replay_size: + number_replays_to_generate = len(strategy.mbatch[0]) * ( + strategy.experience.current_experience + ) + else: + number_replays_to_generate = len(strategy.mbatch[0]) + replay_data = self.old_generator.generate(number_replays_to_generate).to( + strategy.device + ) + # get labels for replay data + if not self.model_is_generator: + with torch.no_grad(): + replay_output = self.old_model(replay_data).argmax(dim=-1) + else: + # Mock labels: + replay_output = torch.zeros(replay_data.shape[0]) + + # make copy of mbatch + mbatch = deepcopy(strategy.mbatch) + # replace mbatch with replay data, calculate loss and add to strategy.loss + strategy.mbatch = [replay_data, replay_output, strategy.mbatch[-1]] + strategy.forward() + strategy.loss += self.weight_replay_loss * strategy.criterion() + self.weight_replay_loss *= self.weight_replay_loss_factor + # restore mbatch + strategy.mbatch = mbatch + def before_training_iteration(self, strategy, **kwargs): """ Generating and appending replay data to current minibatch before each training iteration. """ + if self.is_weighted_replay: + # When using weighted loss, do not add replay data to the current minibatch + return if self.untrained_solver: # The solver needs to be trained before labelling generated data and # the generator needs to be trained before we can sample. diff --git a/avalanche/training/supervised/strategy_wrappers.py b/avalanche/training/supervised/strategy_wrappers.py index 541d500b2..b13091b64 100644 --- a/avalanche/training/supervised/strategy_wrappers.py +++ b/avalanche/training/supervised/strategy_wrappers.py @@ -437,6 +437,9 @@ def __init__( generator_strategy: Optional[BaseTemplate] = None, replay_size: Optional[int] = None, increasing_replay_size: bool = False, + is_weighted_replay: bool = False, + weight_replay_loss_factor: float = 1.0, + weight_replay_loss: float = 0.0001, **base_kwargs ): """ @@ -499,6 +502,9 @@ def __init__( GenerativeReplayPlugin( replay_size=replay_size, increasing_replay_size=increasing_replay_size, + is_weighted_replay=is_weighted_replay, + weight_replay_loss_factor=weight_replay_loss_factor, + weight_replay_loss=weight_replay_loss, ) ], ) @@ -507,6 +513,9 @@ def __init__( generator_strategy=self.generator_strategy, replay_size=replay_size, increasing_replay_size=increasing_replay_size, + is_weighted_replay=is_weighted_replay, + weight_replay_loss_factor=weight_replay_loss_factor, + weight_replay_loss=weight_replay_loss, ) tgp = TrainGeneratorAfterExpPlugin() diff --git a/examples/generative_replay_splitMNIST.py b/examples/generative_replay_splitMNIST.py index 7112128ad..ce60f4c6e 100644 --- a/examples/generative_replay_splitMNIST.py +++ b/examples/generative_replay_splitMNIST.py @@ -3,31 +3,27 @@ # Copyrights licensed under the MIT License. # # See the accompanying LICENSE file for terms. # # # -# Date: 01-04-2022 # -# Author(s): Florian Mies # -# E-mail: contact@continualai.org # -# Website: avalanche.continualai.org # +# Date: 13-02-2024 # +# Author(s): Imron Gamidli # +# # ################################################################################ """ -This is a simple example on how to use the Replay strategy. +This is a simple example on how to use the GenerativeReplay strategy with weighted replay loss. """ - +import datetime import argparse import torch from torch.nn import CrossEntropyLoss -from torchvision import transforms -from torchvision.transforms import ToTensor, RandomCrop import torch.optim.lr_scheduler from avalanche.benchmarks import SplitMNIST from avalanche.models import SimpleMLP from avalanche.training.supervised import GenerativeReplay from avalanche.evaluation.metrics import ( - forgetting_metrics, accuracy_metrics, loss_metrics, ) -from avalanche.logging import InteractiveLogger +from avalanche.logging import InteractiveLogger, TextLogger from avalanche.training.plugins import EvaluationPlugin @@ -38,20 +34,23 @@ def main(args): ) # --- BENCHMARK CREATION - benchmark = SplitMNIST(n_experiences=10, seed=1234) + benchmark = SplitMNIST( + n_experiences=5, seed=1234, fixed_class_order=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + ) # --------- # MODEL CREATION - model = SimpleMLP(num_classes=benchmark.n_classes) + model = SimpleMLP(num_classes=benchmark.n_classes, hidden_size=10) # choose some metrics and evaluation method interactive_logger = InteractiveLogger() + file_name = "logs/log_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + ".log" + text_logger = TextLogger(open(file_name, "a")) eval_plugin = EvaluationPlugin( - accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True), - loss_metrics(minibatch=True, epoch=True, experience=True, stream=True), - forgetting_metrics(experience=True), - loggers=[interactive_logger], + accuracy_metrics(experience=True, stream=True), + loss_metrics(minibatch=True), + loggers=[interactive_logger, text_logger], ) # CREATE THE STRATEGY INSTANCE (GenerativeReplay) @@ -60,10 +59,13 @@ def main(args): torch.optim.Adam(model.parameters(), lr=0.001), CrossEntropyLoss(), train_mb_size=100, - train_epochs=4, + train_epochs=2, eval_mb_size=100, device=device, evaluator=eval_plugin, + is_weighted_replay=True, + weight_replay_loss_factor=2.0, + weight_replay_loss=0.001, ) # TRAINING LOOP @@ -87,4 +89,5 @@ def main(args): help="Select zero-indexed cuda device. -1 to use CPU.", ) args = parser.parse_args() + print(args) main(args)