SwaV
Example implementation of the SwaV architecture. This model takes advantage of contrastive methods without requiring to compute pairwise comparisons. Specifically, this method simultaneously clusters the data while enforcing consistency between cluster assignments produced for different augmentations of the same image, instead of comparing features directly as in contrastive learning. It can be trained with large and small batch sizes.
This example can be run from the command line with:
python lightly/examples/pytorch/swav.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import torch
import torchvision
from torch import nn
from lightly.loss import SwaVLoss
from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes
from lightly.transforms.swav_transform import SwaVTransform
class SwaV(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
self.projection_head = SwaVProjectionHead(512, 512, 128)
self.prototypes = SwaVPrototypes(128, n_prototypes=512)
def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
x = self.projection_head(x)
x = nn.functional.normalize(x, dim=1, p=2)
p = self.prototypes(x)
return p
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = SwaV(backbone)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
transform = SwaVTransform()
# we ignore object detection annotations by setting target_transform to return 0
def target_transform(t):
return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=target_transform,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=128,
shuffle=True,
drop_last=True,
num_workers=8,
)
criterion = SwaVLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
print("Starting Training")
for epoch in range(10):
total_loss = 0
for batch in dataloader:
views = batch[0]
model.prototypes.normalize()
multi_crop_features = [model(view.to(device)) for view in views]
high_resolution = multi_crop_features[:2]
low_resolution = multi_crop_features[2:]
loss = criterion(high_resolution, low_resolution)
total_loss += loss.detach()
loss.backward()
optimizer.step()
optimizer.zero_grad()
avg_loss = total_loss / len(dataloader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
This example can be run from the command line with:
python lightly/examples/pytorch_lightning/swav.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import SwaVLoss
from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes
from lightly.transforms.swav_transform import SwaVTransform
class SwaV(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = SwaVProjectionHead(512, 512, 128)
self.prototypes = SwaVPrototypes(128, n_prototypes=512)
self.criterion = SwaVLoss()
def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
x = self.projection_head(x)
x = nn.functional.normalize(x, dim=1, p=2)
p = self.prototypes(x)
return p
def training_step(self, batch, batch_idx):
self.prototypes.normalize()
views = batch[0]
multi_crop_features = [self.forward(view.to(self.device)) for view in views]
high_resolution = multi_crop_features[:2]
low_resolution = multi_crop_features[2:]
loss = self.criterion(high_resolution, low_resolution)
return loss
def configure_optimizers(self):
optim = torch.optim.Adam(self.parameters(), lr=0.001)
return optim
model = SwaV()
transform = SwaVTransform()
# we ignore object detection annotations by setting target_transform to return 0
def target_transform(t):
return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=target_transform,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=128,
shuffle=True,
drop_last=True,
num_workers=8,
)
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)
trainer.fit(model=model, train_dataloaders=dataloader)
This example runs on multiple gpus using Distributed Data Parallel (DDP) training with Pytorch Lightning. At least one GPU must be available on the system. The example can be run from the command line with:
python lightly/examples/pytorch_lightning_distributed/swav.py
The model differs in the following ways from the non-distributed implementation:
Distributed Data Parallel is enabled
Synchronized Batch Norm is used in place of standard Batch Norm
Distributed Sinkhorn is used in the loss calculation
Note that Synchronized Batch Norm and distributed Sinkhorn are optional and the model can also be trained without them. Without Synchronized Batch Norm and distributed Sinkhorn the batch norm and loss for each GPU are only calculated based on the features on that specific GPU.
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import SwaVLoss
from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes
from lightly.transforms.swav_transform import SwaVTransform
class SwaV(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = SwaVProjectionHead(512, 512, 128)
self.prototypes = SwaVPrototypes(128, n_prototypes=512)
# enable sinkhorn_gather_distributed to gather features from all gpus
# while running the sinkhorn algorithm in the loss calculation
self.criterion = SwaVLoss(sinkhorn_gather_distributed=True)
def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
x = self.projection_head(x)
x = nn.functional.normalize(x, dim=1, p=2)
p = self.prototypes(x)
return p
def training_step(self, batch, batch_idx):
self.prototypes.normalize()
views = batch[0]
multi_crop_features = [self.forward(view.to(self.device)) for view in views]
high_resolution = multi_crop_features[:2]
low_resolution = multi_crop_features[2:]
loss = self.criterion(high_resolution, low_resolution)
return loss
def configure_optimizers(self):
optim = torch.optim.Adam(self.parameters(), lr=0.001)
return optim
model = SwaV()
transform = SwaVTransform()
# we ignore object detection annotations by setting target_transform to return 0
def target_transform(t):
return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=target_transform,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=128,
shuffle=True,
drop_last=True,
num_workers=8,
)
# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm
# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.
trainer = pl.Trainer(
max_epochs=10,
devices="auto",
accelerator="gpu",
strategy="ddp",
sync_batchnorm=True,
use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0
)
trainer.fit(model=model, train_dataloaders=dataloader)
SwaV Queue
If you are planning to work with small batch sizes (less than 256), please use the SwaV implementation with queue:
This example can be run from the command line with:
python lightly/examples/pytorch/swav_queue.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import torch
import torchvision
from torch import nn
from lightly.loss import SwaVLoss
from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes
from lightly.models.modules.memory_bank import MemoryBankModule
from lightly.transforms.swav_transform import SwaVTransform
class SwaV(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
self.projection_head = SwaVProjectionHead(512, 512, 128)
self.prototypes = SwaVPrototypes(128, 512, 1)
self.start_queue_at_epoch = 2
self.queues = nn.ModuleList(
[MemoryBankModule(size=(3840, 128)) for _ in range(2)]
)
def forward(self, high_resolution, low_resolution, epoch):
self.prototypes.normalize()
high_resolution_features = [self._subforward(x) for x in high_resolution]
low_resolution_features = [self._subforward(x) for x in low_resolution]
high_resolution_prototypes = [
self.prototypes(x, epoch) for x in high_resolution_features
]
low_resolution_prototypes = [
self.prototypes(x, epoch) for x in low_resolution_features
]
queue_prototypes = self._get_queue_prototypes(high_resolution_features, epoch)
return high_resolution_prototypes, low_resolution_prototypes, queue_prototypes
def _subforward(self, input):
features = self.backbone(input).flatten(start_dim=1)
features = self.projection_head(features)
features = nn.functional.normalize(features, dim=1, p=2)
return features
@torch.no_grad()
def _get_queue_prototypes(self, high_resolution_features, epoch):
if len(high_resolution_features) != len(self.queues):
raise ValueError(
f"The number of queues ({len(self.queues)}) should be equal to the number of high "
f"resolution inputs ({len(high_resolution_features)}). Set `n_queues` accordingly."
)
# Get the queue features
queue_features = []
for i in range(len(self.queues)):
_, features = self.queues[i](high_resolution_features[i], update=True)
# Queue features are in (num_ftrs X queue_length) shape, while the high res
# features are in (batch_size X num_ftrs). Swap the axes for interoperability.
features = torch.permute(features, (1, 0))
queue_features.append(features)
# If loss calculation with queue prototypes starts at a later epoch,
# just queue the features and return None instead of queue prototypes.
if self.start_queue_at_epoch > 0 and epoch < self.start_queue_at_epoch:
return None
# Assign prototypes
queue_prototypes = [self.prototypes(x, epoch) for x in queue_features]
return queue_prototypes
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = SwaV(backbone)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
transform = SwaVTransform()
# we ignore object detection annotations by setting target_transform to return 0
def target_transform(t):
return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=target_transform,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=128,
shuffle=True,
drop_last=True,
num_workers=8,
)
criterion = SwaVLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
print("Starting Training")
for epoch in range(10):
total_loss = 0
for batch in dataloader:
views = batch[0]
views = [view.to(device) for view in views]
high_resolution, low_resolution = views[:2], views[2:]
high_resolution, low_resolution, queue = model(
high_resolution, low_resolution, epoch
)
loss = criterion(high_resolution, low_resolution, queue)
total_loss += loss.detach()
loss.backward()
optimizer.step()
optimizer.zero_grad()
avg_loss = total_loss / len(dataloader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
This example can be run from the command line with:
python lightly/examples/pytorch_lightning/swav_queue.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import SwaVLoss
from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes
from lightly.models.modules.memory_bank import MemoryBankModule
from lightly.transforms.swav_transform import SwaVTransform
class SwaV(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = SwaVProjectionHead(512, 512, 128)
self.prototypes = SwaVPrototypes(128, 512, 1)
self.start_queue_at_epoch = 2
self.queues = nn.ModuleList(
[MemoryBankModule(size=(3840, 128)) for _ in range(2)]
)
self.criterion = SwaVLoss()
def training_step(self, batch, batch_idx):
views = batch[0]
high_resolution, low_resolution = views[:2], views[2:]
self.prototypes.normalize()
high_resolution_features = [self._subforward(x) for x in high_resolution]
low_resolution_features = [self._subforward(x) for x in low_resolution]
high_resolution_prototypes = [
self.prototypes(x, self.current_epoch) for x in high_resolution_features
]
low_resolution_prototypes = [
self.prototypes(x, self.current_epoch) for x in low_resolution_features
]
queue_prototypes = self._get_queue_prototypes(high_resolution_features)
loss = self.criterion(
high_resolution_prototypes, low_resolution_prototypes, queue_prototypes
)
return loss
def configure_optimizers(self):
optim = torch.optim.Adam(self.parameters(), lr=0.001)
return optim
def _subforward(self, input):
features = self.backbone(input).flatten(start_dim=1)
features = self.projection_head(features)
features = nn.functional.normalize(features, dim=1, p=2)
return features
@torch.no_grad()
def _get_queue_prototypes(self, high_resolution_features):
if len(high_resolution_features) != len(self.queues):
raise ValueError(
f"The number of queues ({len(self.queues)}) should be equal to the number of high "
f"resolution inputs ({len(high_resolution_features)}). Set `n_queues` accordingly."
)
# Get the queue features
queue_features = []
for i in range(len(self.queues)):
_, features = self.queues[i](high_resolution_features[i], update=True)
# Queue features are in (num_ftrs X queue_length) shape, while the high res
# features are in (batch_size X num_ftrs). Swap the axes for interoperability.
features = torch.permute(features, (1, 0))
queue_features.append(features)
# If loss calculation with queue prototypes starts at a later epoch,
# just queue the features and return None instead of queue prototypes.
if (
self.start_queue_at_epoch > 0
and self.current_epoch < self.start_queue_at_epoch
):
return None
# Assign prototypes
queue_prototypes = [
self.prototypes(x, self.current_epoch) for x in queue_features
]
return queue_prototypes
model = SwaV()
transform = SwaVTransform()
# we ignore object detection annotations by setting target_transform to return 0
def target_transform(t):
return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=target_transform,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=128,
shuffle=True,
drop_last=True,
num_workers=8,
)
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)
trainer.fit(model=model, train_dataloaders=dataloader)