PMSN
Prior Matching for Siamese Networks (PMSN) builds on top of MSN by adding support for custom clustering priors. This is especially helpful for datasets with non-uniform class distributions. By default, PMSN uses a power law distribution which is ideal for datasets with long tail distributions.
For PMSN, you can use the exact same code as for MSN but change
lightly.loss.msn_loss.MSNLoss
to lightly.loss.pmsn_loss.PMSNLoss
:
# instead of this
from lightly.loss import MSNLoss
criterion = MSNLoss()
# use this
from lightly.loss import PMSNLoss
criterion = PMSNLoss(power_law_exponent=0.25)
# or define your custom target distribution
from lightly.loss import PMSNCustomLoss
def my_uniform_target_distribution(mean_anchor_probabilities: Tensor) -> Tensor:
dim = mean_anchor_probabilities.shape[0]
return mean_anchor_probabilities.new_ones(dim) / dim
criterion = PMSNCustomLoss(target_distribution=my_uniform_target_distribution)
This example can be run from the command line with:
python lightly/examples/pytorch/pmsn.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import copy
import torch
import torchvision
from torch import nn
from lightly.loss import PMSNLoss
from lightly.models import utils
from lightly.models.modules import MaskedVisionTransformerTorchvision
from lightly.models.modules.heads import MSNProjectionHead
from lightly.transforms import MSNTransform
class PMSN(nn.Module):
def __init__(self, vit):
super().__init__()
self.mask_ratio = 0.15
self.backbone = MaskedVisionTransformerTorchvision(vit=vit)
self.projection_head = MSNProjectionHead(384)
self.anchor_backbone = copy.deepcopy(self.backbone)
self.anchor_projection_head = copy.deepcopy(self.projection_head)
utils.deactivate_requires_grad(self.backbone)
utils.deactivate_requires_grad(self.projection_head)
self.prototypes = nn.Linear(256, 1024, bias=False).weight
def forward(self, images):
out = self.backbone(images=images)
return self.projection_head(out)
def forward_masked(self, images):
batch_size, _, _, width = images.shape
seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2
idx_keep, _ = utils.random_token_mask(
size=(batch_size, seq_length),
mask_ratio=self.mask_ratio,
device=images.device,
)
out = self.anchor_backbone(images=images, idx_keep=idx_keep)
return self.anchor_projection_head(out)
# ViT small configuration (ViT-S/16)
vit = torchvision.models.VisionTransformer(
image_size=224,
patch_size=16,
num_layers=12,
num_heads=6,
hidden_dim=384,
mlp_dim=384 * 4,
)
model = PMSN(vit)
# # or use a torchvision ViT backbone:
# vit = torchvision.models.vit_b_32(pretrained=False)
# model = PMSN(vit)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
transform = MSNTransform()
# we ignore object detection annotations by setting target_transform to return 0
def target_transform(t):
return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=target_transform,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=64,
shuffle=True,
drop_last=True,
num_workers=8,
)
criterion = PMSNLoss()
params = [
*list(model.anchor_backbone.parameters()),
*list(model.anchor_projection_head.parameters()),
model.prototypes,
]
optimizer = torch.optim.AdamW(params, lr=1.5e-4)
print("Starting Training")
for epoch in range(10):
total_loss = 0
for batch in dataloader:
views = batch[0]
utils.update_momentum(model.anchor_backbone, model.backbone, 0.996)
utils.update_momentum(
model.anchor_projection_head, model.projection_head, 0.996
)
views = [view.to(device, non_blocking=True) for view in views]
targets = views[0]
anchors = views[1]
anchors_focal = torch.concat(views[2:], dim=0)
targets_out = model.backbone(images=targets)
targets_out = model.projection_head(targets_out)
anchors_out = model.forward_masked(anchors)
anchors_focal_out = model.forward_masked(anchors_focal)
anchors_out = torch.cat([anchors_out, anchors_focal_out], dim=0)
loss = criterion(anchors_out, targets_out, model.prototypes.data)
total_loss += loss.detach()
loss.backward()
optimizer.step()
optimizer.zero_grad()
avg_loss = total_loss / len(dataloader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
This example can be run from the command line with:
python lightly/examples/pytorch_lightning/pmsn.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import copy
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import PMSNLoss
from lightly.models import utils
from lightly.models.modules import MaskedVisionTransformerTorchvision
from lightly.models.modules.heads import MSNProjectionHead
from lightly.transforms import MSNTransform
class PMSN(pl.LightningModule):
def __init__(self):
super().__init__()
# ViT small configuration (ViT-S/16)
self.mask_ratio = 0.15
vit = torchvision.models.VisionTransformer(
image_size=224,
patch_size=16,
num_layers=12,
num_heads=6,
hidden_dim=384,
mlp_dim=384 * 4,
)
self.backbone = MaskedVisionTransformerTorchvision(vit=vit)
# or use a torchvision ViT backbone:
# vit = torchvision.models.vit_b_32(pretrained=False)
# self.backbone = MAEBackbone.from_vit(vit)
self.projection_head = MSNProjectionHead(384)
self.anchor_backbone = copy.deepcopy(self.backbone)
self.anchor_projection_head = copy.deepcopy(self.projection_head)
utils.deactivate_requires_grad(self.backbone)
utils.deactivate_requires_grad(self.projection_head)
self.prototypes = nn.Linear(256, 1024, bias=False).weight
self.criterion = PMSNLoss()
def training_step(self, batch, batch_idx):
utils.update_momentum(self.anchor_backbone, self.backbone, 0.996)
utils.update_momentum(self.anchor_projection_head, self.projection_head, 0.996)
views = batch[0]
views = [view.to(self.device, non_blocking=True) for view in views]
targets = views[0]
anchors = views[1]
anchors_focal = torch.concat(views[2:], dim=0)
targets_out = self.backbone(images=targets)
targets_out = self.projection_head(targets_out)
anchors_out = self.encode_masked(anchors)
anchors_focal_out = self.encode_masked(anchors_focal)
anchors_out = torch.cat([anchors_out, anchors_focal_out], dim=0)
loss = self.criterion(anchors_out, targets_out, self.prototypes.data)
return loss
def encode_masked(self, anchors):
batch_size, _, _, width = anchors.shape
seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2
idx_keep, _ = utils.random_token_mask(
size=(batch_size, seq_length),
mask_ratio=self.mask_ratio,
device=self.device,
)
out = self.anchor_backbone(images=anchors, idx_keep=idx_keep)
return self.anchor_projection_head(out)
def configure_optimizers(self):
params = [
*list(self.anchor_backbone.parameters()),
*list(self.anchor_projection_head.parameters()),
self.prototypes,
]
optim = torch.optim.AdamW(params, lr=1.5e-4)
return optim
model = PMSN()
transform = MSNTransform()
# we ignore object detection annotations by setting target_transform to return 0
def target_transform(t):
return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=target_transform,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=64,
shuffle=True,
drop_last=True,
num_workers=8,
)
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)
trainer.fit(model=model, train_dataloaders=dataloader)
This example runs on multiple gpus using Distributed Data Parallel (DDP) training with Pytorch Lightning. At least one GPU must be available on the system. The example can be run from the command line with:
python lightly/examples/pytorch_lightning_distributed/pmsn.py
The model differs in the following ways from the non-distributed implementation:
Distributed Data Parallel is enabled
Distributed Sampling is used in the dataloader
Distributed Sinkhorn is used in the loss calculation
Distributed Sampling makes sure that each distributed process sees only a subset of the data.
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import copy
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import PMSNLoss
from lightly.models import utils
from lightly.models.modules import MaskedVisionTransformerTorchvision
from lightly.models.modules.heads import MSNProjectionHead
from lightly.transforms import MSNTransform
class PMSN(pl.LightningModule):
def __init__(self):
super().__init__()
# ViT small configuration (ViT-S/16)
self.mask_ratio = 0.15
vit = torchvision.models.VisionTransformer(
image_size=224,
patch_size=16,
num_layers=12,
num_heads=6,
hidden_dim=384,
mlp_dim=384 * 4,
)
self.backbone = MaskedVisionTransformerTorchvision(vit=vit)
# or use a torchvision ViT backbone:
# vit = torchvision.models.vit_b_32(pretrained=False)
# self.backbone = MAEBackbone.from_vit(vit)
self.projection_head = MSNProjectionHead(384)
self.anchor_backbone = copy.deepcopy(self.backbone)
self.anchor_projection_head = copy.deepcopy(self.projection_head)
utils.deactivate_requires_grad(self.backbone)
utils.deactivate_requires_grad(self.projection_head)
self.prototypes = nn.Linear(256, 1024, bias=False).weight
# set gather_distributed to True for distributed training
self.criterion = PMSNLoss(gather_distributed=True)
def training_step(self, batch, batch_idx):
utils.update_momentum(self.anchor_backbone, self.backbone, 0.996)
utils.update_momentum(self.anchor_projection_head, self.projection_head, 0.996)
views = batch[0]
views = [view.to(self.device, non_blocking=True) for view in views]
targets = views[0]
anchors = views[1]
anchors_focal = torch.concat(views[2:], dim=0)
targets_out = self.backbone(images=targets)
targets_out = self.projection_head(targets_out)
anchors_out = self.encode_masked(anchors)
anchors_focal_out = self.encode_masked(anchors_focal)
anchors_out = torch.cat([anchors_out, anchors_focal_out], dim=0)
loss = self.criterion(anchors_out, targets_out, self.prototypes.data)
return loss
def encode_masked(self, anchors):
batch_size, _, _, width = anchors.shape
seq_length = (width // self.anchor_backbone.vit.patch_size) ** 2
idx_keep, _ = utils.random_token_mask(
size=(batch_size, seq_length),
mask_ratio=self.mask_ratio,
device=self.device,
)
out = self.anchor_backbone(images=anchors, idx_keep=idx_keep)
return self.anchor_projection_head(out)
def configure_optimizers(self):
params = [
*list(self.anchor_backbone.parameters()),
*list(self.anchor_projection_head.parameters()),
self.prototypes,
]
optim = torch.optim.AdamW(params, lr=1.5e-4)
return optim
model = PMSN()
transform = MSNTransform()
# we ignore object detection annotations by setting target_transform to return 0
def target_transform(t):
return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=target_transform,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=64,
shuffle=True,
drop_last=True,
num_workers=8,
)
gpus = torch.cuda.device_count()
# Train with DDP on multiple gpus. Distributed sampling is also enabled with
# replace_sampler_ddp=True.
trainer = pl.Trainer(
max_epochs=10,
devices="auto",
accelerator="gpu",
strategy="ddp",
use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0
)
trainer.fit(model=model, train_dataloaders=dataloader)