.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/package/tutorial_checkpoint_finetuning.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_package_tutorial_checkpoint_finetuning.py: .. _lightly-checkpoint-finetuning-tutorial-7: Tutorial 7: Finetuning Lightly Checkpoints =========================================== LightlySSL provides pre-trained models on various datasets such as ImageNet1k, ImageNet100, Imagenette, and CIFAR-10. All these models' weights along with their hyperparameter configurations are available here :ref:`lightly-benchmarks`. In this tutorial, we will learn how to use these pre-trained model checkpoints to fine-tune an image classification model for the Food-101 dataset using PyTorch Lightning. .. GENERATED FROM PYTHON SOURCE LINES 17-26 Imports ------- Import the Python frameworks we need for this tutorial. Make sure you have the necessary packages installed. .. code-block:: console pip install lightly torchmetrics .. GENERATED FROM PYTHON SOURCE LINES 26-36 .. code-block:: Python import pytorch_lightning as pl import torch import torch.nn as nn from pytorch_lightning.loggers import TensorBoardLogger from torchvision import transforms from lightly.transforms.utils import IMAGENET_NORMALIZE .. GENERATED FROM PYTHON SOURCE LINES 37-44 Downloading Model Checkpoint ----------------------------- Let's use the resnet50 model pre-trained on ImageNet1k using the `SimCLR `_ method. You can browse other model checkpoints at :ref:`lightly-benchmarks`. .. GENERATED FROM PYTHON SOURCE LINES 44-47 .. code-block:: Python checkpoint_url = "https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_simclr_2023-06-22_09-11-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt" .. GENERATED FROM PYTHON SOURCE LINES 48-51 Tensorboard Logger --------------------------- .. GENERATED FROM PYTHON SOURCE LINES 51-55 .. code-block:: Python tb_logger = TensorBoardLogger("tb_logs", name="lightly_finetuning") .. GENERATED FROM PYTHON SOURCE LINES 56-64 Configuration ------------- Let's set the configuration parameters for our experiments. We use a batch size of 32 and an input size of 128. We only train for 5 epochs because the focus of this tutorial is on finetuning lightly checkpoints. .. GENERATED FROM PYTHON SOURCE LINES 64-76 .. code-block:: Python learning_rate = 0.001 num_workers = 8 batch_size = 3 input_size = 128 seed = 42 num_train_epochs = 5 # use cuda if possible device = "cuda" if torch.cuda.is_available() else "cpu" .. GENERATED FROM PYTHON SOURCE LINES 77-90 Setup data augmentations and loaders ------------------------------------- For this tutorial, we'll use the Food-101 dataset of 101 food categories with 101,000 images. For each class, 250 manually reviewed test images are provided as well as 750 training images. On purpose, the training images were not cleaned, and thus still contain some amount of noise. This comes mostly in the form of intense colors and sometimes wrong labels. All images were rescaled to have a maximum side length of 512 pixels. We will also use some minimal augmentations for the train and test subsets. To learn more about data pipelines in LightlySSL you can refer to :ref:`input-structure-label` and to learn more about the different augmentations and learned invariances please refer to :ref:`lightly-advanced`. .. GENERATED FROM PYTHON SOURCE LINES 90-144 .. code-block:: Python from torch.utils.data import DataLoader from torchvision.datasets import Food101 num_classes = 101 # Training Transformations train_transform = transforms.Compose( [ transforms.RandomResizedCrop(input_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize( mean=IMAGENET_NORMALIZE["mean"], std=IMAGENET_NORMALIZE["std"], ), ] ) train_dataset = Food101( "datasets/food101", split="train", download=True, transform=train_transform ) train_dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, ) # Validation Transformations val_transform = transforms.Compose( [ transforms.Resize(input_size), transforms.CenterCrop(input_size), transforms.ToTensor(), transforms.Normalize( mean=IMAGENET_NORMALIZE["mean"], std=IMAGENET_NORMALIZE["std"], ), ] ) val_dataset = Food101( "datasets/food101", split="test", download=True, transform=val_transform ) val_dataloader = DataLoader( val_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, ) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading https://data.vision.ee.ethz.ch/cvl/food-101.tar.gz to datasets/food101/food-101.tar.gz 0%| | 0/4996278331 [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tutorial_checkpoint_finetuning.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: tutorial_checkpoint_finetuning.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_