FastSiam
Example implementation of the FastSiam architecture. FastSiam is an extension of the well-known SimSiam architecture. It is a self-supervised learning method that averages multiple target predictions to improve training with small batch sizes.
This example can be run from the command line with:
python lightly/examples/pytorch/fastsiam.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import torch
import torchvision
from torch import nn
from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead
from lightly.transforms import FastSiamTransform
class FastSiam(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
self.projection_head = SimSiamProjectionHead(512, 512, 128)
self.prediction_head = SimSiamPredictionHead(128, 64, 128)
def forward(self, x):
f = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(f)
p = self.prediction_head(z)
z = z.detach()
return z, p
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = FastSiam(backbone)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
transform = FastSiamTransform(input_size=32)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
criterion = NegativeCosineSimilarity()
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)
print("Starting Training")
for epoch in range(10):
total_loss = 0
for batch in dataloader:
views = batch[0]
features = [model(view.to(device)) for view in views]
zs = torch.stack([z for z, _ in features])
ps = torch.stack([p for _, p in features])
loss = 0.0
for i in range(len(views)):
mask = torch.arange(len(views), device=device) != i
loss += criterion(ps[i], torch.mean(zs[mask], dim=0)) / len(views)
total_loss += loss.detach()
loss.backward()
optimizer.step()
optimizer.zero_grad()
avg_loss = total_loss / len(dataloader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
This example can be run from the command line with:
python lightly/examples/pytorch_lightning/fastsiam.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead
from lightly.transforms import FastSiamTransform
class FastSiam(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = SimSiamProjectionHead(512, 512, 128)
self.prediction_head = SimSiamPredictionHead(128, 64, 128)
self.criterion = NegativeCosineSimilarity()
def forward(self, x):
f = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(f)
p = self.prediction_head(z)
z = z.detach()
return z, p
def training_step(self, batch, batch_idx):
views = batch[0]
features = [self.forward(view) for view in views]
zs = torch.stack([z for z, _ in features])
ps = torch.stack([p for _, p in features])
loss = 0.0
for i in range(len(views)):
mask = torch.arange(len(views), device=self.device) != i
loss += self.criterion(ps[i], torch.mean(zs[mask], dim=0)) / len(views)
self.log("train_loss_ssl", loss)
return loss
def configure_optimizers(self):
optim = torch.optim.SGD(self.parameters(), lr=0.06)
return optim
model = FastSiam()
transform = FastSiamTransform(input_size=32)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)
trainer.fit(model=model, train_dataloaders=dataloader)
This example runs on multiple gpus using Distributed Data Parallel (DDP) training with Pytorch Lightning. At least one GPU must be available on the system. The example can be run from the command line with:
python lightly/examples/pytorch_lightning_distributed/fastsiam.py
The model differs in the following ways from the non-distributed implementation:
Distributed Data Parallel is enabled
Synchronized Batch Norm is used in place of standard Batch Norm
Note that Synchronized Batch Norm is optional and the model can also be trained without it. Without Synchronized Batch Norm the batch norm for each GPU is only calculated based on the features on that specific GPU.
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead
from lightly.transforms import FastSiamTransform
class FastSiam(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = SimSiamProjectionHead(512, 512, 128)
self.prediction_head = SimSiamPredictionHead(128, 64, 128)
self.criterion = NegativeCosineSimilarity()
def forward(self, x):
f = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(f)
p = self.prediction_head(z)
z = z.detach()
return z, p
def training_step(self, batch, batch_idx):
views = batch[0]
features = [self.forward(view) for view in views]
zs = torch.stack([z for z, _ in features])
ps = torch.stack([p for _, p in features])
loss = 0.0
for i in range(len(views)):
mask = torch.arange(len(views), device=self.device) != i
loss += self.criterion(ps[i], torch.mean(zs[mask], dim=0)) / len(views)
self.log("train_loss_ssl", loss)
return loss
def configure_optimizers(self):
optim = torch.optim.SGD(self.parameters(), lr=0.06)
return optim
model = FastSiam()
transform = FastSiamTransform(input_size=32)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm
# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.
trainer = pl.Trainer(
max_epochs=10,
devices="auto",
accelerator="gpu",
strategy="ddp",
sync_batchnorm=True,
use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0
)
trainer.fit(model=model, train_dataloaders=dataloader)