DINO
Example implementation of the DINO architecture.
This example can be run from the command line with:
python lightly/examples/pytorch/dino.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import copy
import torch
import torchvision
from torch import nn
from lightly.loss import DINOLoss
from lightly.models.modules import DINOProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.dino_transform import DINOTransform
from lightly.utils.scheduler import cosine_schedule
class DINO(torch.nn.Module):
def __init__(self, backbone, input_dim):
super().__init__()
self.student_backbone = backbone
self.student_head = DINOProjectionHead(
input_dim, 512, 64, 2048, freeze_last_layer=1
)
self.teacher_backbone = copy.deepcopy(backbone)
self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)
deactivate_requires_grad(self.teacher_backbone)
deactivate_requires_grad(self.teacher_head)
def forward(self, x):
y = self.student_backbone(x).flatten(start_dim=1)
z = self.student_head(y)
return z
def forward_teacher(self, x):
y = self.teacher_backbone(x).flatten(start_dim=1)
z = self.teacher_head(y)
return z
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
input_dim = 512
# instead of a resnet you can also use a vision transformer backbone as in the
# original paper (you might have to reduce the batch size in this case):
# backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=False)
# input_dim = backbone.embed_dim
model = DINO(backbone, input_dim)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
transform = DINOTransform()
# we ignore object detection annotations by setting target_transform to return 0
def target_transform(t):
return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=target_transform,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=64,
shuffle=True,
drop_last=True,
num_workers=8,
)
criterion = DINOLoss(
output_dim=2048,
warmup_teacher_temp_epochs=5,
)
# move loss to correct device because it also contains parameters
criterion = criterion.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 10
print("Starting Training")
for epoch in range(epochs):
total_loss = 0
momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)
for batch in dataloader:
views = batch[0]
update_momentum(model.student_backbone, model.teacher_backbone, m=momentum_val)
update_momentum(model.student_head, model.teacher_head, m=momentum_val)
views = [view.to(device) for view in views]
global_views = views[:2]
teacher_out = [model.forward_teacher(view) for view in global_views]
student_out = [model.forward(view) for view in views]
loss = criterion(teacher_out, student_out, epoch=epoch)
total_loss += loss.detach()
loss.backward()
# We only cancel gradients of student head.
model.student_head.cancel_last_layer_gradients(current_epoch=epoch)
optimizer.step()
optimizer.zero_grad()
avg_loss = total_loss / len(dataloader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
This example can be run from the command line with:
python lightly/examples/pytorch_lightning/dino.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import copy
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import DINOLoss
from lightly.models.modules import DINOProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.dino_transform import DINOTransform
from lightly.utils.scheduler import cosine_schedule
class DINO(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
input_dim = 512
# instead of a resnet you can also use a vision transformer backbone as in the
# original paper (you might have to reduce the batch size in this case):
# backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=False)
# input_dim = backbone.embed_dim
self.student_backbone = backbone
self.student_head = DINOProjectionHead(
input_dim, 512, 64, 2048, freeze_last_layer=1
)
self.teacher_backbone = copy.deepcopy(backbone)
self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)
deactivate_requires_grad(self.teacher_backbone)
deactivate_requires_grad(self.teacher_head)
self.criterion = DINOLoss(output_dim=2048, warmup_teacher_temp_epochs=5)
def forward(self, x):
y = self.student_backbone(x).flatten(start_dim=1)
z = self.student_head(y)
return z
def forward_teacher(self, x):
y = self.teacher_backbone(x).flatten(start_dim=1)
z = self.teacher_head(y)
return z
def training_step(self, batch, batch_idx):
momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)
update_momentum(self.student_backbone, self.teacher_backbone, m=momentum)
update_momentum(self.student_head, self.teacher_head, m=momentum)
views = batch[0]
views = [view.to(self.device) for view in views]
global_views = views[:2]
teacher_out = [self.forward_teacher(view) for view in global_views]
student_out = [self.forward(view) for view in views]
loss = self.criterion(teacher_out, student_out, epoch=self.current_epoch)
return loss
def on_after_backward(self):
self.student_head.cancel_last_layer_gradients(current_epoch=self.current_epoch)
def configure_optimizers(self):
optim = torch.optim.Adam(self.parameters(), lr=0.001)
return optim
model = DINO()
transform = DINOTransform()
# we ignore object detection annotations by setting target_transform to return 0
def target_transform(t):
return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=target_transform,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=64,
shuffle=True,
drop_last=True,
num_workers=8,
)
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)
trainer.fit(model=model, train_dataloaders=dataloader)
This example runs on multiple gpus using Distributed Data Parallel (DDP) training with Pytorch Lightning. At least one GPU must be available on the system. The example can be run from the command line with:
python lightly/examples/pytorch_lightning_distributed/dino.py
The model differs in the following ways from the non-distributed implementation:
Distributed Data Parallel is enabled
Synchronized Batch Norm is used in place of standard Batch Norm
Distributed Sampling is used in the dataloader
Note that Synchronized Batch Norm is optional and the model can also be trained without it. Without Synchronized Batch Norm the batch norm for each GPU is only calculated based on the features on that specific GPU. Distributed Sampling makes sure that each distributed process sees only a subset of the data.
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import copy
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import DINOLoss
from lightly.models.modules import DINOProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.dino_transform import DINOTransform
from lightly.utils.scheduler import cosine_schedule
class DINO(pl.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
input_dim = 512
# instead of a resnet you can also use a vision transformer backbone as in the
# original paper (you might have to reduce the batch size in this case):
# backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=False)
# input_dim = backbone.embed_dim
self.student_backbone = backbone
self.student_head = DINOProjectionHead(
input_dim, 512, 64, 2048, freeze_last_layer=1
)
self.teacher_backbone = copy.deepcopy(backbone)
self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)
deactivate_requires_grad(self.teacher_backbone)
deactivate_requires_grad(self.teacher_head)
self.criterion = DINOLoss(output_dim=2048, warmup_teacher_temp_epochs=5)
def forward(self, x):
y = self.student_backbone(x).flatten(start_dim=1)
z = self.student_head(y)
return z
def forward_teacher(self, x):
y = self.teacher_backbone(x).flatten(start_dim=1)
z = self.teacher_head(y)
return z
def training_step(self, batch, batch_idx):
momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)
update_momentum(self.student_backbone, self.teacher_backbone, m=momentum)
update_momentum(self.student_head, self.teacher_head, m=momentum)
views = batch[0]
views = [view.to(self.device) for view in views]
global_views = views[:2]
teacher_out = [self.forward_teacher(view) for view in global_views]
student_out = [self.forward(view) for view in views]
loss = self.criterion(teacher_out, student_out, epoch=self.current_epoch)
return loss
def on_after_backward(self):
self.student_head.cancel_last_layer_gradients(current_epoch=self.current_epoch)
def configure_optimizers(self):
optim = torch.optim.Adam(self.parameters(), lr=0.001)
return optim
model = DINO()
transform = DINOTransform()
# we ignore object detection annotations by setting target_transform to return 0
def target_transform(t):
return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=target_transform,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=64,
shuffle=True,
drop_last=True,
num_workers=8,
)
# Train with DDP and use Synchronized Batch Norm for a more accurate batch norm
# calculation. Distributed sampling is also enabled with replace_sampler_ddp=True.
trainer = pl.Trainer(
max_epochs=10,
devices="auto",
accelerator="gpu",
strategy="ddp",
sync_batchnorm=True,
use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0
)
trainer.fit(model=model, train_dataloaders=dataloader)