MAE

Example implementation of the Masked Autoencoder (MAE) method. MAE is a transformer-based method that leverages a Vision Transformer (ViT) as its backbone to learn image representations by predicting pixel values of masked patches. As an autoencoder, MAE consists of an encoder that processes masked images to generate latent representations and a decoder that reconstructs the input images from these representations. The masking operation significantly reduces the sequence length processed by the transformer encoder, which improves computational efficiency compared to other transformer-based self-supervised learning methods. By reconstructing the masked patches, MAE effectively forces the model to learn meaningful representations of the data.

Key Components

  • Data Augmentations: Unlike contrastive and most self-distillation methods, MAE minimizes reliance on handcrafted data augmentations. The only augmentation used is random resized cropping.

  • Masking: MAE applies masking to 75% of the input patches, meaning only 25% of the image tokens are fed into the transformer encoder.

  • Backbone: MAE employs a standard ViT to encode the masked images.

  • Decoder: The decoder processes visible tokens alongside shared, learnable mask tokens. It reconstructs the original input image by predicting the pixel values of the masked patches.

  • Reconstruction Loss: A Mean Squared Error (MSE) loss is applied between the original and reconstructed pixel values of the masked patches.

Good to Know

  • Backbone Networks: The masking process used by MAE is inherently incompatible with convolutional-based architectures.

  • Computational Efficiency: The masking mechanism allows the encoder to process only a subset of the image tokens, significantly reducing computational overhead.

  • Scalability: MAE demonstrates excellent scalability with respect to both model and data size as demonstrated here.

  • Versatility: The minimal reliance on handcrafted data augmentations makes MAE adaptable to diverse data domains. For example, its application in medical imaging is discussed in this study.

  • Shallow Evaluations: Despite their strong performance in the fine-tuning regime, models trained with MAE tend to underperform in shallow evaluations, such as k-NN or linear evaluation with a frozen backbone.

Reference:

Masked Autoencoders Are Scalable Vision Learners, 2021

Note

MAE requires TIMM to be installed

pip install "lightly[timm]"
https://img.shields.io/badge/Open%20in%20Colab-blue?logo=googlecolab&label=%20&labelColor=5c5c5c

This example can be run from the command line with:

python lightly/examples/pytorch/mae.py
# This example requires the following dependencies to be installed:
# pip install lightly[timm]

# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import torch
import torchvision
from timm.models.vision_transformer import vit_base_patch32_224
from torch import nn

from lightly.models import utils
from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM
from lightly.transforms import MAETransform


class MAE(nn.Module):
    def __init__(self, vit):
        super().__init__()

        decoder_dim = 512
        self.mask_ratio = 0.75
        self.patch_size = vit.patch_embed.patch_size[0]

        self.backbone = MaskedVisionTransformerTIMM(vit=vit)
        self.sequence_length = self.backbone.sequence_length
        self.decoder = MAEDecoderTIMM(
            num_patches=vit.patch_embed.num_patches,
            patch_size=self.patch_size,
            embed_dim=vit.embed_dim,
            decoder_embed_dim=decoder_dim,
            decoder_depth=1,
            decoder_num_heads=16,
            mlp_ratio=4.0,
            proj_drop_rate=0.0,
            attn_drop_rate=0.0,
        )

    def forward_encoder(self, images, idx_keep=None):
        return self.backbone.encode(images=images, idx_keep=idx_keep)

    def forward_decoder(self, x_encoded, idx_keep, idx_mask):
        # build decoder input
        batch_size = x_encoded.shape[0]
        x_decode = self.decoder.embed(x_encoded)
        x_masked = utils.repeat_token(
            self.decoder.mask_token, (batch_size, self.sequence_length)
        )
        x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked))

        # decoder forward pass
        x_decoded = self.decoder.decode(x_masked)

        # predict pixel values for masked tokens
        x_pred = utils.get_at_index(x_decoded, idx_mask)
        x_pred = self.decoder.predict(x_pred)
        return x_pred

    def forward(self, images):
        batch_size = images.shape[0]
        idx_keep, idx_mask = utils.random_token_mask(
            size=(batch_size, self.sequence_length),
            mask_ratio=self.mask_ratio,
            device=images.device,
        )
        x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep)
        x_pred = self.forward_decoder(
            x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask
        )

        # get image patches for masked tokens
        patches = utils.patchify(images, self.patch_size)
        # must adjust idx_mask for missing class token
        target = utils.get_at_index(patches, idx_mask - 1)
        return x_pred, target


vit = vit_base_patch32_224()
model = MAE(vit)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

transform = MAETransform()
# we ignore object detection annotations by setting target_transform to return 0


def target_transform(t):
    return 0


dataset = torchvision.datasets.VOCDetection(
    "datasets/pascal_voc",
    download=True,
    transform=transform,
    target_transform=target_transform,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4)

print("Starting Training")
for epoch in range(10):
    total_loss = 0
    for batch in dataloader:
        views = batch[0]
        images = views[0].to(device)  # views contains only a single view
        predictions, targets = model(images)
        loss = criterion(predictions, targets)
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")