(custom-models)= # Custom Models ```{note} Training custom models from the command line or with docker is not yet supported. ``` Lightly**Train** supports training custom models. This requires writing a small wrapper around your model to implement the necessary methods. The wrapper must be a subclass of `torch.nn.Module` and implement the following methods: - `forward_features(self, x: Tensor) -> Tensor` Forward pass of the model that extracts features without pooling them. - `forward_pool(self, x: Tensor) -> Tensor` Forward pass of the pooling layer that pools the features extracted by `forward_features`. - `num_features(self) -> int` Number of output features (channels) of the features returned by `forward_features` and `forward_pool`. The methods are described in more detail in the template below: ```python import lightly_train from torch import Tensor from torch.nn import Module class MyModelWrapper(Module): def __init__(self, model: Module): super().__init__() self._model = model # Pass your model here def forward_features(self, x: Tensor) -> Tensor: # Implement the feature extraction forward pass here. This method takes images # as input and extracts features from them. In most cases this method should # call your model's backbone or encoder. The method should not pool the final # features and should not pass them through any classification/detection/etc. # heads. # The input is a batch of images with shape (B, 3, H_in, W_in). # The output is a batch of features with shape (B, num_features, H_out, W_out). return x def forward_pool(self, x: Tensor) -> Tensor: # Implement the pooling layer forward pass here. This method must take the # output of the forward_features method as input and pool the features. # The input is a batch of features with shape (B, num_features, H_in, W_in). # The output is a batch of features with shape (B, num_features, H_out, W_out). # Where H_out and W_out are usually 1. return x def num_features(self) -> int: # Implement number of output features here. # This method must return the number of output features of the forward_features # and forward_pool methods. return 2048 model = ... # Instatiate the model you want to train wrapped_model = MyModelWrapper(model) # Wrap the model lightly_train.train( out="out/my_experiment", data="my_data_dir", model=wrapped_model, method="dino", ) ``` The wrapped model will be called as follows inside Lightly**Train**: ```python embedding_layer = EmbeddingLayer(input_dim=wrapped_model.num_features()) images = load_batch() x = transform(images) # Augment and convert images to tensor x = wrapped_model.forward_features(x) x = wrapped_model.forward_pool(x) x = embedding_layer(x) embeddings = x.flatten(start_dim=1) ``` Some [SSL methods](#methods) (e.g. DenseCL) do not call the `forward_pool` method and only use the unpooled features. In this case, the embedding layer is applied directly to the output of `forward_features`. ## Example The following example demonstrates how to write a wrapper for a torchvision ResNet-18 model. ```python import lightly_train from torch import Tensor from torch.nn import Module from torchvision.models import resnet18 class MyModelWrapper(Module): def __init__(self, model: Module): super().__init__() self._model = model # Pass your model here def forward_features(self, x: Tensor) -> Tensor: # Torchvision ResNet has no method for only extracting features. We have to # call the intermediate layers of the model individually. # Note that we skip the final average pooling and fully connected classification # layer. x = self._model.conv1(x) x = self._model.bn1(x) x = self._model.relu(x) x = self._model.maxpool(x) x = self._model.layer1(x) x = self._model.layer2(x) x = self._model.layer3(x) x = self._model.layer4(x) return x def forward_pool(self, x: Tensor) -> Tensor: # Here we call the average pooling layer of the model to pool the features. x = self._model.avgpool(x) return x def num_features(self) -> int: # ResNet-18 has 512 output features after the last convolutional layer. return 512 model = resnet18() wrapped_model = MyModelWrapper(model) lightly_train.train( out="out/my_experiment", data="my_data_dir", model=wrapped_model, method="dino", ) ```