-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of a Masked Autoencoder for representation learning #8152
Conversation
Signed-off-by: Lucas Robinet <[email protected]>
Signed-off-by: Lucas Robinet <[email protected]>
Signed-off-by: Lucas Robinet <[email protected]>
Hi @Lucas-rbnt thanks for the effort on this followup PR. @atbenmurray could you please re-review the content here? |
@Lucas-rbnt @atbenmurray I shall do so |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR.
In the official masked autoencoder implementation, noise is first generated and then sorted twice using torch.argsort. This rearranges the tokens and identifies which ones are retained, ultimately selecting only a subset of the shuffled indices.
In our implementation, we use torch.multinomial to generate mask indices, followed by simple boolean indexing to manage the sub-selection of patches for encoding and the reordering with mask tokens in the decoder.
As you mentioned here, I wonder have you verified that whether there is a big difference between the two different implementations? Does it have any impact on the final performance? Thanks.
Co-authored-by: YunLiu <[email protected]> Signed-off-by: Lucas Robinet <[email protected]>
Signed-off-by: Lucas Robinet <[email protected]>
I think this is fine now though the comments should be looked at the conflict resolved, then we can trigger the blossom tests. Thanks! |
Signed-off-by: Lucas Robinet <[email protected]>
Signed-off-by: Lucas Robinet <[email protected]>
Signed-off-by: Lucas Robinet <[email protected]>
/build |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm good with this and look forward to an example notebook in Tutorials demonstrating its use!
This follows a previous PR (#7598).
In the previous PR, the official implementation was under a non-compatible license. This is a clean-sheet implementation I developed. The code is fairly straightforward, involving a transformer, encoder, and decoder. The primary changes are in how masks are selected and how patches are organized as they pass through the model.
In the official masked autoencoder implementation, noise is first generated and then sorted twice using
torch.argsort
. This rearranges the tokens and identifies which ones are retained, ultimately selecting only a subset of the shuffled indices.In our implementation, we use
torch.multinomial
to generate mask indices, followed by simple boolean indexing to manage the sub-selection of patches for encoding and the reordering with mask tokens in the decoder.Let me know if you need a detailed, line-by-line explanation of the new code, including how it works and how it differs from the previous version.
Description
Implementation of the Masked Autoencoder as described in the paper: Masked Autoencoders Are Scalable Vision Learners from Kaiming et al.
Its effectiveness has already been demonstrated in the literature for medical tasks in the paper Self Pre-training with Masked Autoencoders for Medical Image Classification and Segmentation.
The PR contains the architecture and associated unit tests.
Note: The output includes the prediction, which is a tensor of size: ($BS$ , $N_{tokens}$ , $D$ ), and the associated mask ($BS$ , $N_{tokens}$ ). The mask is used to apply loss only to masked patches, but I'm not sure it's the “best” output format, what do you think?
Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.