Custom Models

Note

Training custom models from the command line or with docker is not yet supported.

LightlyTrain supports training custom models. This requires writing a small wrapper around your model to implement the necessary methods. The wrapper must be a subclass of torch.nn.Module and implement the following methods:

  • forward_features(self, x: Tensor) -> Tensor

    Forward pass of the model that extracts features without pooling them.

  • forward_pool(self, x: Tensor) -> Tensor

    Forward pass of the pooling layer that pools the features extracted by forward_features.

  • num_features(self) -> int

    Number of output features (channels) of the features returned by forward_features and forward_pool.

The methods are described in more detail in the template below:

import lightly_train
from torch import Tensor
from torch.nn import Module


class MyModelWrapper(Module):
    def __init__(self, model: Module):
        super().__init__()
        self._model = model     # Pass your model here

    def forward_features(self, x: Tensor) -> Tensor:
        # Implement the feature extraction forward pass here. This method takes images
        # as input and extracts features from them. In most cases this method should
        # call your model's backbone or encoder. The method should not pool the final
        # features and should not pass them through any classification/detection/etc.
        # heads.
        # The input is a batch of images with shape (B, 3, H_in, W_in).
        # The output is a batch of features with shape (B, num_features, H_out, W_out).
        return x

    def forward_pool(self, x: Tensor) -> Tensor:
        # Implement the pooling layer forward pass here. This method must take the
        # output of the forward_features method as input and pool the features.
        # The input is a batch of features with shape (B, num_features, H_in, W_in).
        # The output is a batch of features with shape (B, num_features, H_out, W_out).
        # Where H_out and W_out are usually 1.
        return x

    def num_features(self) -> int:
        # Implement number of output features here. 
        # This method must return the number of output features of the forward_features
        # and forward_pool methods.
        return 2048

model = ... # Instatiate the model you want to train
wrapped_model = MyModelWrapper(model) # Wrap the model

lightly_train.train(
    out="out/my_experiment",
    data="my_data_dir",
    model=wrapped_model,
    method="dino",
)

The wrapped model will be called as follows inside LightlyTrain:

embedding_layer = EmbeddingLayer(input_dim=wrapped_model.num_features())

images = load_batch()
x = transform(images)   # Augment and convert images to tensor
x = wrapped_model.forward_features(x)
x = wrapped_model.forward_pool(x)
x = embedding_layer(x)
embeddings = x.flatten(start_dim=1)

Some SSL methods (e.g. DenseCL) do not call the forward_pool method and only use the unpooled features. In this case, the embedding layer is applied directly to the output of forward_features.

Example

The following example demonstrates how to write a wrapper for a torchvision ResNet-18 model.

import lightly_train
from torch import Tensor
from torch.nn import Module
from torchvision.models import resnet18


class MyModelWrapper(Module):
    def __init__(self, model: Module):
        super().__init__()
        self._model = model     # Pass your model here

    def forward_features(self, x: Tensor) -> Tensor:
        # Torchvision ResNet has no method for only extracting features. We have to
        # call the intermediate layers of the model individually.
        # Note that we skip the final average pooling and fully connected classification
        # layer.
        x = self._model.conv1(x)
        x = self._model.bn1(x)
        x = self._model.relu(x)
        x = self._model.maxpool(x)

        x = self._model.layer1(x)
        x = self._model.layer2(x)
        x = self._model.layer3(x)
        x = self._model.layer4(x)
        return x
    
    def forward_pool(self, x: Tensor) -> Tensor:
        # Here we call the average pooling layer of the model to pool the features.
        x = self._model.avgpool(x)
        return x

    def num_features(self) -> int:
        # ResNet-18 has 512 output features after the last convolutional layer.
        return 512


model = resnet18()
wrapped_model = MyModelWrapper(model)

lightly_train.train(
    out="out/my_experiment",
    data="my_data_dir",
    model=wrapped_model,
    method="dino",
)