NNCLR

NNCLR is a self-supervised framework for visual representation learning that builds upon contrastive methods. It shares similarities with SimCLR, such as using two augmented views of the same image, projection and prediction heads, and a contrastive loss. However, it introduces key modifications:

  1. Nearest Neighbor Replacement: Instead of directly comparing two augmented views of the same sample, NNCLR replaces each sample with its nearest neighbor in a support set (or memory bank). This increases semantic variation in the learned representations.

  2. Symmetric Loss: The contrastive loss is made symmetric to improve training stability.

  3. Architectural Adjustments: NNCLR employs different sizes for projection and prediction head layers compared to SimCLR.

These improvements result in significantly better performance across multiple self-supervised learning benchmarks. Compared to SimCLR and other self-supervised methods, NNCLR achieves:

  • Higher ImageNet linear evaluation accuracy.

  • Improved semi-supervised learning results.

  • Superior performance on transfer learning tasks, outperforming BYOL, SimCLR, and even supervised ImageNet pretraining in 8 out of 12 benchmarked cases.

Key Components

  • Data Augmentations: NNCLR applies the same transformations as SimCLR, including random cropping, resizing, color jittering, and Gaussian blur, to create diverse views of the same image.

  • Backbone: A convolutional neural network (typically ResNet) encodes augmented images into feature representations.

  • Projection Head: A multilayer perceptron (MLP) maps features into a contrastive space, improving representation learning.

  • Memory Bank: NNCLR maintains a first-in, first-out (FIFO) memory bank, storing past feature representations. Older features are gradually discarded, ensuring a large and diverse set approximating the full dataset.

  • Nearest Neighbor Sampling: Each feature representation is replaced by its nearest neighbor from the memory bank, introducing additional semantic variation beyond standard augmentations.

  • Contrastive Loss: NNCLR employs normalized temperature-scaled cross-entropy loss (NT-Xent), encouraging alignment between positive pairs and separation from negative pairs.

Good to Know

  • Optimized for CNNs: NNCLR is specifically designed for convolutional neural networks (CNNs), particularly ResNet. It is not recommended for transformer-based architectures.

  • Augmentation Robustness: Compared to SimCLR, NNCLR is less dependent on strong augmentations since nearest neighbor sampling introduces natural semantic variation. However, performance still benefits from well-chosen augmentations and larger batch sizes.

Reference:

With a Little Help from My Friends: Nearest-Neighbor Contrastive Learning of Visual Representations, 2021

https://img.shields.io/badge/Open%20in%20Colab-blue?logo=googlecolab&label=%20&labelColor=5c5c5c

This example can be run from the command line with:

python lightly/examples/pytorch/nnclr.py
# This example requires the following dependencies to be installed:
# pip install lightly

# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.

import torch
import torchvision
from torch import nn

from lightly.loss import NTXentLoss
from lightly.models.modules import (
    NNCLRPredictionHead,
    NNCLRProjectionHead,
    NNMemoryBankModule,
)
from lightly.transforms.simclr_transform import SimCLRTransform


class NNCLR(nn.Module):
    def __init__(self, backbone):
        super().__init__()

        self.backbone = backbone
        self.projection_head = NNCLRProjectionHead(512, 512, 128)
        self.prediction_head = NNCLRPredictionHead(128, 512, 128)

    def forward(self, x):
        y = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(y)
        p = self.prediction_head(z)
        z = z.detach()
        return z, p


resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = NNCLR(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

memory_bank = NNMemoryBankModule(size=(4096, 128))
memory_bank.to(device)

transform = SimCLRTransform(input_size=32)
dataset = torchvision.datasets.CIFAR10(
    "datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

criterion = NTXentLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)

print("Starting Training")
for epoch in range(10):
    total_loss = 0
    for batch in dataloader:
        x0, x1 = batch[0]
        x0 = x0.to(device)
        x1 = x1.to(device)
        z0, p0 = model(x0)
        z1, p1 = model(x1)
        z0 = memory_bank(z0, update=False)
        z1 = memory_bank(z1, update=True)
        loss = 0.5 * (criterion(z0, p1) + criterion(z1, p0))
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")