Train a Self-Supervised Model

Sometimes it may be beneficial to finetune a self-supervised model on your data before embedding the images. This may be the case when the data is from a specific domain, for example, medical images. The extra training increases the embedding quality as the model can adapt to the specific domain.

The command below will train a self-supervised model on the input images before embedding the images and running the selection algorithm:

from lightly.api import ApiWorkflowClient

# Create the Lightly client to connect to the API.
client = ApiWorkflowClient(token="MY_LIGHTLY_TOKEN", dataset_id="MY_DATASET_ID")

# Schedule a run with training enabled.
client.schedule_compute_worker_run(
    worker_config={
        "enable_training": True
    },
    selection_config={
        "n_samples": 50,
        "strategies": [
            {
                "input": {
                    "type": "EMBEDDINGS"
                },
                "strategy": {
                    "type": "DIVERSITY"
                }
            }
        ]
    },
)

The model is trained for 100 epochs by default. You can adjust the training settings using the lightly_config argument. The most common settings you might want to change in the lightly_config are the max_epochs and num_workers, which determine the maximum number of training epochs and number of parallel data loading processes, respectively:

client.schedule_compute_worker_run(
    worker_config={
        "enable_training": True,
    },
    selection_config={
        "n_samples": 50,
        "strategies": [
            {
                "input": {
                    "type": "EMBEDDINGS"
                },
                "strategy": {
                    "type": "DIVERSITY"
                }
            }
        ]
    },
    lightly_config={
        'loader': {
            'num_workers': -1,
        },
        'trainer': {
            'max_epochs': 100,
        },
    }
)

You can find a complete list of settings that can be adjusted here: Configuration Options.

Checkpoints

Checkpoints from your training process will be stored in the Lightly Platform as artifacts. The last checkpoint from a run can be downloaded using the Lightly Python client:

runs = client.get_compute_worker_runs()
run = runs[-1]  # Get last run on the dataset.

# Alternatively you can get the run from a scheduled run id:
# run = client.get_compute_worker_run_from_scheduled_run(scheduled_run_id=scheduled_run_id)

client.download_compute_worker_run_checkpoint(run=run, output_path="artifacts/checkpoint.ckpt")

Use Checkpoints from Previous Runs

Checkpoints from previous runs can be reused when scheduling a new run. This allows you to skip the expensive training step. To make the checkpoint available to the Lightly Worker, download it into the shared directory ({SHARED_DIR}) of the machine running the Lightly Worker docker container:

client.download_compute_worker_run_checkpoint(run=run, output_path=f"{SHARED_DIR}/checkpoint.ckpt")

The directory is called shared directory because its contents will be shared with the Lightly Worker docker container. The directory has to be mounted to the docker container when starting the Lightly Worker:

docker run --shm-size="1024m" --gpus all --rm -it \
    -v ${SHARED_DIR}:/shared_dir \
    -e LIGHTLY_TOKEN={MY_LIGHTLY_TOKEN} \
    lightly/worker:latest \
    worker.worker_id={MY_WORKER_ID}

The checkpoint in the shared directory is now available when scheduling a new run and can be specified as follows:

client.schedule_compute_worker_run(
    worker_config={
        "checkpoint": "checkpoint.ckpt",
        "enable_training": False,               # Set to True to continue training
    },
    selection_config={
        "n_samples": 50,
        "strategies": [
            {
                "input": {
                    "type": "EMBEDDINGS"
                },
                "strategy": {
                    "type": "DIVERSITY"
                }
            }
        ]
    }
)

Load Model from Checkpoint

You might also want to use the checkpoint independent of the Lightly Worker, for example, to train a classifier on your dataset. Using a pretrained checkpoint often results in models with higher accuracy than training a new model from scratch. The code below demonstrates how the checkpoint can be loaded as PyTorch model in your Python application:

from collections import OrderedDict
import torch
import lightly


def load_checkpoint(
    checkpoint_path, model_name="resnet-18", model_width=1, map_location="cpu"
):
    checkpoint = torch.load(checkpoint_path, map_location=map_location)

    state_dict = OrderedDict()
    for key, value in checkpoint["state_dict"].items():
        if ("projection_head" in key) or ("backbone.7" in key):
            # drop layers used for projection head
            continue
        state_dict[key.replace("model.backbone.", "")] = value

    resnet = lightly.models.ResNetGenerator(name=model_name, width=model_width)
    model = torch.nn.Sequential(
        lightly.models.batchnorm.get_norm_layer(3, 0),
        *list(resnet.children())[:-1],
        torch.nn.AdaptiveAvgPool2d(1),
        torch.nn.Flatten(1),
    )
    try:
        model.load_state_dict(state_dict)
    except RuntimeError:
        raise RuntimeError(
            f"It looks like you tried loading a checkpoint from a model that is not a {model_name} with width={model_width}! "
            f"Please set model_name and model_width to the lightly.model.name and lightly.model.width parameters from the "
            f"configuration you used to run Lightly. The configuration from a Lightly worker run can be found in output_dir/config/config.yaml"
        )
    return model


# Load the model
model = load_checkpoint(checkpoint_path="artifacts/checkpoint.ckpt")


# Example usage
image_batch = torch.rand(16, 3, 224, 224)
out = model(image_batch)
print(out.shape)  # prints: torch.Size([16, 512])


# Creating a classifier from the pre-trained model
num_classes = 10
classifier = torch.nn.Sequential(
    model, torch.nn.Linear(512, num_classes)  # use 2048 instead of 512 for resnet-50
)

out = classifier(image_batch)
print(out.shape)  # prints: torch.Size(16, 10)