Monocular Depth Estimation with fastai U-Net (Advanced)¶
Important
This tutorial requires substantial computational resources. We recommend at least 4 x RTX-4090 GPUs (or comparable) and approximately 3-4 days of training time.
This advanced tutorial demonstrates how to pretrain and fine-tune a U-Net from fast.ai for monocular depth estimation while exploring the customization capabilities of LightlyTrain. We will pre-train two ResNet-50 encoders with different augmentation settings to analyze their impact on model performance.
To begin, install the required dependencies:
pip install lightly-train fastai
The tutorial consists of three main steps:
Dataset acquisition and preprocessing for pretraining and fine-tuning
Pretraining of two U-Net encoders using LightlyTrain with distinct augmentation configurations
Fine-tuning and performance comparison of both networks
Data Downloading and Processing¶
For this implementation, we utilize two complementary datasets: MegaDepth for pretraining and DIODE for fine-tuning. MegaDepth provides a comprehensive collection of outdoor scenes with synthetic depth maps derived from structure-from-motion reconstruction. While the synthetic depth maps aren’t used during pre-training, the dataset’s extensive outdoor scene distribution aligns well with our target domain. DIODE complements this with high-precision LiDAR-scanned ground-truth depth maps, ensuring accurate supervision during fine-tuning.
To obtain the MegaDepth dataset run the following command (approximately 200GB):
Note
Due to the lengthy download process we recommend using a terminal multiplexer such as tmux
.
wget https://www.cs.cornell.edu/projects/megadepth/dataset/Megadepth_v1/MegaDepth_v1.tar.gz
For the DIODE dataset, download both training and validation splits (approximately 110GB combined):
wget http://diode-dataset.s3.amazonaws.com/train.tar.gz
wget http://diode-dataset.s3.amazonaws.com/val.tar.gz
To inspect the characteristics of both datasets, we can visualize representative samples:
import glob
from pathlib import Path
from random import shuffle
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
root = "/datasets/MegaDepthv1" # replace with your dataset root
root = Path(root)
imgs = glob.glob(f"{str(root)}/**/*.jpg", recursive=True)
shuffle(imgs)
print(f"Total images: {len(imgs)}") # MegaDepth contains 128228 images
imgs = [np.array(Image.open(img)) for img in imgs[:10]]
fig, axs = plt.subplots(2, 5, figsize=(20, 8))
for i, img in enumerate(imgs):
ax = axs[i // 5, i % 5]
ax.imshow(img)
ax.axis("off")
ax.set_title(f"img {i}")
plt.tight_layout()
plt.show()
import glob
from pathlib import Path
from random import shuffle
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
root = "/datasets/DIODE/train/outdoor" # replace this with your dataset root
root = Path(root)
imgs = glob.glob(f"{str(root)}/**/*.png", recursive=True)
shuffle(imgs)
print(f"Total Outdoors Train Images: {len(imgs)}") # DIODE has 16884 outdoor images
imgs = imgs[:5]
corr_depths = [np.load(elem.replace("image", "depth").replace(".png", "_depth.npy")) for elem in imgs]
imgs = [np.array(Image.open(img)) for img in imgs]
fig, axs = plt.subplots(2, 5, figsize=(20, 8))
for i, img in enumerate(imgs):
ax = axs[0, i]
ax.imshow(img)
ax.axis("off")
ax.set_title(f"img {i}")
for i, depth in enumerate(corr_depths):
ax = axs[1, i]
ax.imshow(depth, cmap="viridis")
ax.axis("off")
ax.set_title(f"depth {i}")
plt.tight_layout()
plt.show()
Pretraining¶
The key to effective pretraining lies in the augmentation strategy. Looking at the MegaDepth dataset, we observe a consistent spatial hierarchy - objects at the top of images are typically further away than those at the bottom (consider the sky-to-ground relationship). This spatial consistency means we should avoid training a rotation-invariant model, unlike in scenarios with satellite or aerial imagery where rotation-invariance would be desirable.
To empirically demonstrate the impact of the augmentation choices, we’ll train two encoders:
One with aggressive rotations (90°) and vertical flips.
One with conservative tilts (15°) and no vertical flips.
Those parameters can be adjusted with lightly_train.train
’s transform_args
argument, which expects a dictionary of augmentation parameters.
# pretrain.py
import lightly_train
# Change this to turn on the aggressive rotations.
ROTATION_OFF = True
def get_transform_args(rotation_off: bool) -> dict:
if ROTATION_OFF:
transform_args = {
"random_flip": {
"vertical_prob": 0.0,
"horizontal_prob": 0.5,
},
"random_rotation": {
"prob": 1.0,
"degrees": 15,
}
}
else:
transform_args = {
"random_flip": {
"vertical_prob": 0.5,
"horizontal_prob": 0.5,
},
"random_rotation": {
"prob": 1.0,
"degrees": 90,
}
}
return transform_args
if __name__ == "__main__":
lightly_train.train(
out=f"pretrain_logs/megadepth_rotationOff{ROTATION_OFF}",
data="/datasets/MegaDepthv1",
model="torchvision/resnet50",
epochs=500,
transform_args=get_transform_args(),
)
Fine-tuning¶
For fine-tuning, we implement a custom depth estimation pipeline using PyTorch Lightning. While fast.ai provides excellent high-level abstractions for a lot of downstream tasks, depth estimation is not available out-of-the-box. Let’s start by implementing our model, which inherits from LightningModule
.
# model.py
import torch.nn.functional as F
from pytorch_lightning import LightningModule
from torch.nn import Module
from torch.optim import Adam
class DepthUnet(LightningModule):
def __init__(self, unet: Module):
super().__init__()
self.unet = unet
self.save_hyperparameters()
def training_step(self, batch, batch_idx):
out = self(batch["image"])
loss = F.mse_loss(out, batch["depth"])
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
out = self(batch["image"])
loss = F.mse_loss(out, batch["depth"])
self.log("val_loss", loss)
return loss
def forward(self, x):
x = self.unet(x)
return x
def configure_optimizers(self):
optim = Adam(self.parameters(), lr=self.hparams.learning_rate)
return optim
Our model implementation uses MSE loss, which while simple, is effective for depth estimation when combined with proper normalization. As you can see, the batch is supposed to arrive in the model as a dictionary, for which we will implement a custom dataset in the next step.
# datasets.py
import glob
from typing import Callable
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
class DIODEDepthDataset(Dataset):
def __init__(
self,
data_dir: str,
split: str,
transform: Callable = None,
outdoor_only: bool = False,
indoor_only: bool = False
):
if outdoor_only and indoor_only:
raise ValueError("Cannot specify both outdoor_only and indoor_only")
if split not in ["train", "val"]:
raise ValueError("split must be 'train' or 'val'")
if outdoor_only:
self.imgs = sorted(glob.glob(f"{data_dir}/{split}/outdoor/**/*.png", recursive=True))
self.depths = sorted(glob.glob(f"{data_dir}/{split}/outdoor/**/*_depth.npy", recursive=True))
elif indoor_only:
self.imgs = sorted(glob.glob(f"{data_dir}/{split}/indoors/**/*.png", recursive=True))
self.depths = sorted(glob.glob(f"{data_dir}/{split}/indoors/**/*_depth.npy", recursive=True))
else:
self.imgs = sorted(glob.glob(f"{data_dir}/{split}/**/*.png", recursive=True))
self.depths = sorted(glob.glob(f"{data_dir}/{split}/**/*_depth.npy", recursive=True))
self.transform = transform
assert len(self.imgs) == len(self.depths), "Mismatch in number of images and depth maps"
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
img = self.imgs[idx]
depth = self.depths[idx]
img = np.array(Image.open(img).convert("RGB"))
depth = np.load(depth)
img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
depth = torch.from_numpy(depth).permute(2, 0, 1).float()
if self.transform:
img = self.transform(torch.cat([img, depth], dim=0))
_img = img[:3]
depth = img[3:4]
img = _img
return {"image": img, "depth": depth}
We focus on outdoor scenes to maintain domain alignment with our MegaDepth pretraining. For the augmentation pipeline we will stay conservative, only allowing slight rotational corrections (±15°) to account for camera tilt while preserving the crucial vertical spatial relationships in depth estimation. This will also make any performance differences between the encoders attributable to the different pretraining strategies.
With this we can finalize our fine-tuning script (make sure to have CKPT_PATH
point to one your pretrained checkpoints, or set it to None
for fine-tuning from scratch):
# finetune.py
import torch
import torchvision.transforms as T
from datasets import DIODEDepthDataset
from fastai.vision.models.unet import DynamicUnet
from model import DepthUnet
from pytorch_lightning import Trainer
from torch.nn import Module, Sequential
from torch.utils.data import DataLoader
from torchvision import models
# Change this to point to your LightlyTrain pretrained model.
CKPT_PATH = "<path-to-pretrained-model>"
def get_train_transform():
return T.Compose([
T.RandomRotation(degrees=15),
T.RandomResizedCrop(size=(768, 768), scale=(0.2, 0.9)),
T.RandomHorizontalFlip(),
])
def get_val_transform():
return T.Compose([
T.Resize(size=(768, 768)),
])
def init_model(ckpt_path: str | None, scratch: bool):
encoder = models.resnet50()
if not scratch:
state_dict = torch.load(ckpt_path)
# make sure that some keys match
assert any(k in state_dict.keys() for k in encoder.state_dict().keys()), "No matching keys found in the checkpoint"
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
unet = DynamicUnet(Sequential(*list(encoder.children())[:-2]), n_out=1, img_size=(768, 768))
unet.train()
return unet
def finetune_unet(
unet: Module,
data_dir: str,
batch_size: int = 16,
learning_rate: float = 1e-4,
max_epochs: int = 10,
num_workers: int = 4,
):
train_transform = get_train_transform()
val_transform = get_val_transform()
train_dataset = DIODEDepthDataset(data_dir, "train", transform=train_transform, outdoor_only=True)
val_dataset = DIODEDepthDataset(data_dir, "val", transform=val_transform, outdoor_only=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
model = DepthUnet(unet)
model.hparams.learning_rate = learning_rate
trainer = Trainer(
max_epochs=max_epochs,
)
trainer.fit(model, train_loader, val_loader)
if __name__ == "__main__":
unet = init_model(CKPT_PATH, False)
finetune_unet(
unet = unet,
data_dir = "/datasets/DIODE",
batch_size = 32,
learning_rate = 1e-4,
max_epochs = 50,
num_workers = 4,
)
print("Training completed! 🥳")
In order to compare the performance of the two pretrained backbones, you can launch tensorboard and inspect the finetuning runs in your browser.
tensorboard --logdir=lightning_logs