.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/package/tutorial_simclr_clothing.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_package_tutorial_simclr_clothing.py: .. _lightly-simclr-tutorial-3: Tutorial 3: Train SimCLR on Clothing ============================================== In this tutorial, we will train a SimCLR model using lightly. The model, augmentations and training procedure is from `A Simple Framework for Contrastive Learning of Visual Representations `_. The paper explores a rather simple training procedure for contrastive learning. Since we use the typical contrastive learning loss based on NCE the method greatly benefits from having larger batch sizes. In this example, we use a batch size of 256 and paired with the input resolution per image of 64x64 pixels and a resnet-18 model this example requires 16GB of GPU memory. We use the `clothing dataset from Alex Grigorev `_ for this tutorial. In this tutorial you will learn: - How to create a SimCLR model - How to generate image representations - How different augmentations impact the learned representations .. GENERATED FROM PYTHON SOURCE LINES 32-36 Imports ------- Import the Python frameworks we need for this tutorial. .. GENERATED FROM PYTHON SOURCE LINES 36-51 .. code-block:: Python import os import matplotlib.pyplot as plt import numpy as np import pytorch_lightning as pl import torch import torch.nn as nn import torchvision from PIL import Image from sklearn.neighbors import NearestNeighbors from sklearn.preprocessing import normalize from lightly.data import LightlyDataset from lightly.transforms import SimCLRTransform, utils .. GENERATED FROM PYTHON SOURCE LINES 52-60 Configuration ------------- We set some configuration parameters for our experiment. Feel free to change them and analyze the effect. The default configuration with a batch size of 256 and input resolution of 128 requires 6GB of GPU memory. .. GENERATED FROM PYTHON SOURCE LINES 60-67 .. code-block:: Python num_workers = 8 batch_size = 256 seed = 1 max_epochs = 20 input_size = 128 num_ftrs = 32 .. GENERATED FROM PYTHON SOURCE LINES 68-69 Let's set the seed for our experiments .. GENERATED FROM PYTHON SOURCE LINES 69-71 .. code-block:: Python pl.seed_everything(seed) .. rst-class:: sphx-glr-script-out .. code-block:: none 1 .. GENERATED FROM PYTHON SOURCE LINES 72-75 Make sure `path_to_data` points to the downloaded clothing dataset. You can download it using `git clone https://github.com/alexeygrigorev/clothing-dataset.git` .. GENERATED FROM PYTHON SOURCE LINES 75-78 .. code-block:: Python path_to_data = "/datasets/clothing-dataset/images" .. GENERATED FROM PYTHON SOURCE LINES 79-91 Setup data augmentations and loaders ------------------------------------ The images from the dataset have been taken from above when the clothing was on a table, bed or floor. Therefore, we can make use of additional augmentations such as vertical flip or random rotation (90 degrees). By adding these augmentations we learn our model invariance regarding the orientation of the clothing piece. E.g. we don't care if a shirt is upside down but more about the strcture which make it a shirt. You can learn more about the different augmentations and learned invariances here: :ref:`lightly-advanced`. .. GENERATED FROM PYTHON SOURCE LINES 91-126 .. code-block:: Python transform = SimCLRTransform(input_size=input_size, vf_prob=0.5, rr_prob=0.5) # We create a torchvision transformation for embedding the dataset after # training test_transform = torchvision.transforms.Compose( [ torchvision.transforms.Resize((input_size, input_size)), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean=utils.IMAGENET_NORMALIZE["mean"], std=utils.IMAGENET_NORMALIZE["std"], ), ] ) dataset_train_simclr = LightlyDataset(input_dir=path_to_data, transform=transform) dataset_test = LightlyDataset(input_dir=path_to_data, transform=test_transform) dataloader_train_simclr = torch.utils.data.DataLoader( dataset_train_simclr, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, ) dataloader_test = torch.utils.data.DataLoader( dataset_test, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, ) .. GENERATED FROM PYTHON SOURCE LINES 127-134 Create the SimCLR Model ----------------------- Now we create the SimCLR model. We implement it as a PyTorch Lightning Module and use a ResNet-18 backbone from Torchvision. Lightly provides implementations of the SimCLR projection head and loss function in the `SimCLRProjectionHead` and `NTXentLoss` classes. We can simply import them and combine the building blocks in the module. .. GENERATED FROM PYTHON SOURCE LINES 134-173 .. code-block:: Python from lightly.loss import NTXentLoss from lightly.models.modules.heads import SimCLRProjectionHead class SimCLRModel(pl.LightningModule): def __init__(self): super().__init__() # create a ResNet backbone and remove the classification head resnet = torchvision.models.resnet18() self.backbone = nn.Sequential(*list(resnet.children())[:-1]) hidden_dim = resnet.fc.in_features self.projection_head = SimCLRProjectionHead(hidden_dim, hidden_dim, 128) self.criterion = NTXentLoss() def forward(self, x): h = self.backbone(x).flatten(start_dim=1) z = self.projection_head(h) return z def training_step(self, batch, batch_idx): (x0, x1), _, _ = batch z0 = self.forward(x0) z1 = self.forward(x1) loss = self.criterion(z0, z1) self.log("train_loss_ssl", loss) return loss def configure_optimizers(self): optim = torch.optim.SGD( self.parameters(), lr=6e-2, momentum=0.9, weight_decay=5e-4 ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs) return [optim], [scheduler] .. GENERATED FROM PYTHON SOURCE LINES 174-175 Train the module using the PyTorch Lightning Trainer on a single GPU. .. GENERATED FROM PYTHON SOURCE LINES 175-180 .. code-block:: Python model = SimCLRModel() trainer = pl.Trainer(max_epochs=max_epochs, devices=1, accelerator="gpu") trainer.fit(model, dataloader_train_simclr) .. rst-class:: sphx-glr-script-out .. code-block:: none /datasets/actions-runner/core_gpu_runner_01/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:165: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /datasets/actions-runner/core_gpu_runner_01/_work/li ... rank_zero_warn( /datasets/actions-runner/core_gpu_runner_01/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1609: PossibleUserWarning: The number of training batches (22) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch. rank_zero_warn( Training: 0it [00:00, ?it/s] Training: 0%| | 0/22 [00:00 .. GENERATED FROM PYTHON SOURCE LINES 310-321 Next Steps ------------ Interested in exploring other self-supervised models? Check out our other tutorials: - :ref:`lightly-moco-tutorial-2` - :ref:`lightly-simsiam-tutorial-4` - :ref:`lightly-custom-augmentation-5` - :ref:`lightly-detectron-tutorial-6` .. rst-class:: sphx-glr-timing **Total running time of the script:** (4 minutes 20.651 seconds) .. _sphx_glr_download_tutorials_package_tutorial_simclr_clothing.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tutorial_simclr_clothing.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tutorial_simclr_clothing.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_