.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/package/tutorial_moco_memory_bank.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_package_tutorial_moco_memory_bank.py: .. _lightly-moco-tutorial-2: Tutorial 2: Train MoCo on CIFAR-10 ============================================== In this tutorial, we will train a model based on the MoCo Paper `Momentum Contrast for Unsupervised Visual Representation Learning `_. When training self-supervised models using contrastive loss we usually face one big problem. To get good results, we need many negative examples for the contrastive loss to work. Therefore, we need a large batch size. However, not everyone has access to a cluster full of GPUs or TPUs. To solve this problem, alternative approaches have been developed. Some of them use a memory bank to store old negative examples we can query to compensate for the smaller batch size. MoCo takes this approach one step further by including a momentum encoder. We use the **CIFAR-10** dataset for this tutorial. In this tutorial you will learn: - How to use lightly to load a dataset and train a model - How to create a MoCo model with a memory bank - How to use the pre-trained model after self-supervised learning for a transfer learning task .. GENERATED FROM PYTHON SOURCE LINES 35-44 Imports ------- Import the Python frameworks we need for this tutorial. Make sure you have lightly installed. .. code-block:: console pip install lightly .. GENERATED FROM PYTHON SOURCE LINES 44-64 .. code-block:: Python import copy import pytorch_lightning as pl import torch import torch.nn as nn import torchvision from lightly.data import LightlyDataset from lightly.loss import NTXentLoss from lightly.models import ResNetGenerator from lightly.models.modules.heads import MoCoProjectionHead from lightly.models.utils import ( batch_shuffle, batch_unshuffle, deactivate_requires_grad, update_momentum, ) from lightly.transforms import MoCoV2Transform, utils .. GENERATED FROM PYTHON SOURCE LINES 65-75 Configuration ------------- We set some configuration parameters for our experiment. Feel free to change them and analyze the effect. The default configuration uses a batch size of 512. This requires around 6.4GB of GPU memory. When training for 100 epochs you should achieve around 73% test set accuracy. When training for 200 epochs accuracy increases to about 80%. .. GENERATED FROM PYTHON SOURCE LINES 75-82 .. code-block:: Python num_workers = 8 batch_size = 512 memory_bank_size = 4096 seed = 1 max_epochs = 100 .. GENERATED FROM PYTHON SOURCE LINES 83-89 Replace the path with the location of your CIFAR-10 dataset. We assume we have a train folder with subfolders for each class and .png images inside. You can download `CIFAR-10 in folders from Kaggle `_. .. GENERATED FROM PYTHON SOURCE LINES 89-107 .. code-block:: Python # The dataset structure should be like this: # cifar10/train/ # L airplane/ # L 10008_airplane.png # L ... # L automobile/ # L bird/ # L cat/ # L deer/ # L dog/ # L frog/ # L horse/ # L ship/ # L truck/ path_to_train = "/datasets/cifar10/train/" path_to_test = "/datasets/cifar10/test/" .. GENERATED FROM PYTHON SOURCE LINES 108-109 Let's set the seed to ensure reproducibility of the experiments .. GENERATED FROM PYTHON SOURCE LINES 109-112 .. code-block:: Python pl.seed_everything(seed) .. rst-class:: sphx-glr-script-out .. code-block:: none 1 .. GENERATED FROM PYTHON SOURCE LINES 113-125 Setup data augmentations and loaders ------------------------------------ We start with our data preprocessing pipeline. We can implement augmentations from the MoCo paper using the transforms provided by lightly. Images from the CIFAR-10 dataset have a resolution of 32x32 pixels. Let's use this resolution to train our model. .. note:: We could use a higher input resolution to train our model. However, since the original resolution of CIFAR-10 images is low there is no real value in increasing the resolution. A higher resolution results in higher memory consumption and to compensate for that we would need to reduce the batch size. .. GENERATED FROM PYTHON SOURCE LINES 125-132 .. code-block:: Python # disable blur because we're working with tiny images transform = MoCoV2Transform( input_size=32, gaussian_blur=0.0, ) .. GENERATED FROM PYTHON SOURCE LINES 133-137 We don't want any augmentation for our test data. Therefore, we create custom, torchvision based data transformations. Let's ensure the size is correct and we normalize the data in the same way as we do with the training data. .. GENERATED FROM PYTHON SOURCE LINES 137-177 .. code-block:: Python # Augmentations typically used to train on cifar-10 train_classifier_transforms = torchvision.transforms.Compose( [ torchvision.transforms.RandomCrop(32, padding=4), torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean=utils.IMAGENET_NORMALIZE["mean"], std=utils.IMAGENET_NORMALIZE["std"], ), ] ) # No additional augmentations for the test set test_transforms = torchvision.transforms.Compose( [ torchvision.transforms.Resize((32, 32)), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean=utils.IMAGENET_NORMALIZE["mean"], std=utils.IMAGENET_NORMALIZE["std"], ), ] ) # We use the moco augmentations for training moco dataset_train_moco = LightlyDataset(input_dir=path_to_train, transform=transform) # Since we also train a linear classifier on the pre-trained moco model we # reuse the test augmentations here (MoCo augmentations are very strong and # usually reduce accuracy of models which are not used for contrastive learning. # Our linear layer will be trained using cross entropy loss and labels provided # by the dataset. Therefore we chose light augmentations.) dataset_train_classifier = LightlyDataset( input_dir=path_to_train, transform=train_classifier_transforms ) dataset_test = LightlyDataset(input_dir=path_to_test, transform=test_transforms) .. GENERATED FROM PYTHON SOURCE LINES 178-180 Create the dataloaders to load and preprocess the data in the background. .. GENERATED FROM PYTHON SOURCE LINES 180-206 .. code-block:: Python dataloader_train_moco = torch.utils.data.DataLoader( dataset_train_moco, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, ) dataloader_train_classifier = torch.utils.data.DataLoader( dataset_train_classifier, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, ) dataloader_test = torch.utils.data.DataLoader( dataset_test, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, ) .. GENERATED FROM PYTHON SOURCE LINES 207-221 Create the MoCo Lightning Module -------------------------------- Now we create our MoCo model. We use PyTorch Lightning to train our model. We follow the specification of the lightning module. In this example we set the number of features for the hidden dimension to 512. The momentum for the Momentum Encoder is set to 0.99 (default is 0.999) since other reports show that this works better for Cifar-10. For the backbone we use the lightly variant of a resnet-18. You can use another model following our `playground to use custom backbones `_. .. note:: We use a split batch norm to simulate multi-gpu behaviour. Combined with the use of batch shuffling, this prevents the model from communicating through the batch norm layers. .. GENERATED FROM PYTHON SOURCE LINES 221-285 .. code-block:: Python class MocoModel(pl.LightningModule): def __init__(self): super().__init__() # create a ResNet backbone and remove the classification head resnet = ResNetGenerator("resnet-18", 1, num_splits=8) self.backbone = nn.Sequential( *list(resnet.children())[:-1], nn.AdaptiveAvgPool2d(1), ) # create a moco model based on ResNet self.projection_head = MoCoProjectionHead(512, 512, 128) self.backbone_momentum = copy.deepcopy(self.backbone) self.projection_head_momentum = copy.deepcopy(self.projection_head) deactivate_requires_grad(self.backbone_momentum) deactivate_requires_grad(self.projection_head_momentum) # create our loss with the optional memory bank self.criterion = NTXentLoss( temperature=0.1, memory_bank_size=(memory_bank_size, 128) ) def training_step(self, batch, batch_idx): (x_q, x_k), _, _ = batch # update momentum update_momentum(self.backbone, self.backbone_momentum, 0.99) update_momentum(self.projection_head, self.projection_head_momentum, 0.99) # get queries q = self.backbone(x_q).flatten(start_dim=1) q = self.projection_head(q) # get keys k, shuffle = batch_shuffle(x_k) k = self.backbone_momentum(k).flatten(start_dim=1) k = self.projection_head_momentum(k) k = batch_unshuffle(k, shuffle) loss = self.criterion(q, k) self.log("train_loss_ssl", loss) return loss def on_train_epoch_end(self): self.custom_histogram_weights() # We provide a helper method to log weights in tensorboard # which is useful for debugging. def custom_histogram_weights(self): for name, params in self.named_parameters(): self.logger.experiment.add_histogram(name, params, self.current_epoch) def configure_optimizers(self): optim = torch.optim.SGD( self.parameters(), lr=6e-2, momentum=0.9, weight_decay=5e-4, ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs) return [optim], [scheduler] .. GENERATED FROM PYTHON SOURCE LINES 286-290 Create the Classifier Lightning Module -------------------------------------- We create a linear classifier using the features we extract using MoCo and train it on the dataset .. GENERATED FROM PYTHON SOURCE LINES 290-358 .. code-block:: Python class Classifier(pl.LightningModule): def __init__(self, backbone): super().__init__() # use the pretrained ResNet backbone self.backbone = backbone # freeze the backbone deactivate_requires_grad(backbone) # create a linear layer for our downstream classification model self.fc = nn.Linear(512, 10) self.criterion = nn.CrossEntropyLoss() self.validation_step_outputs = [] def forward(self, x): y_hat = self.backbone(x).flatten(start_dim=1) y_hat = self.fc(y_hat) return y_hat def training_step(self, batch, batch_idx): x, y, _ = batch y_hat = self.forward(x) loss = self.criterion(y_hat, y) self.log("train_loss_fc", loss) return loss def on_train_epoch_end(self): self.custom_histogram_weights() # We provide a helper method to log weights in tensorboard # which is useful for debugging. def custom_histogram_weights(self): for name, params in self.named_parameters(): self.logger.experiment.add_histogram(name, params, self.current_epoch) def validation_step(self, batch, batch_idx): x, y, _ = batch y_hat = self.forward(x) y_hat = torch.nn.functional.softmax(y_hat, dim=1) # calculate number of correct predictions _, predicted = torch.max(y_hat, 1) num = predicted.shape[0] correct = (predicted == y).float().sum() self.validation_step_outputs.append((num, correct)) return num, correct def on_validation_epoch_end(self): # calculate and log top1 accuracy if self.validation_step_outputs: total_num = 0 total_correct = 0 for num, correct in self.validation_step_outputs: total_num += num total_correct += correct acc = total_correct / total_num self.log("val_acc", acc, on_epoch=True, prog_bar=True) self.validation_step_outputs.clear() def configure_optimizers(self): optim = torch.optim.SGD(self.fc.parameters(), lr=30.0) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs) return [optim], [scheduler] .. GENERATED FROM PYTHON SOURCE LINES 359-364 Train the MoCo model -------------------- We can instantiate the model and train it using the lightning trainer. .. GENERATED FROM PYTHON SOURCE LINES 364-369 .. code-block:: Python model = MocoModel() trainer = pl.Trainer(max_epochs=max_epochs, devices=1, accelerator="gpu") trainer.fit(model, dataloader_train_moco) .. rst-class:: sphx-glr-script-out .. code-block:: none /datasets/actions-runner/core_gpu_runner_01/hostedtoolcache/Python/3.10.13/x64/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:165: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /datasets/actions-runner/core_gpu_runner_01/_work/li ... rank_zero_warn( Training: 0it [00:00, ?it/s] Training: 0%| | 0/97 [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tutorial_moco_memory_bank.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_