Classification with Torchvision’s ResNet

This tutorial demonstrates how to pre-train a ResNet model from Torchvision using LightlyTrain and then fine-tune it for classification using the PyTorch Lightning framework. We will perform both steps on the Human Detection Dataset from Kaggle.

Install LightlyTrain

LightlyTrain can be installed directly from PyPI:

pip install lightly-train

Dataset Preparation

Download the Dataset

The Human Detection Dataset contains 921 PNG images of size 256x256 pixels from videos of humans and no humans. It can be used for training models to detect humans in images, a basic task in industries like security and autonomous driving.

You can download the dataset directly from Kaggle using the following commands (suppose you want the dataset to located in datasets):

mkdir -p datasets
curl -L -o datasets/human-detection-dataset.zip https://www.kaggle.com/api/v1/datasets/download/constantinwerner/human-detection-dataset

and extract the zip file to the dataset directory.

!unzip datasets/human-detection-dataset.zip -d datasets

The resulting dataset directory contains two classes in its subdirectories: 0 for images without humans and 1 for images with humans.

tree -L 1 datasets/"human detection dataset"
> human detection dataset
> ├── 0
> └── 1

Split the Dataset

Before we can train the model, we need to split the dataset into training and validation sets. We will use 80% of the images for training and 20% for validation. The following Python script will create the train and val directories and move the images into their respective subdirectories.

# dataset_split.py
import random
from pathlib import Path

# Suppose you have the dataset in the datasets/ directory
dataset_path = Path("datasets") / "human detection dataset"

# Define class names (subdirectories) in the dataset
classes = ['0', '1']

# Create train and val directories with subdirectories for each class
for split in ['train', 'val']:
    for data_class in classes:
        (dataset_path / split / data_class).mkdir(parents=True, exist_ok=True)

# Process each class folder
for data_class in classes:
    class_dir = dataset_path / data_class
    # List all files in the class directory
    files = list(class_dir.glob("*.png"))
    # Shuffle the file list to randomize the split
    random.shuffle(files)
    # Calculate the split index for 80% training data
    split_idx = int(len(files) * 0.8)

    # Select files for training and validation
    train_files = files[:split_idx]
    val_files = files[split_idx:]

    # Move training files to the train subdirectory
    for file_path in train_files:
        dest_path = dataset_path / 'train' / data_class / file_path.name
        file_path.rename(dest_path)

    # Move validation files to the val subdirectory
    for file_path in val_files:
        dest_path = dataset_path / 'val' / data_class / file_path.name
        file_path.rename(dest_path)

    class_dir.rmdir()

The resulting dataset directory contains two split subdirectories: train and val, each with two classes in their subdirectories.

tree -L 2 datasets/"human detection dataset"
> human detection dataset
> ├── train
> │   ├── 0
> │   └── 1
> └── val
>     ├── 0
>     └── 1

Inspect a few Images

Let’s inspect a few images from each class in the training set to understand the dataset better. We will randomly select two images from each class and display them using Matplotlib.

# inspect_images.py
import random
from pathlib import Path
from typing import List

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

# Suppose you have the dataset in the datasets/ directory
dataset_path = Path("datasets") / "human detection dataset"

# Define paths to the training image directories
train_data_path = dataset_path / "train"
class_0_dir = train_data_path / "0"  # No human
class_1_dir = train_data_path / "1"  # Human

# Function to get two random images from a directory
def get_two_random_images(directory: Path) -> List[str]:
    image_files = list(directory.glob('*.png'))
    selected_files = random.sample(image_files, 2)
    images = []
    
    for file_path in selected_files:
        img = Image.open(file_path)
        images.append((img, file_path.name))
    
    return images

# Get random images from each class
class_0_images = get_two_random_images(class_0_dir)
class_1_images = get_two_random_images(class_1_dir)

# Set up the figure for display
fig, axs = plt.subplots(2, 2, figsize=(10, 10))

# Display images from class 0 (no human)
for i, (img, filename) in enumerate(class_0_images):
    axs[0, i].imshow(np.array(img))
    axs[0, i].set_title(f"No Human: {filename}")
    axs[0, i].axis('off')

# Display images from class 1 (human)
for i, (img, filename) in enumerate(class_1_images):
    axs[1, i].imshow(np.array(img))
    axs[1, i].set_title(f"Human: {filename}")
    axs[1, i].axis('off')

plt.tight_layout()
plt.show()

Human Detection Dataset

Pre-train ResNet with LightlyTrain

We will use LightlyTrain to pre-train a ResNet18 model.

The following scripts or CLI commands will:

  • Initialize a ResNet18 model from Torchvision weights using LightlyTrain.

  • Pre-train the ResNet18 model on the Human Detection Dataset.

  • Export the pre-trained ResNet18 model.

# pretrain_resnet.py
import lightly_train
from pathlib import Path

# Suppose you have the dataset in the datasets/ directory
dataset_path = Path("datasets") / "human detection dataset"

if __name__ == "__main__":
    lightly_train.train(
        out="out/my_experiment",                # Output directory.
        data=dataset_path / "train",            # Directory with images.
        model="torchvision/resnet18",           # Pass the Torchvision model.
        epochs=100,                             # Adjust epochs for faster training.
        batch_size=64,                          # Adjust batch size based on hardware.
    )

lightly-train train out="out/my_experiment" data=datasets/"human detection dataset"/train model="torchvision/resnet18"

Fine-tune ResNet with PyTorch Lightning

We will use PyTorch Lightning to fine-tune the ResNet18 model pre-trained with LightlyTrain on the Human Detection Dataset.

The following Python script will:

  • Load the pre-trained ResNet18 model.

  • Define a PyTorch Lightning module and change the last layer to output two classes.

  • Define a PyTorch Lightning data module with training and validation data loaders.

  • Initialize a PyTorch Lightning trainer.

  • Fine-tune the model on the Human Detection Dataset.

# fine_tune_resnet.py
from pathlib import Path

import pytorch_lightning as pl
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.models import resnet18
from torchvision.transforms.v2 import Compose, Normalize, Resize, ToDtype, ToImage

# Suppose you have the dataset in the datasets/ directory
dataset_path = Path("datasets") / "human detection dataset"

def get_model(
    checkpoint_path: str,
    num_classes: int,
):
    model = resnet18()
    state_dict = torch.load(checkpoint_path, weights_only=True)
    model.load_state_dict(state_dict)

    # Change the last layer for the number of classes
    model.fc = nn.Linear(model.fc.in_features, num_classes)

    return model


class ResNet18Classifier(pl.LightningModule):
    def __init__(self, checkpoint_path, num_classes):
        super().__init__()
        self.save_hyperparameters()

        self.model = get_model(checkpoint_path, num_classes)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)

        return loss

    def configure_optimizers(self):
        optimizer = Adam(self.parameters())
        return optimizer


class HumanClassificationDataModule(pl.LightningDataModule):
    def __init__(self, train_data_path, val_data_path, batch_size=32):
        super().__init__()
        self.train_data_path = train_data_path
        self.val_data_path = val_data_path
        self.batch_size = batch_size

        # Define transforms
        self.train_transform = Compose(
            [
                Resize((224, 224)),  # ResNet18 expects 224x224 images
                ToImage(),
                ToDtype(torch.float32, scale=True),
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

        self.val_transform = Compose(
            [
                Resize((224, 224)),  # ResNet18 expects 224x224 images
                ToImage(),
                ToDtype(torch.float32, scale=True),
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

    def train_dataloader(self):
        train_dataset = ImageFolder(root=self.train_data_path, transform=self.train_transform)
        return DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
        )

    def val_dataloader(self):
        val_dataset = ImageFolder(root=self.val_data_path, transform=self.val_transform)
        return DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
        )

if __name__ == "__main__":
    # Initialize model and data module
    model = ResNet18Classifier(
        checkpoint_path="out/my_experiment/exported_models/exported_last.pt",
        num_classes=2,
    )

    data_module = HumanClassificationDataModule(
        train_data_path=dataset_path / "train",
        val_data_path=dataset_path / "val",
        batch_size=32
    )

    # Initialize trainer
    trainer = pl.Trainer(
        max_epochs=10,
        log_every_n_steps=16,
    )

    # Fine-tune the model
    trainer.fit(model, data_module)

Congratulations! You have successfully pre-trained a model using LightlyTrain and fine-tuned it for classification using PyTorch Lightning.

For more advanced options, explore the LightlyTrain Python API and PyTorch Lightning documentation.

Next Steps

  • Go beyond the default distillation pretraining and experiment with other pre-training methods in LightlyTrain, such as DINO or SimCLR.

  • Try various Torchvision models supported by LightlyTrain.

  • Use the pre-trained model for other tasks, like image embeddings.