Custom Models

Note

Training custom models from the command line or with docker is not yet supported.

LightlyTrain supports training custom models. This requires writing a small wrapper around your model to implement the necessary methods. The wrapper must be a subclass of torch.nn.Module and implement the following methods:

  • forward_features(self, x: Tensor) -> Dict[str, Any]

    Forward pass of the model that extracts features without pooling them.

  • forward_pool(self, x: Dict[str, Any]) -> Dict[str, Any]

    Forward pass of the pooling layer that pools the features extracted by forward_features.

  • feature_dim(self) -> int

    Dimension of output features (channels) of the features returned by forward_features and forward_pool.

The methods are described in more detail in the template below:

from typing import Any, Dict

from torch import Tensor
from torch.nn import Module

import lightly_train


class MyModelWrapper(Module):
    def __init__(self, model: Module):
        super().__init__()
        self._model = model     # Pass your model here

    def forward_features(self, x: Tensor) -> Dict[str, Any]:
        """Forward pass to extract features from images.

        Implement the feature extraction forward pass here. This method takes images
        as input and extracts features from them. In most cases this method should
        call your model's backbone or encoder. The method should not pool the final
        features and should not pass them through any classification, detection, or
        other heads.

        Args:
            x: Batch of images with shape (B, 3, H_in, W_in).
        
        Returns:
            Dict with a "features" entry containing the features tensor with shape
            (B, feature_dim, H_out, W_out). Add any other entries to the dict if they
            are needed in the forward_pool method. For example, for transformer models
            you might want to return the class token as well.
        """
        features = ...
        return {"features": features}

    def forward_pool(self, x: Dict[str, Any]) -> Dict[str, Any]:
        """Forward pass to pool features extracted by forward_features.

        Implement the pooling layer forward pass here. This method must take the
        output of the forward_features method as input and pool the features.

        Args:
            x: 
                Dict with a "features" entry containing the features tensor with shape
                (B, feature_dim, H_in, W_in).

        Returns:
            Dict with a "pooled_features" entry containing the pooled features tensor
            with shape (B, feature_dim, H_out, W_out). H_out and W_out are usually 1.
        """
        pooled_features = ...
        return {"pooled_features": pooled_features}

    def feature_dim(self) -> int:
        """Return the dimension of output features.

        This method must return the dimension of output features of the forward_features
        and forward_pool methods.
        """
        return 2048

if __name__ == "__main__":
    model = ... # Instatiate the model you want to train
    wrapped_model = MyModelWrapper(model) # Wrap the model

    lightly_train.train(
        out="out/my_experiment",
        data="my_data_dir",
        model=wrapped_model,
    )

The wrapped model will be called as follows inside LightlyTrain:

embedding_layer = EmbeddingLayer(input_dim=wrapped_model.feature_dim())

images = load_batch()
x = transform(images)   # Augment and convert images to tensor
x = wrapped_model.forward_features(x)
x = wrapped_model.forward_pool(x)
x = embedding_layer(x)
embeddings = x.flatten(start_dim=1)

Some SSL methods do not call the forward_pool method and only use the unpooled features. In this case, the embedding layer is applied directly to the output of forward_features.

Example

The following example demonstrates how to write a wrapper for a torchvision ResNet-18 model.

from typing import Any, Dict

from torch import Tensor
from torch.nn import Module
from torchvision.models import resnet18

import lightly_train


class MyModelWrapper(Module):
    def __init__(self, model: Module):
        super().__init__()
        self._model = model     # Pass your model here

    def forward_features(self, x: Tensor) -> Dict[str, Any]:
        # Torchvision ResNet has no method for only extracting features. We have to
        # call the intermediate layers of the model individually.
        # Note that we skip the final average pooling and fully connected classification
        # layer.
        x = self._model.conv1(x)
        x = self._model.bn1(x)
        x = self._model.relu(x)
        x = self._model.maxpool(x)

        x = self._model.layer1(x)
        x = self._model.layer2(x)
        x = self._model.layer3(x)
        x = self._model.layer4(x)
        return {"features": x}
    
    def forward_pool(self, x: Dict[str, Any]) -> Dict[str, Any]:
        # Here we call the average pooling layer of the model to pool the features.
        x = self._model.avgpool(x["features"])
        return {"pooled_features": x}

    def feature_dim(self) -> int:
        # ResNet-18 has 512 output features after the last convolutional layer.
        return 512


if __name__ == "__main__":
    model = resnet18()
    wrapped_model = MyModelWrapper(model)

    lightly_train.train(
        out="out/my_experiment",
        data="my_data_dir",
        model=wrapped_model,
    )