-
-
Notifications
You must be signed in to change notification settings - Fork 297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generative Replay with weighted loss for replayed data #1596
Merged
+93
−19
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
e2dfaae
feat: add support for weighted loss in generative replay
gogamid 0855d46
add example for generative replay with weighted loss
gogamid 9c9955b
Merge branch 'master' of https://github.com/ContinualAI/avalanche
gogamid affe662
fix: move docstring befoe super call
gogamid 0f42d7f
chore: add documentation for the arguments
gogamid 7e4fd5a
fix: modify one of the existing example instead of creating a new one
gogamid 6082a3c
chore: run black formatter
gogamid File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,8 +15,8 @@ | |
""" | ||
|
||
from copy import deepcopy | ||
from typing import Optional | ||
from avalanche.core import SupervisedPlugin | ||
from typing import Optional, Any | ||
from avalanche.core import SupervisedPlugin, Template | ||
import torch | ||
|
||
|
||
|
@@ -49,6 +49,14 @@ class GenerativeReplayPlugin(SupervisedPlugin): | |
double the amount of replay data added to each data batch. The effect | ||
will be that the older experiences will gradually increase in importance | ||
to the final loss. | ||
:param is_weighted_replay: If set to True, the loss function will be weighted | ||
and more importance will be given to the replay data as the number of | ||
experiences increases. | ||
:param weight_replay_loss_factor: If is_weighted_replay is set to True, the user | ||
can specify a factor the weight will be multiplied by in each iteration, | ||
the default is 1.0 | ||
:param weight_replay_loss: The user can specify the initial weight of the loss for | ||
the replay data. The default is 0.0001 | ||
""" | ||
|
||
def __init__( | ||
|
@@ -57,6 +65,9 @@ def __init__( | |
untrained_solver: bool = True, | ||
replay_size: Optional[int] = None, | ||
increasing_replay_size: bool = False, | ||
is_weighted_replay: bool = False, | ||
weight_replay_loss_factor: float = 1.0, | ||
weight_replay_loss: float = 0.0001, | ||
): | ||
""" | ||
Init. | ||
|
@@ -71,6 +82,9 @@ def __init__( | |
self.model_is_generator = False | ||
self.replay_size = replay_size | ||
self.increasing_replay_size = increasing_replay_size | ||
self.is_weighted_replay = is_weighted_replay | ||
self.weight_replay_loss_factor = weight_replay_loss_factor | ||
self.weight_replay_loss = weight_replay_loss | ||
|
||
def before_training(self, strategy, *args, **kwargs): | ||
"""Checks whether we are using a user defined external generator | ||
|
@@ -109,11 +123,59 @@ def after_training_exp( | |
""" | ||
self.untrained_solver = False | ||
|
||
def before_backward(self, strategy: Template, *args, **kwargs) -> Any: | ||
""" | ||
Generate replay data and calculate the loss on the replay data. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. docstring must be before super() call |
||
Add weighted loss to the total loss if the user has set the weight_replay_loss | ||
""" | ||
super().before_backward(strategy, *args, **kwargs) | ||
if not self.is_weighted_replay: | ||
# If we are not using weighted loss, ignore this method | ||
return | ||
|
||
if self.untrained_solver: | ||
# do not generate on the first experience | ||
return | ||
|
||
# determine how many replay data points to generate | ||
if self.replay_size: | ||
number_replays_to_generate = self.replay_size | ||
else: | ||
if self.increasing_replay_size: | ||
number_replays_to_generate = len(strategy.mbatch[0]) * ( | ||
strategy.experience.current_experience | ||
) | ||
else: | ||
number_replays_to_generate = len(strategy.mbatch[0]) | ||
replay_data = self.old_generator.generate(number_replays_to_generate).to( | ||
strategy.device | ||
) | ||
# get labels for replay data | ||
if not self.model_is_generator: | ||
with torch.no_grad(): | ||
replay_output = self.old_model(replay_data).argmax(dim=-1) | ||
else: | ||
# Mock labels: | ||
replay_output = torch.zeros(replay_data.shape[0]) | ||
|
||
# make copy of mbatch | ||
mbatch = deepcopy(strategy.mbatch) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you need a deepcopy here? |
||
# replace mbatch with replay data, calculate loss and add to strategy.loss | ||
strategy.mbatch = [replay_data, replay_output, strategy.mbatch[-1]] | ||
strategy.forward() | ||
strategy.loss += self.weight_replay_loss * strategy.criterion() | ||
self.weight_replay_loss *= self.weight_replay_loss_factor | ||
# restore mbatch | ||
strategy.mbatch = mbatch | ||
|
||
def before_training_iteration(self, strategy, **kwargs): | ||
""" | ||
Generating and appending replay data to current minibatch before | ||
each training iteration. | ||
""" | ||
if self.is_weighted_replay: | ||
# When using weighted loss, do not add replay data to the current minibatch | ||
return | ||
if self.untrained_solver: | ||
# The solver needs to be trained before labelling generated data and | ||
# the generator needs to be trained before we can sample. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add the documentation for the arguments?