Train a Self-Supervised Model

Sometimes it may be beneficial to finetune a self-supervised model on your data before embedding the images. This may be the case when the data is from a specific domain, for example, medical images. The extra training increases the embedding quality as the model can adapt to the specific domain.

The command below will train a self-supervised model on the input images before embedding the images and running the selection algorithm:

from lightly.api import ApiWorkflowClient

# Create the Lightly client to connect to the API.
client = ApiWorkflowClient(token="MY_LIGHTLY_TOKEN", dataset_id="MY_DATASET_ID")

# Schedule a run with training enabled.
client.schedule_compute_worker_run(
    worker_config={
        "enable_training": True
    },
    selection_config={
        "n_samples": 50,
        "strategies": [
            {
                "input": {
                    "type": "EMBEDDINGS"
                },
                "strategy": {
                    "type": "DIVERSITY"
                }
            }
        ]
    },
)

The model is trained for 100 epochs by default. You can adjust the training settings using the lightly_config argument. The most common settings you might want to change in the lightly_config are the max_epochs and num_workers, which determine the maximum number of training epochs and number of parallel data loading processes, respectively:

client.schedule_compute_worker_run(
    worker_config={
        "enable_training": True,
    },
    selection_config={
        "n_samples": 50,
        "strategies": [
            {
                "input": {
                    "type": "EMBEDDINGS"
                },
                "strategy": {
                    "type": "DIVERSITY"
                }
            }
        ]
    },
    lightly_config={
        'loader': {
            'num_workers': -1,
        },
        'trainer': {
            'max_epochs': 100,
        },
    }
)

You can find a complete list of settings that can be adjusted here: Configuration Options.

Checkpoints

Checkpoints from your training process will be stored in the Lightly Platform as artifacts. The last checkpoint from a run can be downloaded using the Lightly Python client:

runs = client.get_compute_worker_runs()
run = runs[-1]  # Get last run on the dataset.

# Alternatively you can get the run from a scheduled run id:
# run = client.get_compute_worker_run_from_scheduled_run(scheduled_run_id=scheduled_run_id)

client.download_compute_worker_run_checkpoint(
    run=run, output_path="artifacts/checkpoint.ckpt"
)

Use Checkpoints from Previous Runs

Checkpoints from previous runs can be reused when scheduling a new run. This allows you to skip the expensive training step. You can also reuse checkpoints from runs on other datasets:

# Create the Lightly client to connect to the API.
client = ApiWorkflowClient(token="MY_LIGHTLY_TOKEN", dataset_id="MY_DATASET_ID")

# Get the read-URL of a previous run on another dataset.
## Option 1: Get the run via the dataset ID
runs = client.get_compute_worker_runs(dataset_id="DATASET_ID_WITH_THE_CHECKPOINT_TO_BE_REUSED")
run = runs[-1]  # Get the last run of the dataset.
## Option 2: Get the run via the run ID:
# run = client.get_compute_worker_run(run_id="MY_RUN_ID")

checkpoint_url = client.get_compute_worker_run_checkpoint_url(run=run)

# Schedule a new run, reusing the checkpoint from another run.
scheduled_run_id = client.schedule_compute_worker_run(
    worker_config={
        "checkpoint": checkpoint_url,
        ...
    },
    selection_config={...},
)

Use Checkpoint from Model Trained Outside the Lightly Worker

If you require a more custom training schedule, you can also train the embedding model outside of the Lightly Worker using your own PyTorch training code. The following code generates a model that is compatible with the Lightly Worker:

from lightly.models import ResNetGenerator, batchnorm
from lightly.models.modules import SimCLRProjectionHead
from torch import nn


class SimCLR(nn.Module):
    def __init__(self, backbone, num_ftrs, out_dim):
        super().__init__()
        self.backbone = backbone
        self.projection_head = SimCLRProjectionHead(num_ftrs, num_ftrs, out_dim)

    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z


def build_model(name="resnet-18", out_dim=128, num_ftrs=32, width=1):
    """Returns a SimCLR model for training."""
    resnet = ResNetGenerator(name=name, width=width)
    last_conv_channels = list(resnet.children())[-1].in_features
    backbone = nn.Sequential(
        batchnorm.get_norm_layer(3, 0),
        *list(resnet.children())[:-1],
        nn.Conv2d(last_conv_channels, num_ftrs, 1),
        nn.AdaptiveAvgPool2d(1),
    )
    model = SimCLR(
        backbone=backbone,
        num_ftrs=num_ftrs,
        out_dim=out_dim,
    )
    return model


model = build_model()

After training, you have to save the model checkpoint to make it loadable by the Lightly Worker:

import torch

torch.save({"state_dict": model.state_dict(prefix="model.")}, "checkpoint.ckpt")

Finally, you have to make the checkpoint available to the worker by providing a read-URL to it. This can be done in a way of your choice. For example, you can upload your checkpoint to a cloud bucket and generate a read-URL for it there. As the Lightly Worker downloads the checkpoint as part of the embedding step, you need to set the expiration of the URL accordingly.

# Schedule a run using a checkpoint provided via read-URL
scheduled_run_id = client.schedule_compute_worker_run(
    worker_config={
        "checkpoint": "https://my-checkpoint_read-url",
        ...
    },
    selection_config={...},
)

🚧

Changing Model Parameters

If you change model parameters, such as name, out_dim, num_ftrs, or width you have to specify them when scheduling a new run. This is done by setting the model parameters in the Lightly configuration:

client.schedule_compute_worker_run(
    worker_config={
        "checkpoint": "checkpoint.ckpt",
        "enable_training": False,
    },
    lightly_config={
        "model": {
            "name": 'resnet-18',
            "out_dim": 128,
            "num_ftrs": 32,
            "width": 1,
        },
    },
    ...
)

See Configuration Options for more details.

Load Model from Checkpoint

You might also want to use the checkpoint independent of the Lightly Worker, for example, to train a classifier on your dataset. Using a pretrained checkpoint often results in models with higher accuracy than training a new model from scratch. The code below demonstrates how the checkpoint can be loaded as PyTorch model in your Python application:

from collections import OrderedDict
import torch
import lightly


def load_checkpoint(
    checkpoint_path, model_name="resnet-18", model_width=1, map_location="cpu"
):
    checkpoint = torch.load(checkpoint_path, map_location=map_location)

    state_dict = OrderedDict()
    for key, value in checkpoint["state_dict"].items():
        if ("projection_head" in key) or ("backbone.7" in key):
            # drop layers used for projection head
            continue
        state_dict[key.replace("model.backbone.", "")] = value

    resnet = lightly.models.ResNetGenerator(name=model_name, width=model_width)
    model = torch.nn.Sequential(
        lightly.models.batchnorm.get_norm_layer(3, 0),
        *list(resnet.children())[:-1],
        torch.nn.AdaptiveAvgPool2d(1),
        torch.nn.Flatten(1),
    )
    try:
        model.load_state_dict(state_dict)
    except RuntimeError:
        raise RuntimeError(
            f"It looks like you tried loading a checkpoint from a model that is not a {model_name} with width={model_width}! "
            f"Please set model_name and model_width to the lightly.model.name and lightly.model.width parameters from the "
            f"configuration you used to run Lightly. The configuration from a Lightly worker run can be found in output_dir/config/config.yaml"
        )
    return model


# Load the model
model = load_checkpoint(checkpoint_path="artifacts/checkpoint.ckpt")


# Example usage
image_batch = torch.rand(16, 3, 224, 224)
out = model(image_batch)
print(out.shape)  # prints: torch.Size([16, 512])


# Creating a classifier from the pre-trained model
num_classes = 10
classifier = torch.nn.Sequential(
    model, torch.nn.Linear(512, num_classes)  # use 2048 instead of 512 for resnet-50
)

out = classifier(image_batch)
print(out.shape)  # prints: torch.Size(16, 10)