Example implementation of SimMIM: A Simple Framework for Masked Image Modeling architecture. SimMIM is a very similar architecture to Masked Autoencoders Are Scalable Vision Learners, 2021. It uses a ViT encoder using as input both masked and non-masked patches. Other differences with respect to MAE is that it has just a simple linear layer as a decoder and uses L1 instead of L2 loss.


SimMIM: A Simple Framework for Masked Image Modeling, 2021

This example can be run from the command line with:

python lightly/examples/pytorch/
import torch
import torchvision
from torch import nn

from import LightlyDataset
from import MultiViewCollate
from lightly.models import utils
from lightly.models.modules import masked_autoencoder
from lightly.transforms.mae_transform import MAETransform  # Same transform as MAE

class SimMIM(nn.Module):
    def __init__(self, vit):

        decoder_dim = vit.hidden_dim
        self.mask_ratio = 0.75
        self.patch_size = vit.patch_size
        self.sequence_length = vit.seq_length
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))

        # same backbone as MAE
        self.backbone = masked_autoencoder.MAEBackbone.from_vit(vit)

        # the decoder is a simple linear layer
        self.decoder = nn.Linear(vit.hidden_dim, vit.patch_size**2 * 3)

    def forward_encoder(self, images, batch_size, idx_mask):
        # pass all the tokens to the encoder, both masked and non masked ones
        tokens = self.backbone.images_to_tokens(images, prepend_class_token=True)
        tokens_masked = utils.mask_at_index(tokens, idx_mask, self.mask_token)
        return self.backbone.encoder(tokens_masked)

    def forward_decoder(self, x_encoded):
        return self.decoder(x_encoded)

    def forward(self, images):
        batch_size = images.shape[0]
        idx_keep, idx_mask = utils.random_token_mask(
            size=(batch_size, self.sequence_length),

        # Encoding...
        x_encoded = self.forward_encoder(images, batch_size, idx_mask)
        x_encoded_masked = utils.get_at_index(x_encoded, idx_mask)

        # Decoding...
        x_out = self.forward_decoder(x_encoded_masked)

        # get image patches for masked tokens
        patches = utils.patchify(images, self.patch_size)

        # must adjust idx_mask for missing class token
        target = utils.get_at_index(patches, idx_mask - 1)

        return x_out, target

vit = torchvision.models.vit_b_32(pretrained=False)
model = SimMIM(vit)

device = "cuda" if torch.cuda.is_available() else "cpu"

# we ignore object detection annotations by setting target_transform to return 0
pascal_voc = torchvision.datasets.VOCDetection(
    "datasets/pascal_voc", download=False, target_transform=lambda t: 0
transform = MAETransform()
dataset = LightlyDataset.from_torch_dataset(pascal_voc, transform=transform)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

collate_fn = MultiViewCollate()

dataloader =

# L1 loss as paper suggestion
criterion = nn.L1Loss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4)

print("Starting Training")
for epoch in range(10):
    total_loss = 0
    for images, _, _ in dataloader:
        images = images[0].to(device)  # images is a list containing only one view
        predictions, targets = model(images)

        loss = criterion(predictions, targets)
        total_loss += loss.detach()
    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")