Semantic SegmentationΒΆ
Note
π₯ New: LightlyTrain now supports training DINOv2 models for semantic segmentation
with the train_semantic_segmentation
function! The method is based on the
state-of-the-art segmentation model EoMT by
Kerssies et al. and reaches 58.4% mIoU on the ADE20K dataset.
Training a semantic segmentation model with LightlyTrain is straightforward and only requires a few lines of code. The dataset must follow the ADE20K format with RGB images and integer masks in PNG format. See data for more details.
import lightly_train
if __name__ == "__main__":
lightly_train.train_semantic_segmentation(
out="out/my_experiment",
model="dinov2/vitl14-eomt",
data={
"train": {
"images": "my_data_dir/train/images", # Path to training images
"masks": "my_data_dir/train/masks", # Path to training masks
},
"val": {
"images": "my_data_dir/val/images", # Path to validation images
"masks": "my_data_dir/val/masks", # Path to validation masks
},
"classes": { # Classes in the dataset
0: "background",
1: "car",
2: "bicycle",
# ...
},
# Optional, classes that are in the dataset but should be ignored during
# training.
"ignore_classes": [0],
},
)
After the training completes you can load the model for inference like this:
import lightly_train
model = lightly_train.load_model_from_checkpoint(
"out/my_experiment/checkpoints/last.ckpt"
)
masks = model.predict("path/to/image.jpg")
And visualize the predicted masks like this:
import torch
import matplotlib.pyplot as plt
from torchvision.io import read_image
from torchvision.utils import draw_segmentation_masks
image = read_image("path/to/image.jpg")
masks = torch.stack([masks == class_id for class_id in masks.unique()])
image_with_masks = draw_segmentation_masks(image, masks, alpha=0.6)
plt.imshow(image_with_masks.permute(1, 2, 0))
The predicted masks have shape (height, width)
and each value corresponds to a class
ID as defined in the classes
dictionary in the dataset.
OutΒΆ
The out
argument specifies the output directory where all training logs, model exports,
and checkpoints are saved. It looks like this after training:
out/my_experiment
βββ checkpoints
β βββ last.ckpt # Last checkpoint
βββ events.out.tfevents.1721899772.host.1839736.0 # TensorBoard logs
βββ train.log # Training logs
The final model checkpoint is saved to out/my_experiment/checkpoints/last.ckpt
.
Tip
Create a new output directory for each experiment to keep training logs, model exports, and checkpoints organized.
DataΒΆ
LightlyTrain supports training semantic segmentation models with images and masks. Every image must have a corresponding mask with the same filename except for the file extension. The masks must be PNG images in grayscale integer format, where each pixel value corresponds to a class ID.
The following image formats are supported:
jpg
jpeg
png
ppm
bmp
pgm
tif
tiff
webp
The following mask formats are supported:
png
Example of a directory structure with training and validation images and masks:
my_data_dir
βββ train
β βββ images
β β βββ image0.jpg
β β βββ image1.jpg
β βββ masks
β βββ image0.png
β βββ image1.png
βββ val
βββ images
| βββ image2.jpg
| βββ image3.jpg
βββ masks
βββ image2.png
βββ image3.png
To train with this folder structure, set the data
argument like this:
import lightly_train
if __name__ == "__main__":
lightly_train.train_semantic_segmentation(
out="out/my_experiment",
model="dinov2/vitl14-eomt",
data={
"train": {
"images": "my_data_dir/train/images", # Path to training images
"masks": "my_data_dir/train/masks", # Path to training masks
},
"val": {
"images": "my_data_dir/val/images", # Path to validation images
"masks": "my_data_dir/val/masks", # Path to validation masks
},
"classes": { # Classes in the dataset
0: "background",
1: "car",
2: "bicycle",
# ...
},
# Optional, classes that are in the dataset but should be ignored during
# training.
"ignore_classes": [0],
},
)
The classes in the dataset must be specified in the classes
dictionary. The keys
are the class IDs and the values are the class names. The class IDs must be identical to
the values in the mask images. All possible class IDs must be specified, otherwise
LightlyTrain will raise an error if an unknown class ID is encountered. If you would
like to ignore some classes during training, you specify their class IDs in the
ignore_classes
argument. The trained model will then not predict these classes.
ModelΒΆ
The model
argument defines the model used for semantic segmentation training. The
following models are available:
dinov2/vits14-eomt
dinov2/vitb14-eomt
dinov2/vitl14-eomt
dinov2/vitg14-eomt
All DINOv2 models are pretrained by Meta.
LoggingΒΆ
Logging is configured with the logger_args
argument. The following loggers are
supported:
mlflow
: Logs training metrics to MLflow (disabled by default, requires MLflow to be installed)tensorboard
: Logs training metrics to TensorBoard (enabled by default, requires TensorBoard to be installed)
MLflowΒΆ
Important
MLflow must be installed with pip install "lightly-train[mlflow]"
.
The mlflow logger can be configured with the following arguments:
import lightly_train
if __name__ == "__main__":
lightly_train.train_semantic_segmentation(
out="out/my_experiment",
model="dinov2/vitl14-eomt",
data={
# ...
},
logger_args={
"mlflow": {
"experiment_name": "my_experiment",
"run_name": "my_run",
"tracking_uri": "tracking_uri",
},
},
)
TensorBoardΒΆ
TensorBoard logs are automatically saved to the output directory. Run TensorBoard in a new terminal to visualize the training progress:
tensorboard --logdir out/my_experiment
Disable the TensorBoard logger with:
logger_args={"tensorboard": None}
Pretrain and Fine-tune a Semantic Segmentation ModelΒΆ
To further improve the performance of your semantic segmentation model, you can first pretrain a DINOv2 model on unlabeled data using self-supervised learning and then fine-tune it on your segmentation dataset. This is especially useful if your dataset is only partially labeled or if you have access to a large amount of unlabeled data.
The following example shows how to pretrain and fine-tune the model. Check out the page on DINOv2 to learn more about pretraining DINOv2 models on unlabeled data.
import lightly_train
if __name__ == "__main__":
# Pretrain a DINOv2 model.
lightly_train.train(
out="out/my_pretrain_experiment",
data="my_pretrain_data_dir",
model="dinov2/vitl14",
method="dinov2",
epochs=100, # We recommend epochs = 125000 * batch size // dataset size
)
# Fine-tune the DINOv2 model for semantic segmentation.
lightly_train.train_semantic_segmentation(
out="out/my_experiment",
model="dinov2/vitl14-eomt",
model_args={
# Path to your pretrained DINOv2 model.
"backbone_weights": "out/my_pretrain_experiment/exported_models/exported_last.pt",
},
data={
"train": {
"images": "my_data_dir/train/images", # Path to training images
"masks": "my_data_dir/train/masks", # Path to training masks
},
"val": {
"images": "my_data_dir/val/images", # Path to validation images
"masks": "my_data_dir/val/masks", # Path to validation masks
},
"classes": { # Classes in the dataset
0: "background",
1: "car",
2: "bicycle",
# ...
},
# Optional, classes that are in the dataset but should be ignored during
# training.
"ignore_classes": [0],
},
)