DCL & DCLW
Example implementation of the Decoupled Contrastive Learning (DCL) architecture. DCL is based on the SimCLR architecture and only introduces a new loss function. The new loss is called DCL loss and comes also with a weighted form called DCLW loss. DCL improves upon the widely used NTXent loss (or InfoNCE loss) by removing a negative-positive-coupling effect present in those losses. This speeds up model training and allows the usage of smaller batch sizes.
- Reference:
DCL is identical to SimCLR but uses DCLLoss
instead of NTXentLoss
. To use it you can
copy the example code from SimCLR and make the following adjustments:
# instead of this
from lightly.loss import NTXentLoss
criterion = NTXentLoss()
# use this
from lightly.loss import DCLLoss
criterion = DCLLoss()
Below you can also find fully runnable examples using the SimCLR architecture with DCL loss.
This example can be run from the command line with:
python lightly/examples/pytorch/dcl.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import torch
import torchvision
from torch import nn
from lightly.loss import DCLLoss
from lightly.models.modules import SimCLRProjectionHead
from lightly.transforms.simclr_transform import SimCLRTransform
class DCL(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
self.projection_head = SimCLRProjectionHead(512, 512, 128)
def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(x)
return z
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = DCL(backbone)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
transform = SimCLRTransform(input_size=32)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
criterion = DCLLoss()
# or use the weighted DCLW loss:
# criterion = DCLWLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)
print("Starting Training")
for epoch in range(10):
total_loss = 0
for batch in dataloader:
x0, x1 = batch[0]
x0 = x0.to(device)
x1 = x1.to(device)
z0 = model(x0)
z1 = model(x1)
loss = criterion(z0, z1)
total_loss += loss.detach()
loss.backward()
optimizer.step()
optimizer.zero_grad()
avg_loss = total_loss / len(dataloader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
This example can be run from the command line with:
python lightly/examples/pytorch_lightning/dcl.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import DCLLoss
from lightly.models.modules import SimCLRProjectionHead
from lightly.transforms.simclr_transform import SimCLRTransform
class DCL(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = SimCLRProjectionHead(512, 2048, 2048)
self.criterion = DCLLoss()
# or use the weighted DCLW loss:
# self.criterion = DCLWLoss()
def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(x)
return z
def training_step(self, batch, batch_index):
(x0, x1) = batch[0]
z0 = self.forward(x0)
z1 = self.forward(x1)
loss = self.criterion(z0, z1)
return loss
def configure_optimizers(self):
optim = torch.optim.SGD(self.parameters(), lr=0.06)
return optim
model = DCL()
transform = SimCLRTransform(input_size=32)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)
trainer.fit(model=model, train_dataloaders=dataloader)
This example runs on multiple gpus using Distributed Data Parallel (DDP) training with Pytorch Lightning. At least one GPU must be available on the system. The example can be run from the command line with:
python lightly/examples/pytorch_lightning_distributed/dcl.py
The model differs in the following ways from the non-distributed implementation:
Distributed Data Parallel is enabled
Synchronized Batch Norm is used in place of standard Batch Norm
Features are gathered from all GPUs before the loss is calculated
Note that Synchronized Batch Norm and feature gathering are optional and the model can also be trained without them. Without Synchronized Batch Norm and feature gathering the batch norm and loss for each GPU are only calculated based on the features on that specific GPU.
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import DCLLoss
from lightly.models.modules import SimCLRProjectionHead
from lightly.transforms.simclr_transform import SimCLRTransform
class DCL(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = SimCLRProjectionHead(512, 2048, 2048)
# enable gather_distributed to gather features from all gpus
# before calculating the loss
self.criterion = DCLLoss(gather_distributed=True)
# or use the weighted DCLW loss:
# self.criterion = DCLWLoss(gather_distributed=True)
def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(x)
return z
def training_step(self, batch, batch_index):
(x0, x1) = batch[0]
z0 = self.forward(x0)
z1 = self.forward(x1)
loss = self.criterion(z0, z1)
return loss
def configure_optimizers(self):
optim = torch.optim.SGD(self.parameters(), lr=0.06)
return optim
model = DCL()
transform = SimCLRTransform(input_size=32)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm
# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.
trainer = pl.Trainer(
max_epochs=10,
devices="auto",
accelerator="gpu",
strategy="ddp",
sync_batchnorm=True,
use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0
)
trainer.fit(model=model, train_dataloaders=dataloader)