Semantic SegmentationΒΆ

Note

πŸ”₯ New: LightlyTrain now supports training DINOv2 models for semantic segmentation with the train_semantic_segmentation function! The method is based on the state-of-the-art segmentation model EoMT by Kerssies et al. and reaches 58.4% mIoU on the ADE20K dataset.

Training a semantic segmentation model with LightlyTrain is straightforward and only requires a few lines of code. The dataset must follow the ADE20K format with RGB images and integer masks in PNG format. See data for more details.

import lightly_train

if __name__ == "__main__":
    lightly_train.train_semantic_segmentation(
        out="out/my_experiment",
        model="dinov2/vitl14-eomt",
        data={
            "train": {
                "images": "my_data_dir/train/images",   # Path to training images
                "masks": "my_data_dir/train/masks",     # Path to training masks
            },
            "val": {
                "images": "my_data_dir/val/images",     # Path to validation images
                "masks": "my_data_dir/val/masks",       # Path to validation masks
            },
            "classes": {                                # Classes in the dataset                    
                0: "background",
                1: "car",
                2: "bicycle",
                # ...
            },
            # Optional, classes that are in the dataset but should be ignored during
            # training.
            "ignore_classes": [0], 
        },
    )

After the training completes you can load the model for inference like this:

import lightly_train

model = lightly_train.load_model_from_checkpoint(
    "out/my_experiment/checkpoints/last.ckpt"
)
masks = model.predict("path/to/image.jpg")

And visualize the predicted masks like this:

import torch
import matplotlib.pyplot as plt
from torchvision.io import read_image
from torchvision.utils import draw_segmentation_masks

image = read_image("path/to/image.jpg")
masks = torch.stack([masks == class_id for class_id in masks.unique()])
image_with_masks = draw_segmentation_masks(image, masks, alpha=0.6)
plt.imshow(image_with_masks.permute(1, 2, 0))

The predicted masks have shape (height, width) and each value corresponds to a class ID as defined in the classes dictionary in the dataset.

OutΒΆ

The out argument specifies the output directory where all training logs, model exports, and checkpoints are saved. It looks like this after training:

out/my_experiment
β”œβ”€β”€ checkpoints
β”‚   └── last.ckpt                                       # Last checkpoint
β”œβ”€β”€ events.out.tfevents.1721899772.host.1839736.0       # TensorBoard logs
└── train.log                                           # Training logs

The final model checkpoint is saved to out/my_experiment/checkpoints/last.ckpt.

Tip

Create a new output directory for each experiment to keep training logs, model exports, and checkpoints organized.

DataΒΆ

LightlyTrain supports training semantic segmentation models with images and masks. Every image must have a corresponding mask with the same filename except for the file extension. The masks must be PNG images in grayscale integer format, where each pixel value corresponds to a class ID.

The following image formats are supported:

  • jpg

  • jpeg

  • png

  • ppm

  • bmp

  • pgm

  • tif

  • tiff

  • webp

The following mask formats are supported:

  • png

Example of a directory structure with training and validation images and masks:

my_data_dir
β”œβ”€β”€ train
β”‚   β”œβ”€β”€ images
β”‚   β”‚   β”œβ”€β”€ image0.jpg
β”‚   β”‚   └── image1.jpg
β”‚   └── masks
β”‚       β”œβ”€β”€ image0.png
β”‚       └── image1.png
└── val
    β”œβ”€β”€ images
    |  β”œβ”€β”€ image2.jpg
    |  └── image3.jpg
    └── masks
       β”œβ”€β”€ image2.png
       └── image3.png

To train with this folder structure, set the data argument like this:

import lightly_train

if __name__ == "__main__":
    lightly_train.train_semantic_segmentation(
        out="out/my_experiment",
        model="dinov2/vitl14-eomt",
        data={
            "train": {
                "images": "my_data_dir/train/images",   # Path to training images
                "masks": "my_data_dir/train/masks",     # Path to training masks
            },
            "val": {
                "images": "my_data_dir/val/images",     # Path to validation images
                "masks": "my_data_dir/val/masks",       # Path to validation masks
            },
            "classes": {                                # Classes in the dataset                    
                0: "background",
                1: "car",
                2: "bicycle",
                # ...
            },
            # Optional, classes that are in the dataset but should be ignored during
            # training.
            "ignore_classes": [0], 
        },
    )

The classes in the dataset must be specified in the classes dictionary. The keys are the class IDs and the values are the class names. The class IDs must be identical to the values in the mask images. All possible class IDs must be specified, otherwise LightlyTrain will raise an error if an unknown class ID is encountered. If you would like to ignore some classes during training, you specify their class IDs in the ignore_classes argument. The trained model will then not predict these classes.

ModelΒΆ

The model argument defines the model used for semantic segmentation training. The following models are available:

  • dinov2/vits14-eomt

  • dinov2/vitb14-eomt

  • dinov2/vitl14-eomt

  • dinov2/vitg14-eomt

All DINOv2 models are pretrained by Meta.

LoggingΒΆ

Logging is configured with the logger_args argument. The following loggers are supported:

  • mlflow: Logs training metrics to MLflow (disabled by default, requires MLflow to be installed)

  • tensorboard: Logs training metrics to TensorBoard (enabled by default, requires TensorBoard to be installed)

MLflowΒΆ

Important

MLflow must be installed with pip install "lightly-train[mlflow]".

The mlflow logger can be configured with the following arguments:

import lightly_train

if __name__ == "__main__":
    lightly_train.train_semantic_segmentation(
        out="out/my_experiment",
        model="dinov2/vitl14-eomt",
        data={
            # ...
        },
        logger_args={
            "mlflow": {
                "experiment_name": "my_experiment",
                "run_name": "my_run",
                "tracking_uri": "tracking_uri",
            },
        },
    )

TensorBoardΒΆ

TensorBoard logs are automatically saved to the output directory. Run TensorBoard in a new terminal to visualize the training progress:

tensorboard --logdir out/my_experiment

Disable the TensorBoard logger with:

logger_args={"tensorboard": None}

Pretrain and Fine-tune a Semantic Segmentation ModelΒΆ

To further improve the performance of your semantic segmentation model, you can first pretrain a DINOv2 model on unlabeled data using self-supervised learning and then fine-tune it on your segmentation dataset. This is especially useful if your dataset is only partially labeled or if you have access to a large amount of unlabeled data.

The following example shows how to pretrain and fine-tune the model. Check out the page on DINOv2 to learn more about pretraining DINOv2 models on unlabeled data.

import lightly_train

if __name__ == "__main__":
    # Pretrain a DINOv2 model.
    lightly_train.train(
        out="out/my_pretrain_experiment",
        data="my_pretrain_data_dir",
        model="dinov2/vitl14",
        method="dinov2",
        epochs=100, # We recommend epochs = 125000 * batch size // dataset size
    )

    # Fine-tune the DINOv2 model for semantic segmentation.
    lightly_train.train_semantic_segmentation(
        out="out/my_experiment",
        model="dinov2/vitl14-eomt",
        model_args={
            # Path to your pretrained DINOv2 model.
            "backbone_weights": "out/my_pretrain_experiment/exported_models/exported_last.pt",
        },
        data={
            "train": {
                "images": "my_data_dir/train/images",   # Path to training images
                "masks": "my_data_dir/train/masks",     # Path to training masks
            },
            "val": {
                "images": "my_data_dir/val/images",     # Path to validation images
                "masks": "my_data_dir/val/masks",       # Path to validation masks
            },
            "classes": {                                # Classes in the dataset                    
                0: "background",
                1: "car",
                2: "bicycle",
                # ...
            },
            # Optional, classes that are in the dataset but should be ignored during
            # training.
            "ignore_classes": [0], 
        },
    )