Custom Models¶
Note
Training custom models from the command line or with docker is not yet supported.
LightlyTrain supports training custom models. This requires writing a small wrapper
around your model to implement the necessary methods. The wrapper must be a subclass
of torch.nn.Module
and implement the following methods:
forward_features(self, x: Tensor) -> Tensor
Forward pass of the model that extracts features without pooling them.
forward_pool(self, x: Tensor) -> Tensor
Forward pass of the pooling layer that pools the features extracted by
forward_features
.num_features(self) -> int
Number of output features (channels) of the features returned by
forward_features
andforward_pool
.
The methods are described in more detail in the template below:
import lightly_train
from torch import Tensor
from torch.nn import Module
class MyModelWrapper(Module):
def __init__(self, model: Module):
super().__init__()
self._model = model # Pass your model here
def forward_features(self, x: Tensor) -> Tensor:
# Implement the feature extraction forward pass here. This method takes images
# as input and extracts features from them. In most cases this method should
# call your model's backbone or encoder. The method should not pool the final
# features and should not pass them through any classification/detection/etc.
# heads.
# The input is a batch of images with shape (B, 3, H_in, W_in).
# The output is a batch of features with shape (B, num_features, H_out, W_out).
return x
def forward_pool(self, x: Tensor) -> Tensor:
# Implement the pooling layer forward pass here. This method must take the
# output of the forward_features method as input and pool the features.
# The input is a batch of features with shape (B, num_features, H_in, W_in).
# The output is a batch of features with shape (B, num_features, H_out, W_out).
# Where H_out and W_out are usually 1.
return x
def num_features(self) -> int:
# Implement number of output features here.
# This method must return the number of output features of the forward_features
# and forward_pool methods.
return 2048
model = ... # Instatiate the model you want to train
wrapped_model = MyModelWrapper(model) # Wrap the model
lightly_train.train(
out="out/my_experiment",
data="my_data_dir",
model=wrapped_model,
method="dino",
)
The wrapped model will be called as follows inside LightlyTrain:
embedding_layer = EmbeddingLayer(input_dim=wrapped_model.num_features())
images = load_batch()
x = transform(images) # Augment and convert images to tensor
x = wrapped_model.forward_features(x)
x = wrapped_model.forward_pool(x)
x = embedding_layer(x)
embeddings = x.flatten(start_dim=1)
Some SSL methods (e.g. DenseCL) do not call the forward_pool
method and
only use the unpooled features. In this case, the embedding layer is applied directly to
the output of forward_features
.
Example¶
The following example demonstrates how to write a wrapper for a torchvision ResNet-18 model.
import lightly_train
from torch import Tensor
from torch.nn import Module
from torchvision.models import resnet18
class MyModelWrapper(Module):
def __init__(self, model: Module):
super().__init__()
self._model = model # Pass your model here
def forward_features(self, x: Tensor) -> Tensor:
# Torchvision ResNet has no method for only extracting features. We have to
# call the intermediate layers of the model individually.
# Note that we skip the final average pooling and fully connected classification
# layer.
x = self._model.conv1(x)
x = self._model.bn1(x)
x = self._model.relu(x)
x = self._model.maxpool(x)
x = self._model.layer1(x)
x = self._model.layer2(x)
x = self._model.layer3(x)
x = self._model.layer4(x)
return x
def forward_pool(self, x: Tensor) -> Tensor:
# Here we call the average pooling layer of the model to pool the features.
x = self._model.avgpool(x)
return x
def num_features(self) -> int:
# ResNet-18 has 512 output features after the last convolutional layer.
return 512
model = resnet18()
wrapped_model = MyModelWrapper(model)
lightly_train.train(
out="out/my_experiment",
data="my_data_dir",
model=wrapped_model,
method="dino",
)