Train a Self-Supervised Model

Sometimes it may be beneficial to finetune a self-supervised learning (SSL) model on your data before embedding the images. This may be the case when the data is from a specific domain, for example, medical images or images in agriculture. The extra training increases the embedding quality as the model can adapt to the specific domain.

For instance, in an agricultural use case, a general-purpose embedding model might only recognize something as "leaf-like" in the embedding. However, it would struggle to distinguish between different types of leaves as it requires storing additional information in the embedding vector. On the other hand, an SSL model trained on leaves can learn to encode the subtle differences between various types of leaves. This could enable it to differentiate between different plant species or distinguish between healthy and unhealthy leaves.

Embeddings From the Target Model vs. an SSL Model

The target model is the final machine learning model used to solve the actual task. It is trained on the data selected by the Lightly Worker. The target model, such as an object detection model, also creates vectors that can be used as embeddings, e.g. the weights of the backbone or last layer. It aims to capture the characteristics of the object that are important for its classification. This involves encoding all of the features that assist in determining which class an object belongs to while leaving out any classification-irrelevant characteristics from the embedding. One drawback of using an object detector as an embedding model is that the resulting embeddings include localization information of the object, which is not helpful for achieving visual diversity.

A good example of this is the way a car is classified in a target model. E.g. the color or lighting conditions of the car might help only very little in distinguishing it from a truck. These properties might therefore not be included in the embeddings of the target model. However, SSL models can take into account all properties, including lighting, color, and perspective. This is because SSL models also learn to differentiate between different images within the same class and not only between different classes. For instance, if the same car is captured in multiple images with varying perspectives and lighting conditions, SSL models will produce different embeddings for this car. Thus these embeddings can be used to sample different representations of the same car. This is very useful e.g. if detection fails in some lighting conditions.

Furthermore, SSL has the advantage that it is not influenced by the choice of classes or by label errors, thus it is more general purpose.

Train a SSL Model with the Lightly Worker

The command below will train a self-supervised model on the input images before embedding the images and running the selection algorithm:

from lightly.api import ApiWorkflowClient

# Create the Lightly client to connect to the API.
client = ApiWorkflowClient(token="MY_LIGHTLY_TOKEN", dataset_id="MY_DATASET_ID")

# Schedule a run with training enabled.
client.schedule_compute_worker_run(
    worker_config={
        "enable_training": True
    },
    selection_config={
        "n_samples": 50,
        "strategies": [
            {
                "input": {
                    "type": "EMBEDDINGS"
                },
                "strategy": {
                    "type": "DIVERSITY"
                }
            }
        ]
    },
)

The model is trained for 100 epochs by default. You can adjust the training settings using the lightly_config argument. The most common settings you might want to change in the lightly_config are the max_epochs and num_workers, which determine the maximum number of training epochs and number of parallel data loading processes, respectively:

client.schedule_compute_worker_run(
    worker_config={
        "enable_training": True,
    },
    selection_config={
        "n_samples": 50,
        "strategies": [
            {
                "input": {
                    "type": "EMBEDDINGS"
                },
                "strategy": {
                    "type": "DIVERSITY"
                }
            }
        ]
    },
    lightly_config={
        'loader': {
            'num_workers': -1,
        },
        'trainer': {
            'max_epochs': 100,
        },
    }
)

You can find a complete list of settings that can be adjusted here: Configuration Options.

Train on Object Crops

For datasets containing objects, it is often beneficial to train the model only on the object crops instead of the full images. This results in better embedding and selection quality. See the docs on Crop Selection on how to train a model on object crops and use it for selection.

Checkpoints

Checkpoints from your training process will be stored in the Lightly Platform as artifacts. The last checkpoint from a run can be downloaded using the Lightly Python client:

runs = client.get_compute_worker_runs()
run = runs[-1]  # Get last run on the dataset.

# Alternatively you can get the run from a scheduled run id:
# run = client.get_compute_worker_run_from_scheduled_run(scheduled_run_id=scheduled_run_id)

client.download_compute_worker_run_checkpoint(
    run=run, output_path="artifacts/checkpoint.ckpt"
)

Use Checkpoints from Previous Runs

Checkpoints from previous runs can be reused when scheduling a new run. This allows you to skip the expensive training step. If you schedule the job on the same datapool you did the training with, the checkpoint will automatically be used. You can also reuse checkpoints from runs on other datasets:

🚧

Caveat: S3 Delegated Access

If you are using S3 delegated access it can happen that the read-url expires before the Lightly Worker can download it. If this happens, we recommend generating a public read-url for the checkpoint that doesn't expire or to increase the Maximum session duration.

# Create the Lightly client to connect to the API.
client = ApiWorkflowClient(token="MY_LIGHTLY_TOKEN", dataset_id="MY_DATASET_ID")

# Get the read-URL of a previous run on another dataset.
## Option 1: Get the run via the dataset ID
runs = client.get_compute_worker_runs(dataset_id="DATASET_ID_WITH_THE_CHECKPOINT_TO_BE_REUSED")
run = runs[-1]  # Get the last run of the dataset.
## Option 2: Get the run via the run ID:
# run = client.get_compute_worker_run(run_id="MY_RUN_ID")

checkpoint_url = client.get_compute_worker_run_checkpoint_url(run=run)

# Schedule a new run, reusing the checkpoint from another run.
scheduled_run_id = client.schedule_compute_worker_run(
    worker_config={
        "checkpoint": checkpoint_url,
        ...
    },
    selection_config={...},
)

Use Checkpoint from Model Trained Outside the Lightly Worker

If you require a more custom training schedule, you can also train the embedding model outside of the Lightly Worker using your own PyTorch training code. The following code generates a model that is compatible with the Lightly Worker:

from lightly.models import ResNetGenerator, batchnorm
from lightly.models.modules import SimCLRProjectionHead
from torch import nn


class SimCLR(nn.Module):
    def __init__(self, backbone, num_ftrs, out_dim):
        super().__init__()
        self.backbone = backbone
        self.projection_head = SimCLRProjectionHead(num_ftrs, num_ftrs, out_dim)

    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z


def build_model(name="resnet-18", out_dim=128, num_ftrs=32, width=1):
    """Returns a SimCLR model for training."""
    resnet = ResNetGenerator(name=name, width=width)
    last_conv_channels = list(resnet.children())[-1].in_features
    backbone = nn.Sequential(
        batchnorm.get_norm_layer(3, 0),
        *list(resnet.children())[:-1],
        nn.Conv2d(last_conv_channels, num_ftrs, 1),
        nn.AdaptiveAvgPool2d(1),
    )
    model = SimCLR(
        backbone=backbone,
        num_ftrs=num_ftrs,
        out_dim=out_dim,
    )
    return model


model = build_model()

After training, you have to save the model checkpoint to make it loadable by the Lightly Worker:

import torch

torch.save({"state_dict": model.state_dict(prefix="model.")}, "checkpoint.ckpt")

Finally, you have to make the checkpoint available to the worker by providing a read-URL to it. This can be done in a way of your choice. For example, you can upload your checkpoint to a cloud bucket and generate a read-URL for it there. As the Lightly Worker downloads the checkpoint as part of the embedding step, you need to set the expiration of the URL accordingly.

# Schedule a run using a checkpoint provided via read-URL
scheduled_run_id = client.schedule_compute_worker_run(
    worker_config={
        "checkpoint": "https://my-checkpoint_read-url",
        ...
    },
    selection_config={...},
)

🚧

Changing Model Parameters

If you change model parameters, such as name, out_dim, num_ftrs, or width you have to specify them when scheduling a new run. This is done by setting the model parameters in the Lightly configuration:

client.schedule_compute_worker_run(
    worker_config={
        "checkpoint": "checkpoint.ckpt",
      	"enable_training": False,
    },
    lightly_config={
        "model": {
            "name": 'resnet-18',
            "out_dim": 128,
            "num_ftrs": 32,
            "width": 1,
        },
    },
    ...
)

See Configuration Options for more details.

Load Model from Checkpoint

You might also want to use the checkpoint independent of the Lightly Worker, for example, to train a classifier on your dataset. Using a pretrained checkpoint often results in models with higher accuracy than training a new model from scratch. The code below demonstrates how the checkpoint can be loaded as PyTorch model in your Python application:

from collections import OrderedDict
import torch
import lightly


def load_checkpoint(
    checkpoint_path, model_name="resnet-18", model_width=1, map_location="cpu"
):
    checkpoint = torch.load(checkpoint_path, map_location=map_location)

    state_dict = OrderedDict()
    for key, value in checkpoint["state_dict"].items():
        if ("projection_head" in key) or ("backbone.7" in key):
            # drop layers used for projection head
            continue
        state_dict[key.replace("model.backbone.", "")] = value

    resnet = lightly.models.ResNetGenerator(name=model_name, width=model_width)
    model = torch.nn.Sequential(
        lightly.models.batchnorm.get_norm_layer(3, 0),
        *list(resnet.children())[:-1],
        torch.nn.AdaptiveAvgPool2d(1),
        torch.nn.Flatten(1),
    )
    try:
        model.load_state_dict(state_dict)
    except RuntimeError:
        raise RuntimeError(
            f"It looks like you tried loading a checkpoint from a model that is not a {model_name} with width={model_width}! "
            f"Please set model_name and model_width to the lightly.model.name and lightly.model.width parameters from the "
            f"configuration you used to run Lightly. The configuration from a Lightly worker run can be found in output_dir/config/config.yaml"
        )
    return model


# Load the model
model = load_checkpoint(checkpoint_path="artifacts/checkpoint.ckpt")


# Example usage
image_batch = torch.rand(16, 3, 224, 224)
out = model(image_batch)
print(out.shape)  # prints: torch.Size([16, 512])


# Creating a classifier from the pre-trained model
num_classes = 10
classifier = torch.nn.Sequential(
    model, torch.nn.Linear(512, num_classes)  # use 2048 instead of 512 for resnet-50
)

out = classifier(image_batch)
print(out.shape)  # prints: torch.Size(16, 10)