lightly.models

The lightly.models package provides model implementations.

Note that the high-level building blocks will be deprecated with lightly version 1.3.0. Instead, use low-level building blocks to build the models yourself.

Example implementations for all models can be found here: Model Examples

The package contains an implementation of the commonly used ResNet and adaptations of the architecture which make self-supervised learning simpler.

The package also hosts the Lightly model zoo - a list of downloadable ResNet checkpoints.

.resnet

Custom ResNet Implementation

Note that the architecture we present here differs from the one used in torchvision. We replace the first 7x7 convolution by a 3x3 convolution to make the model faster and run better on smaller input image resolutions.

Furthermore, we introduce a resnet-9 variant for extra small models. These can run for example on a microcontroller with 100kBytes of storage.

class lightly.models.resnet.BasicBlock(in_planes: int, planes: int, stride: int = 1, num_splits: int = 0)

Implementation of the ResNet Basic Block.

in_planes

Number of input channels.

planes

Number of channels.

stride

Stride of the first convolutional.

forward(x: torch.Tensor)

Forward pass through basic ResNet block.

Parameters

x – Tensor of shape bsz x channels x W x H

Returns

Tensor of shape bsz x channels x W x H

class lightly.models.resnet.Bottleneck(in_planes: int, planes: int, stride: int = 1, num_splits: int = 0)

Implementation of the ResNet Bottleneck Block.

in_planes

Number of input channels.

planes

Number of channels.

stride

Stride of the first convolutional.

forward(x)

Forward pass through bottleneck ResNet block.

Parameters

x – Tensor of shape bsz x channels x W x H

Returns

Tensor of shape bsz x channels x W x H

class lightly.models.resnet.ResNet(block: torch.nn.modules.module.Module = <class 'lightly.models.resnet.BasicBlock'>, layers: typing.List[int] = [2, 2, 2, 2], num_classes: int = 10, width: float = 1.0, num_splits: int = 0)

ResNet implementation.

[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun Deep Residual Learning for Image Recognition. arXiv:1512.03385

block

ResNet building block type.

layers

List of blocks per layer.

num_classes

Number of classes in final softmax layer.

width

Multiplier for ResNet width.

forward(x: torch.Tensor)

Forward pass through ResNet.

Parameters

x – Tensor of shape bsz x channels x W x H

Returns

Output tensor of shape bsz x num_classes

lightly.models.resnet.ResNetGenerator(name: str = 'resnet-18', width: float = 1, num_classes: int = 10, num_splits: int = 0)

Builds and returns the specified ResNet.

Parameters
  • name – ResNet version from resnet-{9, 18, 34, 50, 101, 152}.

  • width – ResNet width.

  • num_classes – Output dim of the last layer.

  • num_splits – Number of splits to use for SplitBatchNorm (for MoCo model). Increase this number to simulate multi-gpu behavior. E.g. num_splits=8 simulates a 8-GPU cluster. num_splits=0 uses normal PyTorch BatchNorm.

Returns

ResNet as nn.Module.

Examples

>>> # binary classifier with ResNet-34
>>> from lightly.models import ResNetGenerator
>>> resnet = ResNetGenerator('resnet-34', num_classes=2)

.zoo

Lightly Model Zoo

lightly.models.zoo.checkpoints()

Returns the Lightly model zoo as a list of checkpoints.

Checkpoints:
ResNet-9:

SimCLR with width = 0.0625 and num_ftrs = 16

ResNet-9:

SimCLR with width = 0.125 and num_ftrs = 16

ResNet-18:

SimCLR with width = 1.0 and num_ftrs = 16

ResNet-18:

SimCLR with width = 1.0 and num_ftrs = 32

ResNet-34:

SimCLR with width = 1.0 and num_ftrs = 16

ResNet-34:

SimCLR with width = 1.0 and num_ftrs = 32

Returns

A list of available checkpoints as URLs.

The lightly.models.modules package provides reusable modules.

This package contains reusable modules such as the NNmemoryBankModule which can be combined with any lightly model.

.nn_memory_bank

Nearest Neighbour Memory Bank Module

class lightly.models.modules.nn_memory_bank.NNMemoryBankModule(size: int = 65536)

Nearest Neighbour Memory Bank implementation

This class implements a nearest neighbour memory bank as described in the NNCLR paper[0]. During the forward pass we return the nearest neighbour from the memory bank.

[0] NNCLR, 2021, https://arxiv.org/abs/2104.14548

size

Number of keys the memory bank can store. If set to 0, memory bank is not used.

Examples

>>> model = NNCLR(backbone)
>>> criterion = NTXentLoss(temperature=0.1)
>>>
>>> nn_replacer = NNmemoryBankModule(size=2 ** 16)
>>>
>>> # forward pass
>>> (z0, p0), (z1, p1) = model(x0, x1)
>>> z0 = nn_replacer(z0.detach(), update=False)
>>> z1 = nn_replacer(z1.detach(), update=True)
>>>
>>> loss = 0.5 * (criterion(z0, p1) + criterion(z1, p0))
forward(output: torch.Tensor, update: bool = False)

Returns nearest neighbour of output tensor from memory bank

Parameters
  • output – The torch tensor for which you want the nearest neighbour

  • update – If True updated the memory bank by adding output to it

.heads

Projection and Prediction Heads for Self-supervised Learning

class lightly.models.modules.heads.BYOLPredictionHead(input_dim: int = 256, hidden_dim: int = 4096, output_dim: int = 256)

Prediction head used for BYOL.

“This MLP consists in a linear layer with output size 4096 followed by batch normalization, rectified linear units (ReLU), and a final linear layer with output dimension 256.” [0]

[0]: BYOL, 2020, https://arxiv.org/abs/2006.07733

class lightly.models.modules.heads.BYOLProjectionHead(input_dim: int = 2048, hidden_dim: int = 4096, output_dim: int = 256)

Projection head used for BYOL.

“This MLP consists in a linear layer with output size 4096 followed by batch normalization, rectified linear units (ReLU), and a final linear layer with output dimension 256.” [0]

[0]: BYOL, 2020, https://arxiv.org/abs/2006.07733

class lightly.models.modules.heads.BarlowTwinsProjectionHead(input_dim: int = 2048, hidden_dim: int = 8192, output_dim: int = 8192)

Projection head used for Barlow Twins.

“The projector network has three linear layers, each with 8192 output units. The first two layers of the projector are followed by a batch normalization layer and rectified linear units.” [0]

[0]: 2021, Barlow Twins, https://arxiv.org/abs/2103.03230

class lightly.models.modules.heads.DINOProjectionHead(input_dim: int = 2048, hidden_dim: int = 2048, bottleneck_dim: int = 256, output_dim: int = 65536, batch_norm: bool = False, freeze_last_layer: int = - 1, norm_last_layer: bool = True)

Projection head used in DINO.

“The projection head consists of a 3-layer multi-layer perceptron (MLP) with hidden dimension 2048 followed by l2 normalization and a weight normalized fully connected layer with K dimensions, which is similar to the design from SwAV [1].” [0]

input_dim

The input dimension of the head.

hidden_dim

The hidden dimension.

bottleneck_dim

Dimension of the bottleneck in the last layer of the head.

output_dim

The output dimension of the head.

batch_norm

Whether to use batch norm or not. Should be set to False when using a vision transformer backbone.

freeze_last_layer

Number of epochs during which we keep the output layer fixed. Typically doing so during the first epoch helps training. Try increasing this value if the loss does not decrease.

norm_last_layer

Whether or not to weight normalize the last layer of the DINO head. Not normalizing leads to better performance but can make the training unstable.

cancel_last_layer_gradients(current_epoch: int)

Cancel last layer gradients to stabilize the training.

forward(x: torch.Tensor) torch.Tensor

Computes one forward pass through the head.

class lightly.models.modules.heads.MSNProjectionHead(input_dim: int = 768, hidden_dim: int = 2048, output_dim: int = 256)

Projection head for MSN [0].

“We train with a 3-layer projection head with output dimension 256 and batch-normalization at the input and hidden layers..” [0] Code inspired by [1].

input_dim

Input dimension, default value 768 is for a ViT base model.

hidden_dim

Hidden dimension.

output_dim

Output dimension.

class lightly.models.modules.heads.MoCoProjectionHead(input_dim: int = 2048, hidden_dim: int = 2048, output_dim: int = 128)

Projection head used for MoCo.

“(…) we replace the fc head in MoCo with a 2-layer MLP head (hidden layer 2048-d, with ReLU)” [0]

[0]: MoCo, 2020, https://arxiv.org/abs/1911.05722

class lightly.models.modules.heads.NNCLRPredictionHead(input_dim: int = 256, hidden_dim: int = 4096, output_dim: int = 256)

Prediction head used for NNCLR.

“The architecture of the prediction MLP g is 2 fully-connected layers of size [4096,d]. The hidden layer of the prediction MLP is followed by batch-norm and ReLU. The last layer has no batch-norm or activation.” [0]

[0]: NNCLR, 2021, https://arxiv.org/abs/2104.14548

class lightly.models.modules.heads.NNCLRProjectionHead(input_dim: int = 2048, hidden_dim: int = 2048, output_dim: int = 256)

Projection head used for NNCLR.

“The architectureof the projection MLP is 3 fully connected layers of sizes [2048,2048,d] where d is the embedding size used to apply the loss. We use d = 256 in the experiments unless otherwise stated. All fully-connected layers are followed by batch-normalization [36]. All the batch-norm layers except the last layer are followed by ReLU activation.” [0]

[0]: NNCLR, 2021, https://arxiv.org/abs/2104.14548

class lightly.models.modules.heads.ProjectionHead(blocks: List[Tuple[int, int, Optional[torch.nn.modules.module.Module], Optional[torch.nn.modules.module.Module]]])

Base class for all projection and prediction heads.

Parameters

blocks – List of tuples, each denoting one block of the projection head MLP. Each tuple reads (in_features, out_features, batch_norm_layer, non_linearity_layer).

Examples

>>> # the following projection head has two blocks
>>> # the first block uses batch norm an a ReLU non-linearity
>>> # the second block is a simple linear layer
>>> projection_head = ProjectionHead([
>>>     (256, 256, nn.BatchNorm1d(256), nn.ReLU()),
>>>     (256, 128, None, None)
>>> ])
forward(x: torch.Tensor)

Computes one forward pass through the projection head.

Parameters

x – Input of shape bsz x num_ftrs.

class lightly.models.modules.heads.SMoGPredictionHead(input_dim: int = 128, hidden_dim: int = 2048, output_dim: int = 128)

Prediction head used for SMoG.

“The two kinds of head are both a two-layer MLP and their hidden layer is followed by a BatchNorm [28] and an activation function. (…) The output layer of projection head also has BN” [0]

[0]: SMoG, 2022, https://arxiv.org/pdf/2207.06167.pdf

class lightly.models.modules.heads.SMoGProjectionHead(input_dim: int = 2048, hidden_dim: int = 2048, output_dim: int = 128)

Projection head used for SMoG.

“The two kinds of head are both a two-layer MLP and their hidden layer is followed by a BatchNorm [28] and an activation function. (…) The output layer of projection head also has BN” [0]

[0]: SMoG, 2022, https://arxiv.org/pdf/2207.06167.pdf

class lightly.models.modules.heads.SMoGPrototypes(group_features: torch.Tensor, beta: float)

SMoG prototypes module for synchronous momentum grouping.

assign_groups(x: torch.Tensor) torch.LongTensor

Assigns each representation in x to a group based on cosine similarity.

Parameters

Tensor of shape bsz x dim.

Returns

LongTensor of shape bsz indicating group assignments.

forward(x: torch.Tensor, group_features: torch.Tensor, temperature: float = 0.1) torch.Tensor

Computes the logits for given model outputs and group features.

Parameters
  • x – Tensor of shape bsz x dim.

  • group_features – Momentum updated group features of shape n_groups x dim.

  • temperature – Temperature parameter for calculating the logits.

Returns

The logits.

get_updated_group_features(x: torch.Tensor) None

Performs the synchronous momentum update of the group vectors.

Parameters

x – Tensor of shape bsz x dim.

Returns

The updated group features.

set_group_features(x: torch.Tensor) None

Sets the group features and asserts they don’t require gradient.

class lightly.models.modules.heads.SimCLRProjectionHead(input_dim: int = 2048, hidden_dim: int = 2048, output_dim: int = 128)

Projection head used for SimCLR.

“We use a MLP with one hidden layer to obtain zi = g(h) = W_2 * σ(W_1 * h) where σ is a ReLU non-linearity.” [0]

[0] SimCLR, 2020, https://arxiv.org/abs/2002.05709

class lightly.models.modules.heads.SimSiamPredictionHead(input_dim: int = 2048, hidden_dim: int = 512, output_dim: int = 2048)

Prediction head used for SimSiam.

“The prediction MLP (h) has BN applied to its hidden fc layers. Its output fc does not have BN (…) or ReLU. This MLP has 2 layers.” [0]

[0]: SimSiam, 2020, https://arxiv.org/abs/2011.10566

class lightly.models.modules.heads.SimSiamProjectionHead(input_dim: int = 2048, hidden_dim: int = 2048, output_dim: int = 2048)

Projection head used for SimSiam.

“The projection MLP (in f) has BN applied to each fully-connected (fc) layer, including its output fc. Its output fc has no ReLU. The hidden fc is 2048-d. This MLP has 3 layers.” [0]

[0]: SimSiam, 2020, https://arxiv.org/abs/2011.10566

class lightly.models.modules.heads.SwaVProjectionHead(input_dim: int = 2048, hidden_dim: int = 2048, output_dim: int = 128)

Projection head used for SwaV.

[0]: SwAV, 2020, https://arxiv.org/abs/2006.09882

class lightly.models.modules.heads.SwaVPrototypes(input_dim: int = 128, n_prototypes: Union[List[int], int] = 3000)

Multihead Prototypes used for SwaV.

Each output feature is assigned to a prototype, SwaV solves the swapped predicition problem where the features of one augmentation are used to predict the assigned prototypes of the other augmentation.

Examples

>>> # use features with 128 dimensions and 512 prototypes
>>> prototypes = SwaVPrototypes(128, 512)
>>>
>>> # pass batch through backbone and projection head.
>>> features = model(x)
>>> features = nn.functional.normalize(features, dim=1, p=2)
>>>
>>> # logits has shape bsz x 512
>>> logits = prototypes(features)
forward(x) Union[torch.Tensor, List[torch.Tensor]]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

normalize()

Normalizes the prototypes so that they are on the unit sphere.