NNCLR
NNCLR is a self-supervised framework for visual representation learning that builds upon contrastive methods. It shares similarities with SimCLR, such as using two augmented views of the same image, projection and prediction heads, and a contrastive loss. However, it introduces key modifications:
Nearest Neighbor Replacement: Instead of directly comparing two augmented views of the same sample, NNCLR replaces each sample with its nearest neighbor in a support set (or memory bank). This increases semantic variation in the learned representations.
Symmetric Loss: The contrastive loss is made symmetric to improve training stability.
Architectural Adjustments: NNCLR employs different sizes for projection and prediction head layers compared to SimCLR.
These improvements result in significantly better performance across multiple self-supervised learning benchmarks. Compared to SimCLR and other self-supervised methods, NNCLR achieves:
Higher ImageNet linear evaluation accuracy.
Improved semi-supervised learning results.
Superior performance on transfer learning tasks, outperforming BYOL, SimCLR, and even supervised ImageNet pretraining in 8 out of 12 benchmarked cases.
Key Components
Data Augmentations: NNCLR applies the same transformations as SimCLR, including random cropping, resizing, color jittering, and Gaussian blur, to create diverse views of the same image.
Backbone: A convolutional neural network (typically ResNet) encodes augmented images into feature representations.
Projection Head: A multilayer perceptron (MLP) maps features into a contrastive space, improving representation learning.
Memory Bank: NNCLR maintains a first-in, first-out (FIFO) memory bank, storing past feature representations. Older features are gradually discarded, ensuring a large and diverse set approximating the full dataset.
Nearest Neighbor Sampling: Each feature representation is replaced by its nearest neighbor from the memory bank, introducing additional semantic variation beyond standard augmentations.
Contrastive Loss: NNCLR employs normalized temperature-scaled cross-entropy loss (NT-Xent), encouraging alignment between positive pairs and separation from negative pairs.
Good to Know
Optimized for CNNs: NNCLR is specifically designed for convolutional neural networks (CNNs), particularly ResNet. It is not recommended for transformer-based architectures.
Augmentation Robustness: Compared to SimCLR, NNCLR is less dependent on strong augmentations since nearest neighbor sampling introduces natural semantic variation. However, performance still benefits from well-chosen augmentations and larger batch sizes.
This example can be run from the command line with:
python lightly/examples/pytorch/nnclr.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import torch
import torchvision
from torch import nn
from lightly.loss import NTXentLoss
from lightly.models.modules import (
NNCLRPredictionHead,
NNCLRProjectionHead,
NNMemoryBankModule,
)
from lightly.transforms.simclr_transform import SimCLRTransform
class NNCLR(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
self.projection_head = NNCLRProjectionHead(512, 512, 128)
self.prediction_head = NNCLRPredictionHead(128, 512, 128)
def forward(self, x):
y = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(y)
p = self.prediction_head(z)
z = z.detach()
return z, p
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = NNCLR(backbone)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
memory_bank = NNMemoryBankModule(size=(4096, 128))
memory_bank.to(device)
transform = SimCLRTransform(input_size=32)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
criterion = NTXentLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)
print("Starting Training")
for epoch in range(10):
total_loss = 0
for batch in dataloader:
x0, x1 = batch[0]
x0 = x0.to(device)
x1 = x1.to(device)
z0, p0 = model(x0)
z1, p1 = model(x1)
z0 = memory_bank(z0, update=False)
z1 = memory_bank(z1, update=True)
loss = 0.5 * (criterion(z0, p1) + criterion(z1, p0))
total_loss += loss.detach()
loss.backward()
optimizer.step()
optimizer.zero_grad()
avg_loss = total_loss / len(dataloader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
This example can be run from the command line with:
python lightly/examples/pytorch_lightning/nnclr.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import NTXentLoss
from lightly.models.modules import (
NNCLRPredictionHead,
NNCLRProjectionHead,
NNMemoryBankModule,
)
from lightly.transforms.simclr_transform import SimCLRTransform
class NNCLR(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = NNCLRProjectionHead(512, 512, 128)
self.prediction_head = NNCLRPredictionHead(128, 512, 128)
self.memory_bank = NNMemoryBankModule(size=(4096, 128))
self.criterion = NTXentLoss()
def forward(self, x):
y = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(y)
p = self.prediction_head(z)
z = z.detach()
return z, p
def training_step(self, batch, batch_idx):
(x0, x1) = batch[0]
z0, p0 = self.forward(x0)
z1, p1 = self.forward(x1)
z0 = self.memory_bank(z0, update=False)
z1 = self.memory_bank(z1, update=True)
loss = 0.5 * (self.criterion(z0, p1) + self.criterion(z1, p0))
return loss
def configure_optimizers(self):
optim = torch.optim.SGD(self.parameters(), lr=0.06)
return optim
model = NNCLR()
transform = SimCLRTransform(input_size=32)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)
trainer.fit(model=model, train_dataloaders=dataloader)
This example runs on multiple gpus using Distributed Data Parallel (DDP) training with Pytorch Lightning. At least one GPU must be available on the system. The example can be run from the command line with:
python lightly/examples/pytorch_lightning_distributed/nnclr.py
The model differs in the following ways from the non-distributed implementation:
Distributed Data Parallel is enabled
Synchronized Batch Norm is used in place of standard Batch Norm
Note that Synchronized Batch Norm is optional and the model can also be trained without it. Without Synchronized Batch Norm the batch norm for each GPU is only calculated based on the features on that specific GPU.
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import NTXentLoss
from lightly.models.modules import (
NNCLRPredictionHead,
NNCLRProjectionHead,
NNMemoryBankModule,
)
from lightly.transforms.simclr_transform import SimCLRTransform
class NNCLR(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = NNCLRProjectionHead(512, 512, 128)
self.prediction_head = NNCLRPredictionHead(128, 512, 128)
self.memory_bank = NNMemoryBankModule(size=(4096, 128))
self.criterion = NTXentLoss()
def forward(self, x):
y = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(y)
p = self.prediction_head(z)
z = z.detach()
return z, p
def training_step(self, batch, batch_idx):
(x0, x1) = batch[0]
z0, p0 = self.forward(x0)
z1, p1 = self.forward(x1)
z0 = self.memory_bank(z0, update=False)
z1 = self.memory_bank(z1, update=True)
loss = 0.5 * (self.criterion(z0, p1) + self.criterion(z1, p0))
return loss
def configure_optimizers(self):
optim = torch.optim.SGD(self.parameters(), lr=0.06)
return optim
model = NNCLR()
transform = SimCLRTransform(input_size=32)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm
# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.
trainer = pl.Trainer(
max_epochs=10,
devices="auto",
accelerator="gpu",
strategy="ddp",
sync_batchnorm=True,
use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0
)
trainer.fit(model=model, train_dataloaders=dataloader)