Go to the end to download the full example code.
Tutorial 8: Using timm Models as Backbones
You can use LightlySSL to pre-train any timm model using self-supervised learning since most methods share a similar workflow. All methods have a base model (the backbone), which can be any fundamental model such as ResNet or MobileNet.
In this tutorial, we will learn how to use a model architecture from the timm library as a backbone in a self-supervised learning workflow.
Import the Python frameworks we need for this tutorial. Make sure you have the necessary packages installed.
pip install lightly"[timm]"
import timm
import torch
import torch.nn as nn
timm comes packaged with >700 pre-trained models designed to be flexible and easy to use. These pre-trained models can be loaded using the create_model() function. For example, we can use the following snippet to create an efficient model.
backbone = timm.create_model(model_name="efficientnet_b0")
Using a timm Model as a Backbone
We can now use this model as a backbone for training. Let’s see how we can create a torch module for the SimCLR method.
from lightly.models.modules.heads import SimCLRProjectionHead
class SimCLR(torch.nn.Module):
def __init__(self, backbone, feature_dim, out_dim):
self.backbone = backbone
self.projection_head = SimCLRProjectionHead(feature_dim, feature_dim, out_dim)
def forward(self, x):
features = self.backbone.forward_features(x)
h = self.backbone.global_pool(features).flatten(start_dim=1)
z = self.projection_head(h)
return z
simclr = SimCLR(backbone, feature_dim=1280, out_dim=128)
# check if it works
input_a = torch.randn((2, 3, 224, 224))
input_b = torch.randn((2, 3, 224, 224))
out_a = simclr(input_a)
out_b = simclr(input_b)
Next Steps
For an indepth tutorial on fine-tuning a model using SimCLR you can refer to our fine-tuning Tutorial 7: Finetuning Lightly Checkpoints. Interested in pre-training your own self-supervised models? Check out our other tutorials:
Total running time of the script: (0 minutes 0.277 seconds)