VICRegL
VICRegL (VICRegL: Self-Supervised Learning of Local Visual Features) is a method derived from VICReg, 2022. As the standard VICReg, it avoids the collapse problem with a simple regularization term on the variance of the embeddings along each dimension individually. Moreover, it learns good global and local features simultaneously, yielding excellent performance on detection and segmentation tasks while maintaining good performance on classification tasks.
This example can be run from the command line with:
python lightly/examples/pytorch/vicregl.py
# This example requires the following dependencies to be installed:
# pip install lightly
import torch
import torchvision
from torch import nn
from lightly.loss import VICRegLLoss
## The global projection head is the same as the Barlow Twins one
from lightly.models.modules import BarlowTwinsProjectionHead
from lightly.models.modules.heads import VicRegLLocalProjectionHead
from lightly.transforms.vicregl_transform import VICRegLTransform
class VICRegL(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)
self.local_projection_head = VicRegLLocalProjectionHead(512, 128, 128)
self.average_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
def forward(self, x):
x = self.backbone(x)
y = self.average_pool(x).flatten(start_dim=1)
z = self.projection_head(y)
y_local = x.permute(0, 2, 3, 1) # (B, D, W, H) to (B, W, H, D)
z_local = self.local_projection_head(y_local)
return z, z_local
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-2])
model = VICRegL(backbone)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
transform = VICRegLTransform(n_local_views=0)
# we ignore object detection annotations by setting target_transform to return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=lambda t: 0,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
criterion = VICRegLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
print("Starting Training")
for epoch in range(10):
total_loss = 0
for views_and_grids, _ in dataloader:
views_and_grids = [x.to(device) for x in views_and_grids]
views = views_and_grids[: len(views_and_grids) // 2]
grids = views_and_grids[len(views_and_grids) // 2 :]
features = [model(view) for view in views]
loss = criterion(
global_view_features=features[:2],
global_view_grids=grids[:2],
local_view_features=features[2:],
local_view_grids=grids[2:],
)
total_loss += loss.detach()
loss.backward()
optimizer.step()
optimizer.zero_grad()
avg_loss = total_loss / len(dataloader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
This example can be run from the command line with:
python lightly/examples/pytorch_lightning/vicregl.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import VICRegLLoss
## The global projection head is the same as the Barlow Twins one
from lightly.models.modules import BarlowTwinsProjectionHead
from lightly.models.modules.heads import VicRegLLocalProjectionHead
from lightly.transforms.vicregl_transform import VICRegLTransform
class VICRegL(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-2])
self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)
self.local_projection_head = VicRegLLocalProjectionHead(512, 128, 128)
self.average_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
self.criterion = VICRegLLoss()
def forward(self, x):
x = self.backbone(x)
y = self.average_pool(x).flatten(start_dim=1)
z = self.projection_head(y)
y_local = x.permute(0, 2, 3, 1) # (B, D, W, H) to (B, W, H, D)
z_local = self.local_projection_head(y_local)
return z, z_local
def training_step(self, batch, batch_index):
views_and_grids = batch[0]
views = views_and_grids[: len(views_and_grids) // 2]
grids = views_and_grids[len(views_and_grids) // 2 :]
features = [self.forward(view) for view in views]
loss = self.criterion(
global_view_features=features[:2],
global_view_grids=grids[:2],
local_view_features=features[2:],
local_view_grids=grids[2:],
)
return loss
def configure_optimizers(self):
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
return optim
model = VICRegL()
transform = VICRegLTransform(n_local_views=0)
# we ignore object detection annotations by setting target_transform to return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=lambda t: 0,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)
trainer.fit(model=model, train_dataloaders=dataloader)
This example runs on multiple gpus using Distributed Data Parallel (DDP) training with Pytorch Lightning. At least one GPU must be available on the system. The example can be run from the command line with:
python lightly/examples/pytorch_lightning_distributed/vicregl.py
The model differs in the following ways from the non-distributed implementation:
Distributed Data Parallel is enabled
Distributed Sampling is used in the dataloader
Distributed Sampling makes sure that each distributed process sees only a subset of the data.
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import VICRegLLoss
## The global projection head is the same as the Barlow Twins one
from lightly.models.modules import BarlowTwinsProjectionHead
from lightly.models.modules.heads import VicRegLLocalProjectionHead
from lightly.transforms.vicregl_transform import VICRegLTransform
class VICRegL(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-2])
self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)
self.local_projection_head = VicRegLLocalProjectionHead(512, 128, 128)
self.average_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
self.criterion = VICRegLLoss()
def forward(self, x):
x = self.backbone(x)
y = self.average_pool(x).flatten(start_dim=1)
z = self.projection_head(y)
y_local = x.permute(0, 2, 3, 1) # (B, D, W, H) to (B, W, H, D)
z_local = self.local_projection_head(y_local)
return z, z_local
def training_step(self, batch, batch_index):
views_and_grids = batch[0]
views = views_and_grids[: len(views_and_grids) // 2]
grids = views_and_grids[len(views_and_grids) // 2 :]
features = [self.forward(view) for view in views]
loss = self.criterion(
global_view_features=features[:2],
global_view_grids=grids[:2],
local_view_features=features[2:],
local_view_grids=grids[2:],
)
return loss
def configure_optimizers(self):
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
return optim
model = VICRegL()
transform = VICRegLTransform(n_local_views=0)
# we ignore object detection annotations by setting target_transform to return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=lambda t: 0,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm
# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.
trainer = pl.Trainer(
max_epochs=10,
devices="auto",
accelerator="gpu",
strategy="ddp",
sync_batchnorm=True,
use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0
)
trainer.fit(model=model, train_dataloaders=dataloader)