Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of a Masked Autoencoder for representation learning #8152

Merged
merged 9 commits into from
Nov 27, 2024

Conversation

Lucas-rbnt
Copy link
Contributor

This follows a previous PR (#7598).

In the previous PR, the official implementation was under a non-compatible license. This is a clean-sheet implementation I developed. The code is fairly straightforward, involving a transformer, encoder, and decoder. The primary changes are in how masks are selected and how patches are organized as they pass through the model.

In the official masked autoencoder implementation, noise is first generated and then sorted twice using torch.argsort. This rearranges the tokens and identifies which ones are retained, ultimately selecting only a subset of the shuffled indices.

In our implementation, we use torch.multinomial to generate mask indices, followed by simple boolean indexing to manage the sub-selection of patches for encoding and the reordering with mask tokens in the decoder.

Let me know if you need a detailed, line-by-line explanation of the new code, including how it works and how it differs from the previous version.

Description

Implementation of the Masked Autoencoder as described in the paper: Masked Autoencoders Are Scalable Vision Learners from Kaiming et al.

Its effectiveness has already been demonstrated in the literature for medical tasks in the paper Self Pre-training with Masked Autoencoders for Medical Image Classification and Segmentation.
The PR contains the architecture and associated unit tests.

Note: The output includes the prediction, which is a tensor of size: ($BS$, $N_{tokens}$, $D$), and the associated mask ($BS$, $N_{tokens}$). The mask is used to apply loss only to masked patches, but I'm not sure it's the “best” output format, what do you think?

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@ericspod
Copy link
Member

Hi @Lucas-rbnt thanks for the effort on this followup PR. @atbenmurray could you please re-review the content here?

@atbenmurray
Copy link
Contributor

@Lucas-rbnt @atbenmurray I shall do so

Copy link
Contributor

@KumoLiu KumoLiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR.

In the official masked autoencoder implementation, noise is first generated and then sorted twice using torch.argsort. This rearranges the tokens and identifies which ones are retained, ultimately selecting only a subset of the shuffled indices.
In our implementation, we use torch.multinomial to generate mask indices, followed by simple boolean indexing to manage the sub-selection of patches for encoding and the reordering with mask tokens in the decoder.

As you mentioned here, I wonder have you verified that whether there is a big difference between the two different implementations? Does it have any impact on the final performance? Thanks.

monai/networks/nets/masked_autoencoder_vit.py Outdated Show resolved Hide resolved
@KumoLiu KumoLiu requested a review from ericspod November 15, 2024 16:32
@ericspod
Copy link
Member

I think this is fine now though the comments should be looked at the conflict resolved, then we can trigger the blossom tests. Thanks!

@KumoLiu
Copy link
Contributor

KumoLiu commented Nov 27, 2024

/build

Copy link
Member

@ericspod ericspod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm good with this and look forward to an example notebook in Tutorials demonstrating its use!

@KumoLiu KumoLiu merged commit 20372f0 into Project-MONAI:dev Nov 27, 2024
28 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants