BYOL
Example implementation of the BYOL architecture.
This example can be run from the command line with:
python lightly/examples/pytorch/byol.py
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import torch
from torch import nn
import torchvision
import copy
from lightly.data import LightlyDataset
from lightly.data import SimCLRCollateFunction
from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import BYOLProjectionHead, BYOLPredictionHead
from lightly.models.utils import deactivate_requires_grad
from lightly.models.utils import update_momentum
class BYOL(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
self.projection_head = BYOLProjectionHead(512, 1024, 256)
self.prediction_head = BYOLPredictionHead(256, 1024, 256)
self.backbone_momentum = copy.deepcopy(self.backbone)
self.projection_head_momentum = copy.deepcopy(self.projection_head)
deactivate_requires_grad(self.backbone_momentum)
deactivate_requires_grad(self.projection_head_momentum)
def forward(self, x):
y = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(y)
p = self.prediction_head(z)
return p
def forward_momentum(self, x):
y = self.backbone_momentum(x).flatten(start_dim=1)
z = self.projection_head_momentum(y)
z = z.detach()
return z
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = BYOL(backbone)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True)
dataset = LightlyDataset.from_torch_dataset(cifar10)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
collate_fn = SimCLRCollateFunction(input_size=32)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
collate_fn=collate_fn,
shuffle=True,
drop_last=True,
num_workers=8,
)
criterion = NegativeCosineSimilarity()
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)
print("Starting Training")
for epoch in range(10):
total_loss = 0
for (x0, x1), _, _ in dataloader:
update_momentum(model.backbone, model.backbone_momentum, m=0.99)
update_momentum(model.projection_head, model.projection_head_momentum, m=0.99)
x0 = x0.to(device)
x1 = x1.to(device)
p0 = model(x0)
z0 = model.forward_momentum(x0)
p1 = model(x1)
z1 = model.forward_momentum(x1)
loss = 0.5 * (criterion(p0, z1) + criterion(p1, z0))
total_loss += loss.detach()
loss.backward()
optimizer.step()
optimizer.zero_grad()
avg_loss = total_loss / len(dataloader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
This example can be run from the command line with:
python lightly/examples/pytorch_lightning/byol.py
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import torch
from torch import nn
import torchvision
import copy
import pytorch_lightning as pl
from lightly.data import LightlyDataset
from lightly.data import SimCLRCollateFunction
from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import BYOLProjectionHead, BYOLPredictionHead
from lightly.models.utils import deactivate_requires_grad
from lightly.models.utils import update_momentum
class BYOL(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = BYOLProjectionHead(512, 1024, 256)
self.prediction_head = BYOLPredictionHead(256, 1024, 256)
self.backbone_momentum = copy.deepcopy(self.backbone)
self.projection_head_momentum = copy.deepcopy(self.projection_head)
deactivate_requires_grad(self.backbone_momentum)
deactivate_requires_grad(self.projection_head_momentum)
self.criterion = NegativeCosineSimilarity()
def forward(self, x):
y = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(y)
p = self.prediction_head(z)
return p
def forward_momentum(self, x):
y = self.backbone_momentum(x).flatten(start_dim=1)
z = self.projection_head_momentum(y)
z = z.detach()
return z
def training_step(self, batch, batch_idx):
update_momentum(self.backbone, self.backbone_momentum, m=0.99)
update_momentum(self.projection_head, self.projection_head_momentum, m=0.99)
(x0, x1), _, _ = batch
p0 = self.forward(x0)
z0 = self.forward_momentum(x0)
p1 = self.forward(x1)
z1 = self.forward_momentum(x1)
loss = 0.5 * (self.criterion(p0, z1) + self.criterion(p1, z0))
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.06)
model = BYOL()
cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True)
dataset = LightlyDataset.from_torch_dataset(cifar10)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
collate_fn = SimCLRCollateFunction(input_size=32)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
collate_fn=collate_fn,
shuffle=True,
drop_last=True,
num_workers=8,
)
gpus = 1 if torch.cuda.is_available() else 0
trainer = pl.Trainer(max_epochs=10, gpus=gpus)
trainer.fit(model=model, train_dataloaders=dataloader)
This example runs on multiple gpus using Distributed Data Parallel (DDP) training with Pytorch Lightning. At least one GPU must be available on the system. The example can be run from the command line with:
python lightly/examples/pytorch_lightning_distributed/byol.py
The model differs in the following ways from the non-distributed implementation:
Distributed Data Parallel is enabled
Synchronized Batch Norm is used in place of standard Batch Norm
Note that Synchronized Batch Norm is optional and the model can also be trained without it. Without Synchronized Batch Norm the batch norm for each GPU is only calculated based on the features on that specific GPU.
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import torch
from torch import nn
import torchvision
import copy
import pytorch_lightning as pl
from lightly.data import LightlyDataset
from lightly.data import SimCLRCollateFunction
from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import BYOLProjectionHead
from lightly.models.modules.heads import BYOLPredictionHead
from lightly.models.utils import deactivate_requires_grad
from lightly.models.utils import update_momentum
class BYOL(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
self.projection_head = BYOLProjectionHead(512, 1024, 256)
self.prediction_head = BYOLPredictionHead(256, 1024, 256)
self.backbone_momentum = copy.deepcopy(self.backbone)
self.projection_head_momentum = copy.deepcopy(self.projection_head)
deactivate_requires_grad(self.backbone_momentum)
deactivate_requires_grad(self.projection_head_momentum)
self.criterion = NegativeCosineSimilarity()
def forward(self, x):
y = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(y)
p = self.prediction_head(z)
return p
def forward_momentum(self, x):
y = self.backbone_momentum(x).flatten(start_dim=1)
z = self.projection_head_momentum(y)
z = z.detach()
return z
def training_step(self, batch, batch_idx):
update_momentum(self.backbone, self.backbone_momentum, m=0.99)
update_momentum(self.projection_head, self.projection_head_momentum, m=0.99)
(x0, x1), _, _ = batch
p0 = self.forward(x0)
z0 = self.forward_momentum(x0)
p1 = self.forward(x1)
z1 = self.forward_momentum(x1)
loss = 0.5 * (self.criterion(p0, z1) + self.criterion(p1, z0))
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.06)
model = BYOL()
cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True)
dataset = LightlyDataset.from_torch_dataset(cifar10)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
collate_fn = SimCLRCollateFunction(input_size=32)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
collate_fn=collate_fn,
shuffle=True,
drop_last=True,
num_workers=8,
)
gpus = torch.cuda.device_count()
# train with DDP and use Synchronized Batch Norm for a more accurate batch norm
# calculation
trainer = pl.Trainer(
max_epochs=10,
gpus=gpus,
strategy='ddp',
sync_batchnorm=True,
)
trainer.fit(model=model, train_dataloaders=dataloader)