DCL & DCLW

Example implementation of the Decoupled Contrastive Learning (DCL) architecture. DCL is based on the SimCLR architecture and only introduces a new loss function. The new loss is called DCL loss and comes also with a weighted form called DCLW loss. DCL improves upon the widely used NTXent loss (or InfoNCE loss) by removing a negative-positive-coupling effect present in those losses. This speeds up model training and allows the usage of smaller batch sizes.

Reference:

Decoupled Contrastive Learning, 2021

DCL is identical to SimCLR but uses DCLLoss instead of NTXentLoss. To use it you can copy the example code from SimCLR and make the following adjustments:

# instead of this
from lightly.loss import NTXentLoss
criterion = NTXentLoss()

# use this
from lightly.loss import DCLLoss
criterion = DCLLoss()

Below you can also find fully runnable examples using the SimCLR architecture with DCL loss.

This example can be run from the command line with:

python lightly/examples/pytorch/dcl.py
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.

import torch
from torch import nn
import torchvision

from lightly.data import LightlyDataset
from lightly.data import SimCLRCollateFunction
from lightly.loss import DCLLoss, DCLWLoss
from lightly.models.modules import SimCLRProjectionHead


class DCL(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = SimCLRProjectionHead(512, 512, 128)

    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z


resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = DCL(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True)
dataset = LightlyDataset.from_torch_dataset(cifar10)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

collate_fn = SimCLRCollateFunction(
    input_size=32,
    gaussian_blur=0.,
)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

criterion = DCLLoss()
# or use the weighted DCLW loss:
# criterion = DCLWLoss()

optimizer = torch.optim.SGD(model.parameters(), lr=0.06)

print("Starting Training")
for epoch in range(10):
    total_loss = 0
    for (x0, x1), _, _ in dataloader:
        x0 = x0.to(device)
        x1 = x1.to(device)
        z0 = model(x0)
        z1 = model(x1)
        loss = criterion(z0, z1)
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")