Train a Self-Supervised Model

Sometimes it may be beneficial to finetune a self-supervised learning (SSL) model on your data before embedding the images. This may be the case when the data is from a specific domain, for example, medical images or images in agriculture. The extra training increases the embedding quality as the model can adapt to the specific domain.

For instance, in an agricultural use case, a general-purpose embedding model might only recognize something as "leaf-like" in the embedding. However, it would struggle to distinguish between different types of leaves as it requires storing additional information in the embedding vector. On the other hand, an SSL model trained on leaves can learn to encode the subtle differences between various types of leaves. This could enable it to differentiate between different plant species or distinguish between healthy and unhealthy leaves.

Embeddings From the Target Model vs. an SSL Model

The target model is the final machine learning model used to solve the actual task. It is trained on the data selected by the Lightly Worker. The target model, such as an object detection model, also creates vectors that can be used as embeddings, e.g. the weights of the backbone or last layer. It aims to capture the characteristics of the object that are important for its classification. This involves encoding all of the features that assist in determining which class an object belongs to while leaving out any classification-irrelevant characteristics from the embedding. One drawback of using an object detector as an embedding model is that the resulting embeddings include localization information of the object, which is not helpful for achieving visual diversity.

A good example of this is the way a car is classified in a target model. E.g. the color or lighting conditions of the car might help only very little in distinguishing it from a truck. These properties might therefore not be included in the embeddings of the target model. However, SSL models can take into account all properties, including lighting, color, and perspective. This is because SSL models also learn to differentiate between different images within the same class and not only between different classes. For instance, if the same car is captured in multiple images with varying perspectives and lighting conditions, SSL models will produce different embeddings for this car. Thus these embeddings can be used to sample different representations of the same car. This is very useful e.g. if detection fails in some lighting conditions.

Furthermore, SSL has the advantage that it is not influenced by the choice of classes or by label errors, thus it is more general purpose.

Train a SSL Model with the Lightly Worker

The command below will train a self-supervised model on the input images before embedding the images and running the selection algorithm:

from lightly.api import ApiWorkflowClient

# Create the Lightly client to connect to the API.
client = ApiWorkflowClient(token="MY_LIGHTLY_TOKEN", dataset_id="MY_DATASET_ID")

# Schedule a run with training enabled.
client.schedule_compute_worker_run(
    worker_config={
        "enable_training": True
    },
    selection_config={
        "n_samples": 50,
        "strategies": [
            {
                "input": {
                    "type": "EMBEDDINGS"
                },
                "strategy": {
                    "type": "DIVERSITY"
                }
            }
        ]
    },
    lightly_config={
        'loader': {
          	"batch_size": 128, # use batch size of 128
        },
        'trainer': {
          	'max_epochs': 100, # default
        },
    }
)

The model is trained for 100 epochs by default. You can adjust the training settings using the lightly_config argument. The most common settings you might want to change in the lightly_config are the max_epochs and the batch_size, which determine the maximum number of training epochs and the batch size, respectively.

You can find a complete list of settings that can be adjusted here: Configuration Options.

Train on Object Crops

For datasets containing objects, it is often beneficial to train the model only on the object crops instead of the full images. This results in better embedding and selection quality. See the docs on Crop Selection on how to train a model on object crops and use it for selection.

Checkpoints

Use Checkpoints from Previous Runs

Checkpoints from previous Lightly Worker runs can be reused when scheduling a new run. This allows you to skip the expensive training step.

When using the datapool feature by scheduling a run on an existing dataset, the last available checkpoint will be used automatically: It takes the newest checkpoint from previous Lightly Worker runs on that dataset.

You can also reuse checkpoints from runs on other datasets: Go to the Lightly Platform runs overview and click on the run whose checkpoint you want to reuse. It must have the checkpoint available as an artifact. Then copy the run_idof that run. Set it in the worker_configas "checkpoint_run_id": "RUN_ID_JUST_COPIED" when scheduling a new run.

# This workflow requires Lightly Worker v2.9.2+ and Lightly Python Client v1.4.19+

# Create the Lightly client to connect to the API.
client = ApiWorkflowClient(token="MY_LIGHTLY_TOKEN", dataset_id="MY_DATASET_ID")

scheduled_run_id = client.schedule_compute_worker_run(
    worker_config={
        "checkpoint_run_id": "PREVIOUS_RUN_ID_WITH_CHECKPOINT_TO_BE_REUSED",
        ...
    },
    selection_config={...},
)
# Create the Lightly client to connect to the API.
client = ApiWorkflowClient(token="MY_LIGHTLY_TOKEN", dataset_id="MY_DATASET_ID")

### Option 1: Get the run via the dataset ID
runs = client.get_compute_worker_runs(dataset_id="DATASET_ID_WITH_THE_CHECKPOINT_TO_BE_REUSED")
run = runs[-1]  # Get the last run of the dataset.
### Option 2: Get the run via the run ID:
# run = client.get_compute_worker_run(run_id="MY_RUN_ID")

# Get read-url of checkpoint.
# CAVEAT: If you are using S3 delegated access it can happen that the read-url
# expires before the Lightly Worker can download it. If this happens, we recommend
# generating a public read-url for the checkpoint that doesn't expire or to increase
# the `Maximum session duration`.
checkpoint_url = client.get_compute_worker_run_checkpoint_url(run=run)

# Schedule a new run, reusing the checkpoint from another run.
scheduled_run_id = client.schedule_compute_worker_run(
    worker_config={
        "checkpoint": checkpoint_url,
        ...
    },
    selection_config={...},
)

Use Checkpoint from Model Trained Outside the Lightly Worker

If you require a more custom training schedule, you can also train the embedding model outside of the Lightly Worker using your own PyTorch training code. The following code generates a model that is compatible with the Lightly Worker:

from lightly.models import ResNetGenerator, batchnorm
from lightly.models.modules import SimCLRProjectionHead
from torch import nn


class SimCLR(nn.Module):
    def __init__(self, backbone, num_ftrs, out_dim):
        super().__init__()
        self.backbone = backbone
        self.projection_head = SimCLRProjectionHead(num_ftrs, num_ftrs, out_dim)

    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z


def build_model(name="resnet-18", out_dim=128, num_ftrs=32, width=1):
    """Returns a SimCLR model for training."""
    resnet = ResNetGenerator(name=name, width=width)
    last_conv_channels = list(resnet.children())[-1].in_features
    backbone = nn.Sequential(
        batchnorm.get_norm_layer(3, 0),
        *list(resnet.children())[:-1],
        nn.Conv2d(last_conv_channels, num_ftrs, 1),
        nn.AdaptiveAvgPool2d(1),
    )
    model = SimCLR(
        backbone=backbone,
        num_ftrs=num_ftrs,
        out_dim=out_dim,
    )
    return model


model = build_model()

After training, you have to save the model checkpoint to make it loadable by the Lightly Worker:

import torch

torch.save({"state_dict": model.state_dict(prefix="model.")}, "checkpoint.ckpt")

Finally, you have to make the checkpoint available to the worker by providing a read-URL to it. This can be done in a way of your choice. For example, you can upload your checkpoint to a cloud bucket and generate a read-URL for it there. As the Lightly Worker downloads the checkpoint as part of the embedding step, you need to set the expiration of the URL accordingly.

# Schedule a run using a checkpoint provided via read-URL
scheduled_run_id = client.schedule_compute_worker_run(
    worker_config={
        "checkpoint": "https://my-checkpoint_read-url",
        ...
    },
    selection_config={...},
)

🚧

Changing Model Parameters

If you change model parameters, such as name, out_dim, num_ftrs, or width you have to specify them when scheduling a new run. This is done by setting the model parameters in the Lightly configuration:

client.schedule_compute_worker_run(
    worker_config={
        "checkpoint": "checkpoint.ckpt",
      	"enable_training": False,
    },
    lightly_config={
        "model": {
            "name": 'resnet-18',
            "out_dim": 128,
            "num_ftrs": 32,
            "width": 1,
        },
    },
    ...
)

See Configuration Options for more details.

Load Model from Checkpoint

You might also want to use the checkpoint independent of the Lightly Worker, for example, to train a classifier on your dataset. Using a pre-trained checkpoint often results in models with higher accuracy than training a new model from scratch.

First, you need to download the checkpoint using one of these two options:

runs = client.get_compute_worker_runs()
run = runs[-1]  # Get last run on the dataset. Alternatively: filter the runs.

# Alternatively you can get the run from a scheduled run id:
# run = client.get_compute_worker_run_from_scheduled_run(scheduled_run_id=scheduled_run_id)

client.download_compute_worker_run_checkpoint(
    run=run, output_path="artifacts/checkpoint.ckpt"
)

Once you have downloaded the checkpoint to your machine, you can load it as PyTorch model in your Python application:

from collections import OrderedDict
import torch
import lightly


def load_checkpoint(
    checkpoint_path, model_name="resnet-18", model_width=1, map_location="cpu"
):
    checkpoint = torch.load(checkpoint_path, map_location=map_location)

    state_dict = OrderedDict()
    for key, value in checkpoint["state_dict"].items():
        if ("projection_head" in key) or ("backbone.7" in key):
            # drop layers used for projection head
            continue
        state_dict[key.replace("model.backbone.", "")] = value

    resnet = lightly.models.ResNetGenerator(name=model_name, width=model_width)
    model = torch.nn.Sequential(
        lightly.models.batchnorm.get_norm_layer(3, 0),
        *list(resnet.children())[:-1],
        torch.nn.AdaptiveAvgPool2d(1),
        torch.nn.Flatten(1),
    )
    try:
        model.load_state_dict(state_dict)
    except RuntimeError:
        raise RuntimeError(
            f"It looks like you tried loading a checkpoint from a model that is not a {model_name} with width={model_width}! "
            f"Please set model_name and model_width to the lightly.model.name and lightly.model.width parameters from the "
            f"configuration you used to run Lightly. The configuration from a Lightly worker run can be found in output_dir/config/config.yaml"
        )
    return model


# Load the model
model = load_checkpoint(checkpoint_path="artifacts/checkpoint.ckpt")


# Example usage
image_batch = torch.rand(16, 3, 224, 224)
out = model(image_batch)
print(out.shape)  # prints: torch.Size([16, 512])


# Creating a classifier from the pre-trained model
num_classes = 10
classifier = torch.nn.Sequential(
    model, torch.nn.Linear(512, num_classes)  # use 2048 instead of 512 for resnet-50
)

out = classifier(image_batch)
print(out.shape)  # prints: torch.Size(16, 10)