DenseCL
Example implementation of the DenseCL architecture. DenseCL is an extension of MoCo that uses a dense contrastive loss to improve the quality of the learned representations for object detection and segmentation tasks. While initially designed for MoCo, DenseCL can also be combined with other self-supervised learning methods.
This example can be run from the command line with:
python lightly/examples/pytorch/densecl.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import copy
import torch
import torchvision
from torch import nn
from lightly.loss import NTXentLoss
from lightly.models import utils
from lightly.models.modules import DenseCLProjectionHead
from lightly.transforms import DenseCLTransform
from lightly.utils.scheduler import cosine_schedule
class DenseCL(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
self.projection_head_global = DenseCLProjectionHead(512, 512, 128)
self.projection_head_local = DenseCLProjectionHead(512, 512, 128)
self.backbone_momentum = copy.deepcopy(self.backbone)
self.projection_head_global_momentum = copy.deepcopy(
self.projection_head_global
)
self.projection_head_local_momentum = copy.deepcopy(self.projection_head_local)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
utils.deactivate_requires_grad(self.backbone_momentum)
utils.deactivate_requires_grad(self.projection_head_global_momentum)
utils.deactivate_requires_grad(self.projection_head_local_momentum)
def forward(self, x):
query_features = self.backbone(x)
query_global = self.pool(query_features).flatten(start_dim=1)
query_global = self.projection_head_global(query_global)
query_features = query_features.flatten(start_dim=2).permute(0, 2, 1)
query_local = self.projection_head_local(query_features)
# Shapes: (B, H*W, C), (B, D), (B, H*W, D)
return query_features, query_global, query_local
@torch.no_grad()
def forward_momentum(self, x):
key_features = self.backbone(x)
key_global = self.pool(key_features).flatten(start_dim=1)
key_global = self.projection_head_global(key_global)
key_features = key_features.flatten(start_dim=2).permute(0, 2, 1)
key_local = self.projection_head_local(key_features)
return key_features, key_global, key_local
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-2])
model = DenseCL(backbone)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
transform = DenseCLTransform(input_size=32)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
criterion_global = NTXentLoss(memory_bank_size=(4096, 128))
criterion_local = NTXentLoss(memory_bank_size=(4096, 128))
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)
epochs = 10
print("Starting Training")
for epoch in range(epochs):
total_loss = 0
momentum = cosine_schedule(epoch, epochs, 0.996, 1)
for batch in dataloader:
x_query, x_key = batch[0]
utils.update_momentum(model.backbone, model.backbone_momentum, m=momentum)
utils.update_momentum(
model.projection_head_global,
model.projection_head_global_momentum,
m=momentum,
)
utils.update_momentum(
model.projection_head_local,
model.projection_head_local_momentum,
m=momentum,
)
x_query = x_query.to(device)
x_key = x_key.to(device)
query_features, query_global, query_local = model(x_query)
key_features, key_global, key_local = model.forward_momentum(x_key)
key_local = utils.select_most_similar(query_features, key_features, key_local)
query_local = query_local.flatten(end_dim=1)
key_local = key_local.flatten(end_dim=1)
loss_global = criterion_global(query_global, key_global)
loss_local = criterion_local(query_local, key_local)
lambda_ = 0.5
loss = (1 - lambda_) * loss_global + lambda_ * loss_local
total_loss += loss.detach()
loss.backward()
optimizer.step()
optimizer.zero_grad()
avg_loss = total_loss / len(dataloader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
This example can be run from the command line with:
python lightly/examples/pytorch_lightning/densecl.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import copy
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import NTXentLoss
from lightly.models import utils
from lightly.models.modules import DenseCLProjectionHead
from lightly.transforms import DenseCLTransform
from lightly.utils.scheduler import cosine_schedule
class DenseCL(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-2])
self.projection_head_global = DenseCLProjectionHead(512, 512, 128)
self.projection_head_local = DenseCLProjectionHead(512, 512, 128)
self.backbone_momentum = copy.deepcopy(self.backbone)
self.projection_head_global_momentum = copy.deepcopy(
self.projection_head_global
)
self.projection_head_local_momentum = copy.deepcopy(self.projection_head_local)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
utils.deactivate_requires_grad(self.backbone_momentum)
utils.deactivate_requires_grad(self.projection_head_global_momentum)
utils.deactivate_requires_grad(self.projection_head_local_momentum)
self.criterion_global = NTXentLoss(memory_bank_size=(4096, 128))
self.criterion_local = NTXentLoss(memory_bank_size=(4096, 128))
def forward(self, x):
query_features = self.backbone(x)
query_global = self.pool(query_features).flatten(start_dim=1)
query_global = self.projection_head_global(query_global)
query_features = query_features.flatten(start_dim=2).permute(0, 2, 1)
query_local = self.projection_head_local(query_features)
# Shapes: (B, H*W, C), (B, D), (B, H*W, D)
return query_features, query_global, query_local
@torch.no_grad()
def forward_momentum(self, x):
key_features = self.backbone(x)
key_global = self.pool(key_features).flatten(start_dim=1)
key_global = self.projection_head_global(key_global)
key_features = key_features.flatten(start_dim=2).permute(0, 2, 1)
key_local = self.projection_head_local(key_features)
return key_features, key_global, key_local
def training_step(self, batch, batch_idx):
momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)
utils.update_momentum(model.backbone, model.backbone_momentum, m=momentum)
utils.update_momentum(
model.projection_head_global,
model.projection_head_global_momentum,
m=momentum,
)
utils.update_momentum(
model.projection_head_local,
model.projection_head_local_momentum,
m=momentum,
)
x_query, x_key = batch[0]
query_features, query_global, query_local = self(x_query)
key_features, key_global, key_local = self.forward_momentum(x_key)
key_local = utils.select_most_similar(query_features, key_features, key_local)
query_local = query_local.flatten(end_dim=1)
key_local = key_local.flatten(end_dim=1)
loss_global = self.criterion_global(query_global, key_global)
loss_local = self.criterion_local(query_local, key_local)
lambda_ = 0.5
loss = (1 - lambda_) * loss_global + lambda_ * loss_local
return loss
def configure_optimizers(self):
optim = torch.optim.SGD(self.parameters(), lr=0.06)
return optim
model = DenseCL()
transform = DenseCLTransform(input_size=32)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)
trainer.fit(model=model, train_dataloaders=dataloader)
This example runs on multiple gpus using Distributed Data Parallel (DDP) training with Pytorch Lightning. At least one GPU must be available on the system. The example can be run from the command line with:
python lightly/examples/pytorch_lightning_distributed/densecl.py
The model differs in the following ways from the non-distributed implementation:
Distributed Data Parallel is enabled
Synchronized Batch Norm is used in place of standard Batch Norm
Note that Synchronized Batch Norm is optional and the model can also be trained without it. Without Synchronized Batch Norm the batch norm for each GPU is only calculated based on the features on that specific GPU.
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import copy
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import NTXentLoss
from lightly.models import utils
from lightly.models.modules import DenseCLProjectionHead
from lightly.transforms import DenseCLTransform
from lightly.utils.scheduler import cosine_schedule
class DenseCL(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-2])
self.projection_head_global = DenseCLProjectionHead(512, 512, 128)
self.projection_head_local = DenseCLProjectionHead(512, 512, 128)
self.backbone_momentum = copy.deepcopy(self.backbone)
self.projection_head_global_momentum = copy.deepcopy(
self.projection_head_global
)
self.projection_head_local_momentum = copy.deepcopy(self.projection_head_local)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
utils.deactivate_requires_grad(self.backbone_momentum)
utils.deactivate_requires_grad(self.projection_head_global_momentum)
utils.deactivate_requires_grad(self.projection_head_local_momentum)
self.criterion_global = NTXentLoss(memory_bank_size=(4096, 128))
self.criterion_local = NTXentLoss(memory_bank_size=(4096, 128))
def forward(self, x):
query_features = self.backbone(x)
query_global = self.pool(query_features).flatten(start_dim=1)
query_global = self.projection_head_global(query_global)
query_features = query_features.flatten(start_dim=2).permute(0, 2, 1)
query_local = self.projection_head_local(query_features)
# Shapes: (B, H*W, C), (B, D), (B, H*W, D)
return query_features, query_global, query_local
@torch.no_grad()
def forward_momentum(self, x):
key_features = self.backbone(x)
key_global = self.pool(key_features).flatten(start_dim=1)
key_global = self.projection_head_global(key_global)
key_features = key_features.flatten(start_dim=2).permute(0, 2, 1)
key_local = self.projection_head_local(key_features)
return key_features, key_global, key_local
def training_step(self, batch, batch_idx):
momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)
utils.update_momentum(model.backbone, model.backbone_momentum, m=momentum)
utils.update_momentum(
model.projection_head_global,
model.projection_head_global_momentum,
m=momentum,
)
utils.update_momentum(
model.projection_head_local,
model.projection_head_local_momentum,
m=momentum,
)
x_query, x_key = batch[0]
query_features, query_global, query_local = self(x_query)
key_features, key_global, key_local = self.forward_momentum(x_key)
key_local = utils.select_most_similar(query_features, key_features, key_local)
query_local = query_local.flatten(end_dim=1)
key_local = key_local.flatten(end_dim=1)
loss_global = self.criterion_global(query_global, key_global)
loss_local = self.criterion_local(query_local, key_local)
lambda_ = 0.5
loss = (1 - lambda_) * loss_global + lambda_ * loss_local
return loss
def configure_optimizers(self):
optim = torch.optim.SGD(self.parameters(), lr=0.06)
return optim
model = DenseCL()
transform = DenseCLTransform(input_size=32)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm
# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.
trainer = pl.Trainer(
max_epochs=10,
devices="auto",
accelerator="gpu",
strategy="ddp",
sync_batchnorm=True,
use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0
)
trainer.fit(model=model, train_dataloaders=dataloader)