Train a Self-Supervised Model
Sometimes it may be beneficial to finetune a self-supervised model on your data before embedding the images. This may be the case when the data is from a specific domain, for example, medical images. The extra training increases the embedding quality as the model can adapt to the specific domain.
The command below will train a self-supervised model on the input images before embedding the images and running the selection algorithm:
from lightly.api import ApiWorkflowClient
# Create the Lightly client to connect to the API.
client = ApiWorkflowClient(token="MY_LIGHTLY_TOKEN", dataset_id="MY_DATASET_ID")
# Schedule a run with training enabled.
client.schedule_compute_worker_run(
worker_config={
"enable_training": True
},
selection_config={
"n_samples": 50,
"strategies": [
{
"input": {
"type": "EMBEDDINGS"
},
"strategy": {
"type": "DIVERSITY"
}
}
]
},
)
The model is trained for 100 epochs by default. You can adjust the training settings using the lightly_config
argument. The most common settings you might want to change in the lightly_config
are the max_epochs
and num_workers
, which determine the maximum number of training epochs and number of parallel data loading processes, respectively:
client.schedule_compute_worker_run(
worker_config={
"enable_training": True,
},
selection_config={
"n_samples": 50,
"strategies": [
{
"input": {
"type": "EMBEDDINGS"
},
"strategy": {
"type": "DIVERSITY"
}
}
]
},
lightly_config={
'loader': {
'num_workers': -1,
},
'trainer': {
'max_epochs': 100,
},
}
)
You can find a complete list of settings that can be adjusted here: Configuration Options.
Checkpoints
Checkpoints from your training process will be stored in the Lightly Platform as artifacts. The last checkpoint from a run can be downloaded using the Lightly Python client:
runs = client.get_compute_worker_runs()
run = runs[-1] # Get last run on the dataset.
# Alternatively you can get the run from a scheduled run id:
# run = client.get_compute_worker_run_from_scheduled_run(scheduled_run_id=scheduled_run_id)
client.download_compute_worker_run_checkpoint(
run=run, output_path="artifacts/checkpoint.ckpt"
)
Use Checkpoints from Previous Runs
Checkpoints from previous runs can be reused when scheduling a new run. This allows you to skip the expensive training step. You can also reuse checkpoints from runs on other datasets:
# Create the Lightly client to connect to the API.
client = ApiWorkflowClient(token="MY_LIGHTLY_TOKEN", dataset_id="MY_DATASET_ID")
# Get the read-URL of a previous run on another dataset.
## Option 1: Get the run via the dataset ID
runs = client.get_compute_worker_runs(dataset_id="DATASET_ID_WITH_THE_CHECKPOINT_TO_BE_REUSED")
run = runs[-1] # Get the last run of the dataset.
## Option 2: Get the run via the run ID:
# run = client.get_compute_worker_run(run_id="MY_RUN_ID")
checkpoint_url = client.get_compute_worker_run_checkpoint_url(run=run)
# Schedule a new run, reusing the checkpoint from another run.
scheduled_run_id = client.schedule_compute_worker_run(
worker_config={
"checkpoint": checkpoint_url,
...
},
selection_config={...},
)
Use Checkpoint from Model Trained Outside the Lightly Worker
If you require a more custom training schedule, you can also train the embedding model outside of the Lightly Worker using your own PyTorch training code. The following code generates a model that is compatible with the Lightly Worker:
from lightly.models import ResNetGenerator, batchnorm
from lightly.models.modules import SimCLRProjectionHead
from torch import nn
class SimCLR(nn.Module):
def __init__(self, backbone, num_ftrs, out_dim):
super().__init__()
self.backbone = backbone
self.projection_head = SimCLRProjectionHead(num_ftrs, num_ftrs, out_dim)
def forward(self, x):
x = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(x)
return z
def build_model(name="resnet-18", out_dim=128, num_ftrs=32, width=1):
"""Returns a SimCLR model for training."""
resnet = ResNetGenerator(name=name, width=width)
last_conv_channels = list(resnet.children())[-1].in_features
backbone = nn.Sequential(
batchnorm.get_norm_layer(3, 0),
*list(resnet.children())[:-1],
nn.Conv2d(last_conv_channels, num_ftrs, 1),
nn.AdaptiveAvgPool2d(1),
)
model = SimCLR(
backbone=backbone,
num_ftrs=num_ftrs,
out_dim=out_dim,
)
return model
model = build_model()
After training, you have to save the model checkpoint to make it loadable by the Lightly Worker:
import torch
torch.save({"state_dict": model.state_dict(prefix="model.")}, "checkpoint.ckpt")
Finally, you have to make the checkpoint available to the worker by providing a read-URL to it. This can be done in a way of your choice. For example, you can upload your checkpoint to a cloud bucket and generate a read-URL for it there. As the Lightly Worker downloads the checkpoint as part of the embedding step, you need to set the expiration of the URL accordingly.
# Schedule a run using a checkpoint provided via read-URL
scheduled_run_id = client.schedule_compute_worker_run(
worker_config={
"checkpoint": "https://my-checkpoint_read-url",
...
},
selection_config={...},
)
Changing Model Parameters
If you change model parameters, such as
name
,out_dim
,num_ftrs
, orwidth
you have to specify them when scheduling a new run. This is done by setting the model parameters in the Lightly configuration:client.schedule_compute_worker_run( worker_config={ "checkpoint": "checkpoint.ckpt", "enable_training": False, }, lightly_config={ "model": { "name": 'resnet-18', "out_dim": 128, "num_ftrs": 32, "width": 1, }, }, ... )
See Configuration Options for more details.
Load Model from Checkpoint
You might also want to use the checkpoint independent of the Lightly Worker, for example, to train a classifier on your dataset. Using a pretrained checkpoint often results in models with higher accuracy than training a new model from scratch. The code below demonstrates how the checkpoint can be loaded as PyTorch model in your Python application:
from collections import OrderedDict
import torch
import lightly
def load_checkpoint(
checkpoint_path, model_name="resnet-18", model_width=1, map_location="cpu"
):
checkpoint = torch.load(checkpoint_path, map_location=map_location)
state_dict = OrderedDict()
for key, value in checkpoint["state_dict"].items():
if ("projection_head" in key) or ("backbone.7" in key):
# drop layers used for projection head
continue
state_dict[key.replace("model.backbone.", "")] = value
resnet = lightly.models.ResNetGenerator(name=model_name, width=model_width)
model = torch.nn.Sequential(
lightly.models.batchnorm.get_norm_layer(3, 0),
*list(resnet.children())[:-1],
torch.nn.AdaptiveAvgPool2d(1),
torch.nn.Flatten(1),
)
try:
model.load_state_dict(state_dict)
except RuntimeError:
raise RuntimeError(
f"It looks like you tried loading a checkpoint from a model that is not a {model_name} with width={model_width}! "
f"Please set model_name and model_width to the lightly.model.name and lightly.model.width parameters from the "
f"configuration you used to run Lightly. The configuration from a Lightly worker run can be found in output_dir/config/config.yaml"
)
return model
# Load the model
model = load_checkpoint(checkpoint_path="artifacts/checkpoint.ckpt")
# Example usage
image_batch = torch.rand(16, 3, 224, 224)
out = model(image_batch)
print(out.shape) # prints: torch.Size([16, 512])
# Creating a classifier from the pre-trained model
num_classes = 10
classifier = torch.nn.Sequential(
model, torch.nn.Linear(512, num_classes) # use 2048 instead of 512 for resnet-50
)
out = classifier(image_batch)
print(out.shape) # prints: torch.Size(16, 10)
Updated 7 days ago