SimMIM
Example implementation of SimMIM: A Simple Framework for Masked Image Modeling architecture. SimMIM is a very similar architecture to Masked Autoencoders Are Scalable Vision Learners, 2021. It uses a ViT encoder using as input both masked and non-masked patches. Other differences with respect to MAE is that it has just a simple linear layer as a decoder and uses L1 instead of L2 loss.
This example can be run from the command line with:
python lightly/examples/pytorch/simmim.py
# This example requires the following dependencies to be installed:
# pip install lightly
import torch
import torchvision
from torch import nn
from lightly.models import utils
from lightly.models.modules.masked_vision_transformer_torchvision import (
MaskedVisionTransformerTorchvision,
)
from lightly.transforms.mae_transform import MAETransform # Same transform as MAE
class SimMIM(nn.Module):
def __init__(self, vit):
super().__init__()
decoder_dim = vit.hidden_dim
self.mask_ratio = 0.75
self.patch_size = vit.patch_size
self.sequence_length = vit.seq_length
self.backbone = MaskedVisionTransformerTorchvision(vit=vit)
# the decoder is a simple linear layer
self.decoder = nn.Linear(decoder_dim, vit.patch_size**2 * 3)
def forward_encoder(self, images, batch_size, idx_mask):
# pass all the tokens to the encoder, both masked and non masked ones
return self.backbone.encode(images=images, idx_mask=idx_mask)
def forward_decoder(self, x_encoded):
return self.decoder(x_encoded)
def forward(self, images):
batch_size = images.shape[0]
idx_keep, idx_mask = utils.random_token_mask(
size=(batch_size, self.sequence_length),
mask_ratio=self.mask_ratio,
device=images.device,
)
# Encoding...
x_encoded = self.forward_encoder(images, batch_size, idx_mask)
x_encoded_masked = utils.get_at_index(x_encoded, idx_mask)
# Decoding...
x_out = self.forward_decoder(x_encoded_masked)
# get image patches for masked tokens
patches = utils.patchify(images, self.patch_size)
# must adjust idx_mask for missing class token
target = utils.get_at_index(patches, idx_mask - 1)
return x_out, target
vit = torchvision.models.vit_b_32(pretrained=False)
model = SimMIM(vit)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
transform = MAETransform()
# we ignore object detection annotations by setting target_transform to return 0
def target_transform(t):
return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=target_transform,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=8,
shuffle=True,
drop_last=True,
num_workers=8,
)
# L1 loss as paper suggestion
criterion = nn.L1Loss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4)
print("Starting Training")
for epoch in range(10):
total_loss = 0
for batch in dataloader:
views = batch[0]
images = views[0].to(device) # views contains only a single view
predictions, targets = model(images)
loss = criterion(predictions, targets)
total_loss += loss.detach()
loss.backward()
optimizer.step()
optimizer.zero_grad()
avg_loss = total_loss / len(dataloader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
This example can be run from the command line with:
python lightly/examples/pytorch_lightning/simmim.py
# This example requires the following dependencies to be installed:
# pip install lightly
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.models import utils
from lightly.models.modules import MaskedVisionTransformerTorchvision
from lightly.transforms.mae_transform import MAETransform # Same transform as MAE
class SimMIM(pl.LightningModule):
def __init__(self):
super().__init__()
vit = torchvision.models.vit_b_32(pretrained=False)
decoder_dim = vit.hidden_dim
self.mask_ratio = 0.75
self.patch_size = vit.patch_size
self.sequence_length = vit.seq_length
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))
self.backbone = MaskedVisionTransformerTorchvision(vit=vit)
# the decoder is a simple linear layer
self.decoder = nn.Linear(vit.hidden_dim, vit.patch_size**2 * 3)
# L1 loss as paper suggestion
self.criterion = nn.L1Loss()
def forward_encoder(self, images, batch_size, idx_mask):
# pass all the tokens to the encoder, both masked and non masked ones
return self.backbone.encode(images=images, idx_mask=idx_mask)
def forward_decoder(self, x_encoded):
return self.decoder(x_encoded)
def training_step(self, batch, batch_idx):
views = batch[0]
images = views[0] # views contains only a single view
batch_size = images.shape[0]
idx_keep, idx_mask = utils.random_token_mask(
size=(batch_size, self.sequence_length),
mask_ratio=self.mask_ratio,
device=images.device,
)
# Encoding...
x_encoded = self.forward_encoder(images, batch_size, idx_mask)
x_encoded_masked = utils.get_at_index(x_encoded, idx_mask)
# Decoding...
x_out = self.forward_decoder(x_encoded_masked)
# get image patches for masked tokens
patches = utils.patchify(images, self.patch_size)
# must adjust idx_mask for missing class token
target = utils.get_at_index(patches, idx_mask - 1)
loss = self.criterion(x_out, target)
return loss
def configure_optimizers(self):
optim = torch.optim.AdamW(self.parameters(), lr=1.5e-4)
return optim
model = SimMIM()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
transform = MAETransform()
# we ignore object detection annotations by setting target_transform to return 0
def target_transform(t):
return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=target_transform,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=8,
shuffle=True,
drop_last=True,
num_workers=8,
)
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)
trainer.fit(model=model, train_dataloaders=dataloader)
This example runs on multiple gpus using Distributed Data Parallel (DDP) training with Pytorch Lightning. At least one GPU must be available on the system. The example can be run from the command line with:
python lightly/examples/pytorch_lightning_distributed/simmim.py
The model differs in the following ways from the non-distributed implementation:
Distributed Data Parallel is enabled
Distributed Sampling is used in the dataloader
Distributed Sampling makes sure that each distributed process sees only a subset of the data.
# This example requires the following dependencies to be installed:
# pip install lightly
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.models import utils
from lightly.models.modules import MaskedVisionTransformerTorchvision
from lightly.transforms.mae_transform import MAETransform # Same transform as MAE
class SimMIM(pl.LightningModule):
def __init__(self):
super().__init__()
vit = torchvision.models.vit_b_32(pretrained=False)
self.mask_ratio = 0.75
self.patch_size = vit.patch_size
self.sequence_length = vit.seq_length
decoder_dim = vit.hidden_dim
self.backbone = MaskedVisionTransformerTorchvision(vit=vit)
# the decoder is a simple linear layer
self.decoder = nn.Linear(decoder_dim, vit.patch_size**2 * 3)
# L1 loss as paper suggestion
self.criterion = nn.L1Loss()
def forward_encoder(self, images, batch_size, idx_mask):
# pass all the tokens to the encoder, both masked and non masked ones
return self.backbone.encode(images=images, idx_mask=idx_mask)
def forward_decoder(self, x_encoded):
return self.decoder(x_encoded)
def training_step(self, batch, batch_idx):
views = batch[0]
images = views[0] # views contains only a single view
batch_size = images.shape[0]
idx_keep, idx_mask = utils.random_token_mask(
size=(batch_size, self.sequence_length),
mask_ratio=self.mask_ratio,
device=images.device,
)
# Encoding...
x_encoded = self.forward_encoder(images, batch_size, idx_mask)
x_encoded_masked = utils.get_at_index(x_encoded, idx_mask)
# Decoding...
x_out = self.forward_decoder(x_encoded_masked)
# get image patches for masked tokens
patches = utils.patchify(images, self.patch_size)
# must adjust idx_mask for missing class token
target = utils.get_at_index(patches, idx_mask - 1)
loss = self.criterion(x_out, target)
return loss
def configure_optimizers(self):
optim = torch.optim.AdamW(self.parameters(), lr=1.5e-4)
return optim
model = SimMIM()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
transform = MAETransform()
# we ignore object detection annotations by setting target_transform to return 0
def target_transform(t):
return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=target_transform,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=8,
shuffle=True,
drop_last=True,
num_workers=8,
)
# Train with DDP on multiple gpus. Distributed sampling is also enabled with
# replace_sampler_ddp=True.
trainer = pl.Trainer(
max_epochs=10,
devices="auto",
accelerator="gpu",
strategy="ddp",
use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0
)
trainer.fit(model=model, train_dataloaders=dataloader)