.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/package/tutorial_custom_augmentations.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_package_tutorial_custom_augmentations.py: .. _lightly-custom-augmentation-5: Tutorial 5: Custom Augmentations ============================================== In this tutorial, we will train a model on chest X-ray images in a self-supervised manner. In self-supervised learning, X-ray images can pose some problems: They are often more than eight bits deep which makes them incompatible with certain standard torchvision transforms such as, for example, random-resized cropping. Additionally, some augmentations which are often used in self-supervised learning are ineffective on X-ray images. For example, applying color jitter to an X-ray image with a single color channel does not make sense. We will show how to address these problems and how to train a ResNet-18 with MoCo on a set of 16-bit X-ray images in TIFF format. The original dataset this tutorial is based on can be found `on Kaggle `_. These images are in the DICOM format. For simplicity and efficiency reasons, we randomly selected ~4000 images from the above dataset, resized them such that the maximum of the width and height of each image is no larger than 512, and converted them to the 16-bit TIFF format. To do so, we used ImageMagick which is preinstalled on most Linux systems. .. code:: mogrify -path path/to/new/dataset -resize 512x512 -format tiff "*.dicom" .. GENERATED FROM PYTHON SOURCE LINES 30-33 .. code-block:: Python import copy .. GENERATED FROM PYTHON SOURCE LINES 34-38 Imports ------- Import the Python frameworks we need for this tutorial. .. GENERATED FROM PYTHON SOURCE LINES 38-62 .. code-block:: Python import os import matplotlib.pyplot as plt import numpy as np import pandas import pytorch_lightning as pl import torch import torch.nn as nn import torchvision from PIL import Image from sklearn.neighbors import NearestNeighbors from sklearn.preprocessing import normalize from lightly.data import LightlyDataset from lightly.loss import NTXentLoss from lightly.models.modules.heads import MoCoProjectionHead from lightly.models.utils import ( batch_shuffle, batch_unshuffle, deactivate_requires_grad, update_momentum, ) from lightly.transforms.multi_view_transform import MultiViewTransform .. GENERATED FROM PYTHON SOURCE LINES 63-70 Configuration ------------- Let's set the configuration parameters for our experiments. We will use eight workers to fetch the data from disc and a batch size of 128. The input size of the images is set to 128. With these settings, the training requires 2.5GB of GPU memory. .. GENERATED FROM PYTHON SOURCE LINES 70-77 .. code-block:: Python num_workers = 8 batch_size = 128 input_size = 128 seed = 1 max_epochs = 50 .. GENERATED FROM PYTHON SOURCE LINES 78-79 Let's set the seed for our experiments. .. GENERATED FROM PYTHON SOURCE LINES 79-82 .. code-block:: Python pl.seed_everything(seed) .. rst-class:: sphx-glr-script-out .. code-block:: none 1 .. GENERATED FROM PYTHON SOURCE LINES 83-84 Set the path to our dataset. .. GENERATED FROM PYTHON SOURCE LINES 84-87 .. code-block:: Python path_to_data = "/datasets/vinbigdata/train_small" .. GENERATED FROM PYTHON SOURCE LINES 88-98 Setup custom data augmentations ------------------------------- The key to working with 16-bit X-ray images is to convert them to 8-bit images which are compatible with the torchvision augmentations without creating harmful artifacts. A good way to do so, is to use histogram normalization as described in `this paper `_ about Covid-19 prognosis. Let's write an augmentation, which takes as input a numpy array with 16-bit input depth and returns a histogram normalized 8-bit PIL image. .. GENERATED FROM PYTHON SOURCE LINES 98-124 .. code-block:: Python class HistogramNormalize: """Performs histogram normalization on numpy array and returns 8-bit image. Code was taken and adapted from Facebook: https://github.com/facebookresearch/CovidPrognosis """ def __init__(self, number_bins: int = 256): self.number_bins = number_bins def __call__(self, image: np.array) -> Image: # Get the image histogram. image_histogram, bins = np.histogram( image.flatten(), self.number_bins, density=True ) cdf = image_histogram.cumsum() # cumulative distribution function cdf = 255 * cdf / cdf[-1] # normalize # Use linear interpolation of cdf to find new pixel values. image_equalized = np.interp(image.flatten(), bins[:-1], cdf) return Image.fromarray(image_equalized.reshape(image.shape)) .. GENERATED FROM PYTHON SOURCE LINES 125-128 Since we can't use color jitter on X-ray images, let's replace it and add some Gaussian noise instead. It's easiest to apply this after the image has been converted to a PyTorch tensor. .. GENERATED FROM PYTHON SOURCE LINES 128-146 .. code-block:: Python class GaussianNoise: """Applies random Gaussian noise to a tensor. The intensity of the noise is dependent on the mean of the pixel values. See https://arxiv.org/pdf/2101.04909.pdf for more information. """ def __call__(self, sample: torch.Tensor) -> torch.Tensor: mu = sample.mean() snr = np.random.randint(low=4, high=8) sigma = mu / snr noise = torch.normal(torch.zeros(sample.shape), sigma) return sample + noise .. GENERATED FROM PYTHON SOURCE LINES 147-157 Now that we have implemented our custom augmentations, we can combine them with available augmentations from the torchvision library to get to the same set of augmentations as used in the aforementioned paper. Make sure, that the first augmentation is the histogram normalization, and that the Gaussian noise is applied after converting the image to a tensor. Note that we also transform the image from grayscale to RGB by simply repeating the single color channel three times. The reason for this is that our ResNet expects a three color channel input. This step can be skipped if a different backbone network is used. .. GENERATED FROM PYTHON SOURCE LINES 157-175 .. code-block:: Python # Compose the custom augmentations with available augmentations. view_transform = torchvision.transforms.Compose( [ HistogramNormalize(), torchvision.transforms.Grayscale(num_output_channels=3), torchvision.transforms.RandomResizedCrop(size=input_size, scale=(0.2, 1.0)), torchvision.transforms.RandomHorizontalFlip(p=0.5), torchvision.transforms.RandomVerticalFlip(p=0.5), torchvision.transforms.GaussianBlur(21), torchvision.transforms.ToTensor(), GaussianNoise(), ] ) # Create a multiview transform that returns two different augmentations of each image. transform = MultiViewTransform(transforms=[view_transform, view_transform]) .. GENERATED FROM PYTHON SOURCE LINES 176-179 Let's take a look at what our augmentation pipeline does to an image! We plot the original image on the left and two random augmentations on the right. .. GENERATED FROM PYTHON SOURCE LINES 179-201 .. code-block:: Python example_image_name = "55e8e3db7309febee415515d06418171.tiff" example_image_path = os.path.join(path_to_data, example_image_name) example_image = np.array(Image.open(example_image_path)) # Torch transform returns a 3 x W x H image, we only show one color channel. augmented_image_1 = view_transform(example_image).numpy()[0] augmented_image_2 = view_transform(example_image).numpy()[0] fig, axs = plt.subplots(1, 3) axs[0].imshow(example_image) axs[0].set_axis_off() axs[0].set_title("Original Image") axs[1].imshow(augmented_image_1) axs[1].set_axis_off() axs[2].imshow(augmented_image_2) axs[2].set_axis_off() .. image-sg:: /tutorials/package/images/sphx_glr_tutorial_custom_augmentations_001.png :alt: Original Image :srcset: /tutorials/package/images/sphx_glr_tutorial_custom_augmentations_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 202-213 Setup dataset and dataloader ------------------------------ We create a dataset which loads the images in the input directory. Since the input images are 16 bits deep, we need to overwrite the image loader such that it doesn't convert the images to RGB (and hence to 8-bit) automatically. .. note:: The `LightlyDataset` uses a torchvision dataset underneath, which in turn uses an image loader which transforms the input image to an 8-bit RGB image. If a 16-bit grayscale image is loaded that way, all pixel values above 255 are simply clamped. Therefore, we overwrite the default image loader with our custom one. .. GENERATED FROM PYTHON SOURCE LINES 213-235 .. code-block:: Python def tiff_loader(f): """Loads a 16-bit tiff image and returns it as a numpy array.""" with open(f, "rb") as f: image = Image.open(f) return np.array(image) # Create the dataset with the custom transform and overwrite the image loader. dataset_train = LightlyDataset(input_dir=path_to_data, transform=transform) dataset_train.dataset.loader = tiff_loader # Setup the dataloader for training. dataloader_train = torch.utils.data.DataLoader( dataset_train, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, ) .. GENERATED FROM PYTHON SOURCE LINES 236-248 Create the MoCo model ----------------------- Using the building blocks provided by lightly we can write our MoCo model. We implement it as a PyTorch Lightning module. For the criterion, we use the NTXentLoss which should always be used with MoCo. MoCo also requires a memory bank - we set its size to 4096 which is approximately the size of the input dataset. The temperature parameter of the loss is set to 0.1. This smoothens the cross entropy term in the loss function. The choice of the optimizer is left to the user. Here, we go with simple stochastic gradient descent with momentum. .. GENERATED FROM PYTHON SOURCE LINES 248-306 .. code-block:: Python class MoCoModel(pl.LightningModule): def __init__(self): super().__init__() # Create a ResNet backbone and remove the classification head. resnet = torchvision.models.resnet18() self.backbone = nn.Sequential( *list(resnet.children())[:-1], ) # The backbone has output dimension 512 which also defines the size of # the hidden dimension. We select 128 for the output dimension. self.projection_head = MoCoProjectionHead(512, 512, 128) # Add the momentum network. self.backbone_momentum = copy.deepcopy(self.backbone) self.projection_head_momentum = copy.deepcopy(self.projection_head) deactivate_requires_grad(self.backbone_momentum) deactivate_requires_grad(self.projection_head_momentum) # Create the loss function with memory bank. self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=(4096, 128)) def training_step(self, batch, batch_idx): (x_q, x_k), _, _ = batch # Momentum update update_momentum(self.backbone, self.backbone_momentum, 0.99) update_momentum(self.projection_head, self.projection_head_momentum, 0.99) # Get the queries. q = self.backbone(x_q).flatten(start_dim=1) q = self.projection_head(q) # Get the keys. k, shuffle = batch_shuffle(x_k) k = self.backbone_momentum(k).flatten(start_dim=1) k = self.projection_head_momentum(k) k = batch_unshuffle(k, shuffle) loss = self.criterion(q, k) self.log("train_loss_ssl", loss) return loss def configure_optimizers(self): # Use SGD optimizer with momentum and weight decay. optim = torch.optim.SGD( self.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4, ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs) return [optim], [scheduler] .. GENERATED FROM PYTHON SOURCE LINES 307-311 Train MoCo with custom augmentations ------------------------------------- Training the self-supervised model is now very easy. We can create a new MoCoModel instance and pass it to the PyTorch Lightning trainer. .. GENERATED FROM PYTHON SOURCE LINES 311-323 .. code-block:: Python model = MoCoModel() trainer = pl.Trainer( max_epochs=max_epochs, devices=1, accelerator="gpu", precision=16, ) trainer.fit(model, dataloader_train) .. rst-class:: sphx-glr-script-out .. code-block:: none /datasets/actions-runner/core_gpu_runner_01/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:165: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /datasets/actions-runner/core_gpu_runner_01/_work/li ... rank_zero_warn( /datasets/actions-runner/core_gpu_runner_01/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1609: PossibleUserWarning: The number of training batches (32) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch. rank_zero_warn( Training: 0it [00:00, ?it/s] Training: 0%| | 0/32 [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tutorial_custom_augmentations.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_