MSN
Example implementation of the Masked Siamese Networks (MSN) architecture. MSN is a transformer model based on the Vision Transformer (ViT) architecture. It learns image representations by comparing cluster assignments of masked and unmasked image views. The network is split into a target and anchor network. The target network creates representations from unmasked image views while the anchor network receives a masked image view. MSN increases training efficiency as the backward pass is only calculated for the anchor network. The target network is updated via momentum from the anchor network.
See PMSN for a version of MSN for datasets with non-uniform class distributions.
This example can be run from the command line with:
python lightly/examples/pytorch/msn.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import copy
import torch
import torchvision
from torch import nn
from lightly.loss import MSNLoss
from lightly.models import utils
from lightly.models.modules import MaskedVisionTransformerTorchvision
from lightly.models.modules.heads import MSNProjectionHead
from lightly.transforms.msn_transform import MSNTransform
class MSN(nn.Module):
def __init__(self, vit):
super().__init__()
self.mask_ratio = 0.15
self.backbone = MaskedVisionTransformerTorchvision(vit=vit)
self.projection_head = MSNProjectionHead(input_dim=384)
self.anchor_backbone = copy.deepcopy(self.backbone)
self.anchor_projection_head = copy.deepcopy(self.projection_head)
utils.deactivate_requires_grad(self.backbone)
utils.deactivate_requires_grad(self.projection_head)
self.prototypes = nn.Linear(256, 1024, bias=False).weight
def forward(self, images):
out = self.backbone(images=images)
return self.projection_head(out)
def forward_masked(self, images):
batch_size, _, _, width = images.shape
seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2
idx_keep, _ = utils.random_token_mask(
size=(batch_size, seq_length),
mask_ratio=self.mask_ratio,
device=images.device,
)
out = self.anchor_backbone(images=images, idx_keep=idx_keep)
return self.anchor_projection_head(out)
# ViT small configuration (ViT-S/16)
vit = torchvision.models.VisionTransformer(
image_size=224,
patch_size=16,
num_layers=12,
num_heads=6,
hidden_dim=384,
mlp_dim=384 * 4,
)
model = MSN(vit)
# or use a torchvision ViT backbone:
# vit = torchvision.models.vit_b_32(pretrained=False)
# model = MSN(vit)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
transform = MSNTransform()
# we ignore object detection annotations by setting target_transform to return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=lambda t: 0,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=64,
shuffle=True,
drop_last=True,
num_workers=8,
)
criterion = MSNLoss()
params = [
*list(model.anchor_backbone.parameters()),
*list(model.anchor_projection_head.parameters()),
model.prototypes,
]
optimizer = torch.optim.AdamW(params, lr=1.5e-4)
print("Starting Training")
for epoch in range(10):
total_loss = 0
for batch in dataloader:
views = batch[0]
utils.update_momentum(model.anchor_backbone, model.backbone, 0.996)
utils.update_momentum(
model.anchor_projection_head, model.projection_head, 0.996
)
views = [view.to(device, non_blocking=True) for view in views]
targets = views[0]
anchors = views[1]
anchors_focal = torch.concat(views[2:], dim=0)
targets_out = model.backbone(targets)
targets_out = model.projection_head(targets_out)
anchors_out = model.forward_masked(anchors)
anchors_focal_out = model.forward_masked(anchors_focal)
anchors_out = torch.cat([anchors_out, anchors_focal_out], dim=0)
loss = criterion(anchors_out, targets_out, model.prototypes.data)
total_loss += loss.detach()
loss.backward()
optimizer.step()
optimizer.zero_grad()
avg_loss = total_loss / len(dataloader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
This example can be run from the command line with:
python lightly/examples/pytorch_lightning/msn.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import copy
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import MSNLoss
from lightly.models import utils
from lightly.models.modules import MaskedVisionTransformerTorchvision
from lightly.models.modules.heads import MSNProjectionHead
from lightly.transforms.msn_transform import MSNTransform
class MSN(pl.LightningModule):
def __init__(self):
super().__init__()
# ViT small configuration (ViT-S/16)
self.mask_ratio = 0.15
# ViT small configuration (ViT-S/16)
vit = torchvision.models.VisionTransformer(
image_size=224,
patch_size=16,
num_layers=12,
num_heads=6,
hidden_dim=384,
mlp_dim=384 * 4,
)
self.backbone = MaskedVisionTransformerTorchvision(vit=vit)
# or use a torchvision ViT backbone:
# vit = torchvision.models.vit_b_32(pretrained=False)
# self.backbone = MAEBackbone.from_vit(vit)
self.projection_head = MSNProjectionHead(384)
self.anchor_backbone = copy.deepcopy(self.backbone)
self.anchor_projection_head = copy.deepcopy(self.projection_head)
utils.deactivate_requires_grad(self.backbone)
utils.deactivate_requires_grad(self.projection_head)
self.prototypes = nn.Linear(256, 1024, bias=False).weight
self.criterion = MSNLoss()
def training_step(self, batch, batch_idx):
utils.update_momentum(self.anchor_backbone, self.backbone, 0.996)
utils.update_momentum(self.anchor_projection_head, self.projection_head, 0.996)
views = batch[0]
views = [view.to(self.device, non_blocking=True) for view in views]
targets = views[0]
anchors = views[1]
anchors_focal = torch.concat(views[2:], dim=0)
targets_out = self.backbone(images=targets)
targets_out = self.projection_head(targets_out)
anchors_out = self.encode_masked(anchors)
anchors_focal_out = self.encode_masked(anchors_focal)
anchors_out = torch.cat([anchors_out, anchors_focal_out], dim=0)
loss = self.criterion(anchors_out, targets_out, self.prototypes.data)
return loss
def encode_masked(self, anchors):
batch_size, _, _, width = anchors.shape
seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2
idx_keep, _ = utils.random_token_mask(
size=(batch_size, seq_length),
mask_ratio=self.mask_ratio,
device=self.device,
)
out = self.anchor_backbone(images=anchors, idx_keep=idx_keep)
return self.anchor_projection_head(out)
def configure_optimizers(self):
params = [
*list(self.anchor_backbone.parameters()),
*list(self.anchor_projection_head.parameters()),
self.prototypes,
]
optim = torch.optim.AdamW(params, lr=1.5e-4)
return optim
model = MSN()
transform = MSNTransform()
# we ignore object detection annotations by setting target_transform to return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=lambda t: 0,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=64,
shuffle=True,
drop_last=True,
num_workers=8,
)
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)
trainer.fit(model=model, train_dataloaders=dataloader)
This example runs on multiple gpus using Distributed Data Parallel (DDP) training with Pytorch Lightning. At least one GPU must be available on the system. The example can be run from the command line with:
python lightly/examples/pytorch_lightning_distributed/msn.py
The model differs in the following ways from the non-distributed implementation:
Distributed Data Parallel is enabled
Distributed Sampling is used in the dataloader
Distributed Sinkhorn is used in the loss calculation
Distributed Sampling makes sure that each distributed process sees only a subset of the data.
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import copy
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import MSNLoss
from lightly.models import utils
from lightly.models.modules import MaskedVisionTransformerTorchvision
from lightly.models.modules.heads import MSNProjectionHead
from lightly.transforms.msn_transform import MSNTransform
class MSN(pl.LightningModule):
def __init__(self):
super().__init__()
# ViT small configuration (ViT-S/16)
self.mask_ratio = 0.15
# ViT small configuration (ViT-S/16)
vit = torchvision.models.VisionTransformer(
image_size=224,
patch_size=16,
num_layers=12,
num_heads=6,
hidden_dim=384,
mlp_dim=384 * 4,
)
self.backbone = MaskedVisionTransformerTorchvision(vit=vit)
# or use a torchvision ViT backbone:
# vit = torchvision.models.vit_b_32(pretrained=False)
# self.backbone = MAEBackbone.from_vit(vit)
self.projection_head = MSNProjectionHead(384)
self.anchor_backbone = copy.deepcopy(self.backbone)
self.anchor_projection_head = copy.deepcopy(self.projection_head)
utils.deactivate_requires_grad(self.backbone)
utils.deactivate_requires_grad(self.projection_head)
self.prototypes = nn.Linear(256, 1024, bias=False).weight
# set gather_distributed to True for distributed training
self.criterion = MSNLoss(gather_distributed=True)
def training_step(self, batch, batch_idx):
utils.update_momentum(self.anchor_backbone, self.backbone, 0.996)
utils.update_momentum(self.anchor_projection_head, self.projection_head, 0.996)
views = batch[0]
views = [view.to(self.device, non_blocking=True) for view in views]
targets = views[0]
anchors = views[1]
anchors_focal = torch.concat(views[2:], dim=0)
targets_out = self.backbone(images=targets)
targets_out = self.projection_head(targets_out)
anchors_out = self.encode_masked(anchors)
anchors_focal_out = self.encode_masked(anchors_focal)
anchors_out = torch.cat([anchors_out, anchors_focal_out], dim=0)
loss = self.criterion(anchors_out, targets_out, self.prototypes.data)
return loss
def encode_masked(self, anchors):
batch_size, _, _, width = anchors.shape
seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2
idx_keep, _ = utils.random_token_mask(
size=(batch_size, seq_length),
mask_ratio=self.mask_ratio,
device=self.device,
)
out = self.anchor_backbone(images=anchors, idx_keep=idx_keep)
return self.anchor_projection_head(out)
def configure_optimizers(self):
params = [
*list(self.anchor_backbone.parameters()),
*list(self.anchor_projection_head.parameters()),
self.prototypes,
]
optim = torch.optim.AdamW(params, lr=1.5e-4)
return optim
model = MSN()
transform = MSNTransform()
# we ignore object detection annotations by setting target_transform to return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=lambda t: 0,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=64,
shuffle=True,
drop_last=True,
num_workers=8,
)
# Train with DDP on multiple gpus. Distributed sampling is also enabled with
# replace_sampler_ddp=True.
trainer = pl.Trainer(
max_epochs=10,
devices="auto",
accelerator="gpu",
strategy="ddp",
use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0
)
trainer.fit(model=model, train_dataloaders=dataloader)