BYOL

BYOL (Bootstrap Your Own Latent) 0 is a self-supervised learning framework for visual representation learning without negative samples. Unlike contrastive learning methods, such as MoCo 1 and SimCLR 2 that compare positive and negative pairs, BYOL uses two neural networks – “online” and “target” – where the online network is trained to predict the target’s representation of the same image under different augmentations, yielding in iterative bootstrapping of the latent samples. The target’s weights are updated as the exponential moving average (EMA) of the online network, and the authors show that this is sufficient to prevent collapse to trivial solutions. The authors also show that due to the absence of negative samples, BYOL is less sensitive to the batch size during training and manages to achieve state-of-the-art performance on several semi-supervised and transfer learning benchmarks.

Key Components

  • Data Augmentations: BYOL 0 uses the same augmentations as SimCLR 2, namely random resized crop, random horizontal flip, color distortions, Gaussian blur and solarization. The color distortion consists of a random sequence of brightness, constrast, saturation, hue adjustments and an optional grayscale conversion. However the hyperparameters for the augmentations are different from SimCLR 2.

  • Backbone: BYOL 0 uses ResNet-type convolutional backbones as the online and target networks. They do not evaluate the performance of other architectures.

  • Projection & Prediction Head: A projection head is used to map the output of the backbone to a lower-dimensional space. For this, the target network once again relies on an EMA of the online network. A notable architectureal choice is the use of an additional prediction head, a secondary MLP appended to only the online network’s projection head.

  • Loss Function: BYOL 0 uses a negative cosine similarity loss between the representations of the online’s prediction output and the target’s projection output.

Good to Know

  • Backbone Networks: BYOL is specifically optimized for convolutional neural networks, with a focus on ResNet architectures. We do not recommend using it with transformer-based models and instead suggest using DINO 3.

Reference:
0(1,2,3,4)

Bootstrap your own latent: A new approach to self-supervised Learning, 2020

1

Momentum Contrast for Unsupervised Visual Representation Learning, 2019

2(1,2,3)

A Simple Framework for Contrastive Learning of Visual Representations, 2020

3

Emerging Properties in Self-Supervised Vision Transformers, 2021

https://img.shields.io/badge/Open%20in%20Colab-blue?logo=googlecolab&label=%20&labelColor=5c5c5c

This example can be run from the command line with:

python lightly/examples/pytorch/byol.py
# This example requires the following dependencies to be installed:
# pip install lightly

# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.

import copy

import torch
import torchvision
from torch import nn

from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.byol_transform import (
    BYOLTransform,
    BYOLView1Transform,
    BYOLView2Transform,
)
from lightly.utils.scheduler import cosine_schedule


class BYOL(nn.Module):
    def __init__(self, backbone):
        super().__init__()

        self.backbone = backbone
        self.projection_head = BYOLProjectionHead(512, 1024, 256)
        self.prediction_head = BYOLPredictionHead(256, 1024, 256)

        self.backbone_momentum = copy.deepcopy(self.backbone)
        self.projection_head_momentum = copy.deepcopy(self.projection_head)

        deactivate_requires_grad(self.backbone_momentum)
        deactivate_requires_grad(self.projection_head_momentum)

    def forward(self, x):
        y = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(y)
        p = self.prediction_head(z)
        return p

    def forward_momentum(self, x):
        y = self.backbone_momentum(x).flatten(start_dim=1)
        z = self.projection_head_momentum(y)
        z = z.detach()
        return z


resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = BYOL(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# We disable resizing and gaussian blur for cifar10.
transform = BYOLTransform(
    view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
    view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
)
dataset = torchvision.datasets.CIFAR10(
    "datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

criterion = NegativeCosineSimilarity()
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)

epochs = 10

print("Starting Training")
for epoch in range(epochs):
    total_loss = 0
    momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)
    for batch in dataloader:
        x0, x1 = batch[0]
        update_momentum(model.backbone, model.backbone_momentum, m=momentum_val)
        update_momentum(
            model.projection_head, model.projection_head_momentum, m=momentum_val
        )
        x0 = x0.to(device)
        x1 = x1.to(device)
        p0 = model(x0)
        z0 = model.forward_momentum(x0)
        p1 = model(x1)
        z1 = model.forward_momentum(x1)
        loss = 0.5 * (criterion(p0, z1) + criterion(p1, z0))
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")