SimCLR
SimCLR is a self-supervised framework for visual representation learning using contrastive methods. It learns by creating two augmented views of the same image—using random cropping, color jitter, and Gaussian blur—then maximizing agreement between these augmented views while separating them from other images. Key findings include the importance of strong compositions of data augmentations, a nonlinear projection head that boosts representation quality, and the advantages of large batch sizes. Combined, these elements allow SimCLR to approach or match supervised performance on ImageNet and achieve strong transfer and semi-supervised learning results.
Key Components
Data Augmentations: SimCLR uses random cropping, resizing, color jittering, and Gaussian blur to create diverse views of the same image.
Backbone: Convolutional neural networks, such as ResNet, are employed to encode augmented images into feature representations.
Projection Head: A multilayer perceptron (MLP) maps features into a space where contrastive loss is applied, enhancing representation quality.
Contrastive Loss: The normalized temperature-scaled cross-entropy loss (NT-Xent) encourages similar pairs to align and dissimilar pairs to diverge.
Good to Know
Backbone Networks: SimCLR is specifically optimized for convolutional neural networks, with a focus on ResNet architectures. We do not recommend using it with transformer-based models.
Learning Paradigm: SimCLR is based on contrastive learning which makes it sensitive to the augmentations you pick and the method benefits from larger batch sizes.
- Reference:
A Simple Framework for Contrastive Learning of Visual Representations, 2020
- Tutorials:
This example can be run from the command line with:
python lightly/examples/pytorch/simclr.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import torch
import torchvision
from torch import nn
from lightly.loss import NTXentLoss
from lightly.models.modules import SimCLRProjectionHead
from lightly.transforms.simclr_transform import SimCLRTransform
class SimCLR(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
self.projection_head = SimCLRProjectionHead(512, 512, 128)
def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(x)
return z
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = SimCLR(backbone)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
transform = SimCLRTransform(input_size=32, gaussian_blur=0.0)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
criterion = NTXentLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)
print("Starting Training")
for epoch in range(10):
total_loss = 0
for batch in dataloader:
x0, x1 = batch[0]
x0 = x0.to(device)
x1 = x1.to(device)
z0 = model(x0)
z1 = model(x1)
loss = criterion(z0, z1)
total_loss += loss.detach()
loss.backward()
optimizer.step()
optimizer.zero_grad()
avg_loss = total_loss / len(dataloader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
This example can be run from the command line with:
python lightly/examples/pytorch_lightning/simclr.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import NTXentLoss
from lightly.models.modules import SimCLRProjectionHead
from lightly.transforms.simclr_transform import SimCLRTransform
class SimCLR(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = SimCLRProjectionHead(512, 2048, 2048)
self.criterion = NTXentLoss()
def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(x)
return z
def training_step(self, batch, batch_index):
(x0, x1) = batch[0]
z0 = self.forward(x0)
z1 = self.forward(x1)
loss = self.criterion(z0, z1)
return loss
def configure_optimizers(self):
optim = torch.optim.SGD(self.parameters(), lr=0.06)
return optim
model = SimCLR()
transform = SimCLRTransform(input_size=32)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)
trainer.fit(model=model, train_dataloaders=dataloader)
This example runs on multiple gpus using Distributed Data Parallel (DDP) training with Pytorch Lightning. At least one GPU must be available on the system. The example can be run from the command line with:
python lightly/examples/pytorch_lightning_distributed/simclr.py
The model differs in the following ways from the non-distributed implementation:
Distributed Data Parallel is enabled
Synchronized Batch Norm is used in place of standard Batch Norm
Features are gathered from all GPUs before the loss is calculated
Note that Synchronized Batch Norm and feature gathering are optional and the model can also be trained without them. Without Synchronized Batch Norm and feature gathering the batch norm and loss for each GPU are only calculated based on the features on that specific GPU.
# This example requires the following dependencies to be installed:
# pip install lightly
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import NTXentLoss
from lightly.models.modules import SimCLRProjectionHead
from lightly.transforms.simclr_transform import SimCLRTransform
class SimCLR(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = SimCLRProjectionHead(512, 2048, 2048)
# enable gather_distributed to gather features from all gpus
# before calculating the loss
self.criterion = NTXentLoss(gather_distributed=True)
def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(x)
return z
def training_step(self, batch, batch_index):
(x0, x1) = batch[0]
z0 = self.forward(x0)
z1 = self.forward(x1)
loss = self.criterion(z0, z1)
return loss
def configure_optimizers(self):
optim = torch.optim.SGD(self.parameters(), lr=0.06)
return optim
model = SimCLR()
transform = SimCLRTransform(input_size=32)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm
# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.
trainer = pl.Trainer(
max_epochs=10,
devices="auto",
accelerator="gpu",
strategy="ddp",
sync_batchnorm=True,
use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0
)
trainer.fit(model=model, train_dataloaders=dataloader)