Tutorial 3: Train SimCLR on Clothing

In this tutorial, we will train a SimCLR model using lightly. The model, augmentations and training procedure is from A Simple Framework for Contrastive Learning of Visual Representations.

The paper explores a rather simple training procedure for contrastive learning. Since we use the typical contrastive learning loss based on NCE the method greatly benefits from having larger batch sizes. In this example, we use a batch size of 256 and paired with the input resolution per image of 64x64 pixels and a resnet-18 model this example requires 16GB of GPU memory.

We use the clothing dataset from Alex Grigorev for this tutorial.

In this tutorial you will learn:

  • How to create a SimCLR model

  • How different augmentations impact the learned representations

  • How to use the SelfSupervisedEmbedding class from the embedding module to train a model and obtain embeddings

Imports

Import the Python frameworks we need for this tutorial.

import os
import glob
import torch
import torch.nn as nn
import torchvision
import pytorch_lightning as pl
import lightly
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import normalize
from PIL import Image
import numpy as np

Configuration

We set some configuration parameters for our experiment. Feel free to change them and analyze the effect.

The default configuration with a batch size of 256 and input resolution of 128 requires 6GB of GPU memory.

num_workers = 8
batch_size = 256
seed = 1
max_epochs = 20
input_size = 128
num_ftrs = 32

Let’s set the seed for our experiments

pl.seed_everything(seed)

Out:

Global seed set to 1

1

Make sure path_to_data points to the downloaded clothing dataset. You can download it using git clone https://github.com/alexeygrigorev/clothing-dataset.git

path_to_data = '/datasets/clothing-dataset/images'

Setup data augmentations and loaders

The images from the dataset have been taken from above when the clothing was on a table, bed or floor. Therefore, we can make use of additional augmentations such as vertical flip or random rotation (90 degrees). By adding these augmentations we learn our model invariance regarding the orientation of the clothing piece. E.g. we don’t care if a shirt is upside down but more about the strcture which make it a shirt.

You can learn more about the different augmentations and learned invariances here: Advanced Concepts in Self-Supervised Learning.

collate_fn = lightly.data.SimCLRCollateFunction(
    input_size=input_size,
    vf_prob=0.5,
    rr_prob=0.5
)

# We create a torchvision transformation for embedding the dataset after
# training
test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((input_size, input_size)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=lightly.data.collate.imagenet_normalize['mean'],
        std=lightly.data.collate.imagenet_normalize['std'],
    )
])

dataset_train_simclr = lightly.data.LightlyDataset(
    input_dir=path_to_data
)

dataset_test = lightly.data.LightlyDataset(
    input_dir=path_to_data,
    transform=test_transforms
)

dataloader_train_simclr = torch.utils.data.DataLoader(
    dataset_train_simclr,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=True,
    num_workers=num_workers
)

dataloader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

Create the SimCLR model

Create a ResNet backbone and remove the classification head

resnet = torchvision.models.resnet18()
last_conv_channels = list(resnet.children())[-1].in_features
backbone = nn.Sequential(
    *list(resnet.children())[:-1],
    nn.Conv2d(last_conv_channels, num_ftrs, 1),
)

# create the SimCLR model using the newly created backbone
model = lightly.models.SimCLR(backbone, num_ftrs=num_ftrs)

We now use the SelfSupervisedEmbedding class from the embedding module. First, we create a criterion and an optimizer and then pass them together with the model and the dataloader.

criterion = lightly.loss.NTXentLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
encoder = lightly.embedding.SelfSupervisedEmbedding(
    model,
    criterion,
    optimizer,
    dataloader_train_simclr
)

use a GPU if available

gpus = 1 if torch.cuda.is_available() else 0

Train the Embedding

The encoder itself wraps a PyTorch-Lightning module. We can pass any lightning trainer parameter (e.g. gpus=, max_epochs=) to the train_embedding method.

encoder.train_embedding(gpus=gpus,
                        progress_bar_refresh_rate=100,
                        max_epochs=max_epochs)

Out:

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/opt/conda/envs/lightly/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
/opt/conda/envs/lightly/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/opt/conda/envs/lightly/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
/opt/conda/envs/lightly/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/opt/conda/envs/lightly/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
/opt/conda/envs/lightly/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])

  | Name      | Type       | Params
-----------------------------------------
0 | model     | SimCLR     | 11.2 M
1 | criterion | NTXentLoss | 0
-----------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.793    Total estimated model params size (MB)

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/22 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/22 [00:00<?, ?it/s]
Epoch 0: 100%|##########| 22/22 [00:16<00:00,  1.30it/s]
Epoch 0: 100%|##########| 22/22 [00:16<00:00,  1.30it/s, loss=5.86, v_num=75]
Epoch 0:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.86, v_num=75]
Epoch 1:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.86, v_num=75]
Epoch 1: 100%|##########| 22/22 [00:14<00:00,  1.52it/s, loss=5.86, v_num=75]
Epoch 1: 100%|##########| 22/22 [00:14<00:00,  1.52it/s, loss=5.66, v_num=75]
Epoch 1:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.66, v_num=75]
Epoch 2:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.66, v_num=75]
Epoch 2: 100%|##########| 22/22 [00:14<00:00,  1.52it/s, loss=5.66, v_num=75]
Epoch 2: 100%|##########| 22/22 [00:14<00:00,  1.52it/s, loss=5.58, v_num=75]
Epoch 2:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.58, v_num=75]
Epoch 3:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.58, v_num=75]
Epoch 3: 100%|##########| 22/22 [00:14<00:00,  1.50it/s, loss=5.58, v_num=75]
Epoch 3: 100%|##########| 22/22 [00:14<00:00,  1.50it/s, loss=5.51, v_num=75]
Epoch 3:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.51, v_num=75]
Epoch 4:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.51, v_num=75]
Epoch 4: 100%|##########| 22/22 [00:13<00:00,  1.61it/s, loss=5.51, v_num=75]
Epoch 4: 100%|##########| 22/22 [00:13<00:00,  1.61it/s, loss=5.48, v_num=75]
Epoch 4:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.48, v_num=75]
Epoch 5:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.48, v_num=75]
Epoch 5: 100%|##########| 22/22 [00:14<00:00,  1.53it/s, loss=5.48, v_num=75]
Epoch 5: 100%|##########| 22/22 [00:14<00:00,  1.53it/s, loss=5.43, v_num=75]
Epoch 5:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.43, v_num=75]
Epoch 6:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.43, v_num=75]
Epoch 6: 100%|##########| 22/22 [00:14<00:00,  1.53it/s, loss=5.43, v_num=75]
Epoch 6: 100%|##########| 22/22 [00:14<00:00,  1.53it/s, loss=5.44, v_num=75]
Epoch 6:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.44, v_num=75]
Epoch 7:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.44, v_num=75]
Epoch 7: 100%|##########| 22/22 [00:14<00:00,  1.54it/s, loss=5.44, v_num=75]
Epoch 7: 100%|##########| 22/22 [00:14<00:00,  1.54it/s, loss=5.39, v_num=75]
Epoch 7:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.39, v_num=75]
Epoch 8:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.39, v_num=75]
Epoch 8: 100%|##########| 22/22 [00:13<00:00,  1.58it/s, loss=5.39, v_num=75]
Epoch 8: 100%|##########| 22/22 [00:13<00:00,  1.58it/s, loss=5.35, v_num=75]
Epoch 8:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.35, v_num=75]
Epoch 9:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.35, v_num=75]
Epoch 9: 100%|##########| 22/22 [00:13<00:00,  1.59it/s, loss=5.35, v_num=75]
Epoch 9: 100%|##########| 22/22 [00:13<00:00,  1.59it/s, loss=5.35, v_num=75]
Epoch 9:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.35, v_num=75]
Epoch 10:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.35, v_num=75]
Epoch 10: 100%|##########| 22/22 [00:14<00:00,  1.50it/s, loss=5.35, v_num=75]
Epoch 10: 100%|##########| 22/22 [00:14<00:00,  1.50it/s, loss=5.33, v_num=75]
Epoch 10:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.33, v_num=75]
Epoch 11:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.33, v_num=75]
Epoch 11: 100%|##########| 22/22 [00:14<00:00,  1.55it/s, loss=5.33, v_num=75]
Epoch 11: 100%|##########| 22/22 [00:14<00:00,  1.55it/s, loss=5.31, v_num=75]
Epoch 11:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.31, v_num=75]
Epoch 12:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.31, v_num=75]
Epoch 12: 100%|##########| 22/22 [00:14<00:00,  1.51it/s, loss=5.31, v_num=75]
Epoch 12: 100%|##########| 22/22 [00:14<00:00,  1.51it/s, loss=5.29, v_num=75]
Epoch 12:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.29, v_num=75]
Epoch 13:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.29, v_num=75]
Epoch 13: 100%|##########| 22/22 [00:14<00:00,  1.52it/s, loss=5.29, v_num=75]
Epoch 13: 100%|##########| 22/22 [00:14<00:00,  1.52it/s, loss=5.29, v_num=75]
Epoch 13:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.29, v_num=75]
Epoch 14:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.29, v_num=75]
Epoch 14: 100%|##########| 22/22 [00:14<00:00,  1.51it/s, loss=5.29, v_num=75]
Epoch 14: 100%|##########| 22/22 [00:14<00:00,  1.51it/s, loss=5.27, v_num=75]
Epoch 14:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.27, v_num=75]
Epoch 15:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.27, v_num=75]
Epoch 15: 100%|##########| 22/22 [00:14<00:00,  1.50it/s, loss=5.27, v_num=75]
Epoch 15: 100%|##########| 22/22 [00:14<00:00,  1.50it/s, loss=5.25, v_num=75]
Epoch 15:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.25, v_num=75]
Epoch 16:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.25, v_num=75]
Epoch 16: 100%|##########| 22/22 [00:14<00:00,  1.53it/s, loss=5.25, v_num=75]
Epoch 16: 100%|##########| 22/22 [00:14<00:00,  1.53it/s, loss=5.25, v_num=75]
Epoch 16:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.25, v_num=75]
Epoch 17:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.25, v_num=75]
Epoch 17: 100%|##########| 22/22 [00:14<00:00,  1.53it/s, loss=5.25, v_num=75]
Epoch 17: 100%|##########| 22/22 [00:14<00:00,  1.53it/s, loss=5.23, v_num=75]
Epoch 17:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.23, v_num=75]
Epoch 18:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.23, v_num=75]
Epoch 18: 100%|##########| 22/22 [00:13<00:00,  1.59it/s, loss=5.23, v_num=75]
Epoch 18: 100%|##########| 22/22 [00:13<00:00,  1.59it/s, loss=5.21, v_num=75]
Epoch 18:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.21, v_num=75]
Epoch 19:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.21, v_num=75]
Epoch 19: 100%|##########| 22/22 [00:14<00:00,  1.55it/s, loss=5.21, v_num=75]
Epoch 19: 100%|##########| 22/22 [00:14<00:00,  1.55it/s, loss=5.22, v_num=75]
Epoch 19: 100%|##########| 22/22 [00:14<00:00,  1.55it/s, loss=5.22, v_num=75]

Now, let’s make sure we move the trained model to the gpu if we have one

device = 'cuda' if gpus==1 else 'cpu'
encoder = encoder.to(device)

We can use the .embed method to create an embedding of the dataset. The method returns a list of embedding vectors as well as a list of filenames.

embeddings, _, fnames = encoder.embed(dataloader_test, device=device)
embeddings = normalize(embeddings)

Out:

  0%|          | 0/23 [00:00<?, ?it/s]
Compute efficiency: 0.01:   0%|          | 0/23 [00:02<?, ?it/s]
Compute efficiency: 0.01:   4%|4         | 1/23 [00:02<00:56,  2.57s/it]
Compute efficiency: 0.18:   4%|4         | 1/23 [00:02<00:56,  2.57s/it]
Compute efficiency: 0.16:   4%|4         | 1/23 [00:02<00:56,  2.57s/it]
Compute efficiency: 0.17:   4%|4         | 1/23 [00:02<00:56,  2.57s/it]
Compute efficiency: 0.17:   4%|4         | 1/23 [00:02<00:56,  2.57s/it]
Compute efficiency: 0.17:  22%|##1       | 5/23 [00:02<00:07,  2.44it/s]
Compute efficiency: 0.16:  22%|##1       | 5/23 [00:02<00:07,  2.44it/s]
Compute efficiency: 0.16:  22%|##1       | 5/23 [00:02<00:07,  2.44it/s]
Compute efficiency: 0.16:  22%|##1       | 5/23 [00:02<00:07,  2.44it/s]
Compute efficiency: 0.00:  22%|##1       | 5/23 [00:04<00:07,  2.44it/s]
Compute efficiency: 0.00:  39%|###9      | 9/23 [00:04<00:06,  2.17it/s]
Compute efficiency: 0.13:  39%|###9      | 9/23 [00:04<00:06,  2.17it/s]
Compute efficiency: 0.12:  39%|###9      | 9/23 [00:04<00:06,  2.17it/s]
Compute efficiency: 0.12:  39%|###9      | 9/23 [00:04<00:06,  2.17it/s]
Compute efficiency: 0.12:  39%|###9      | 9/23 [00:04<00:06,  2.17it/s]
Compute efficiency: 0.12:  57%|#####6    | 13/23 [00:04<00:02,  3.69it/s]
Compute efficiency: 0.12:  57%|#####6    | 13/23 [00:04<00:02,  3.69it/s]
Compute efficiency: 0.12:  57%|#####6    | 13/23 [00:04<00:02,  3.69it/s]
Compute efficiency: 0.12:  57%|#####6    | 13/23 [00:04<00:02,  3.69it/s]
Compute efficiency: 0.00:  57%|#####6    | 13/23 [00:06<00:02,  3.69it/s]
Compute efficiency: 0.00:  74%|#######3  | 17/23 [00:06<00:01,  3.52it/s]
Compute efficiency: 0.13:  74%|#######3  | 17/23 [00:06<00:01,  3.52it/s]
Compute efficiency: 0.12:  74%|#######3  | 17/23 [00:06<00:01,  3.52it/s]
Compute efficiency: 0.16:  74%|#######3  | 17/23 [00:06<00:01,  3.52it/s]
Compute efficiency: 0.12:  74%|#######3  | 17/23 [00:06<00:01,  3.52it/s]
Compute efficiency: 0.12:  91%|#########1| 21/23 [00:06<00:00,  5.15it/s]
Compute efficiency: 0.12:  91%|#########1| 21/23 [00:06<00:00,  5.15it/s]
Compute efficiency: 0.15:  91%|#########1| 21/23 [00:06<00:00,  5.15it/s]
Compute efficiency: 0.15: 100%|##########| 23/23 [00:06<00:00,  3.70it/s]

Visualize Nearest Neighbors

Let’s look at the trained embedding and visualize the nearest neighbors for a few random samples.

We create some helper functions to simplify the work

def get_image_as_np_array(filename: str):
    """Returns an image as an numpy array
    """
    img = Image.open(filename)
    return np.asarray(img)

def plot_knn_examples(embeddings, n_neighbors=3, num_examples=6):
    """Plots multiple rows of random images with their nearest neighbors
    """
    # lets look at the nearest neighbors for some samples
    # we use the sklearn library
    nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(embeddings)
    distances, indices = nbrs.kneighbors(embeddings)

    # get 5 random samples
    samples_idx = np.random.choice(len(indices), size=num_examples, replace=False)

    # loop through our randomly picked samples
    for idx in samples_idx:
        fig = plt.figure()
        # loop through their nearest neighbors
        for plot_x_offset, neighbor_idx in enumerate(indices[idx]):
            # add the subplot
            ax = fig.add_subplot(1, len(indices[idx]), plot_x_offset + 1)
            # get the correponding filename for the current index
            fname = os.path.join(path_to_data, fnames[neighbor_idx])
            # plot the image
            plt.imshow(get_image_as_np_array(fname))
            # set the title to the distance of the neighbor
            ax.set_title(f'd={distances[idx][plot_x_offset]:.3f}')
            # let's disable the axis
            plt.axis('off')

Let’s do the plot of the images. The leftmost image is the query image whereas the ones next to it on the same row are the nearest neighbors. In the title we see the distance of the neigbor.

plot_knn_examples(embeddings)
  • d=0.000, d=0.086, d=0.136
  • d=0.000, d=0.112, d=0.139
  • d=0.000, d=0.075, d=0.076
  • d=0.000, d=0.073, d=0.094
  • d=0.000, d=0.163, d=0.168
  • d=0.000, d=0.155, d=0.214

Color Invariance

Let’s train again without color augmentation. This will force our model to respect the colors in the images.

# Set color jitter and gray scale probability to 0
new_collate_fn = lightly.data.SimCLRCollateFunction(
    input_size=input_size,
    vf_prob=0.5,
    rr_prob=0.5,
    cj_prob=0.0,
    random_gray_scale=0.0
)

# let's update our collate method and reuse our dataloader
dataloader_train_simclr.collate_fn=new_collate_fn

# create a ResNet backbone and remove the classification head
resnet = torchvision.models.resnet18()
last_conv_channels = list(resnet.children())[-1].in_features
backbone = nn.Sequential(
    *list(resnet.children())[:-1],
    nn.Conv2d(last_conv_channels, num_ftrs, 1),
)
model = lightly.models.SimCLR(backbone, num_ftrs=num_ftrs)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
encoder = lightly.embedding.SelfSupervisedEmbedding(
    model,
    criterion,
    optimizer,
    dataloader_train_simclr
)

encoder.train_embedding(gpus=gpus,
                        progress_bar_refresh_rate=100,
                        max_epochs=max_epochs)
encoder = encoder.to(device)

embeddings, _, fnames = encoder.embed(dataloader_test, device=device)
embeddings = normalize(embeddings)

Out:

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type       | Params
-----------------------------------------
0 | model     | SimCLR     | 11.2 M
1 | criterion | NTXentLoss | 0
-----------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.793    Total estimated model params size (MB)

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/22 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/22 [00:00<?, ?it/s]
Epoch 0: 100%|##########| 22/22 [00:10<00:00,  2.11it/s]
Epoch 0: 100%|##########| 22/22 [00:10<00:00,  2.11it/s, loss=5.29, v_num=76]
Epoch 0:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.29, v_num=76]
Epoch 1:   0%|          | 0/22 [00:00<?, ?it/s, loss=5.29, v_num=76]
Epoch 1: 100%|##########| 22/22 [00:09<00:00,  2.22it/s, loss=5.29, v_num=76]
Epoch 1: 100%|##########| 22/22 [00:09<00:00,  2.22it/s, loss=4.89, v_num=76]
Epoch 1:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.89, v_num=76]
Epoch 2:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.89, v_num=76]
Epoch 2: 100%|##########| 22/22 [00:10<00:00,  2.19it/s, loss=4.89, v_num=76]
Epoch 2: 100%|##########| 22/22 [00:10<00:00,  2.19it/s, loss=4.8, v_num=76]
Epoch 2:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.8, v_num=76]
Epoch 3:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.8, v_num=76]
Epoch 3: 100%|##########| 22/22 [00:09<00:00,  2.37it/s, loss=4.8, v_num=76]
Epoch 3: 100%|##########| 22/22 [00:09<00:00,  2.37it/s, loss=4.74, v_num=76]
Epoch 3:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.74, v_num=76]
Epoch 4:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.74, v_num=76]
Epoch 4: 100%|##########| 22/22 [00:10<00:00,  2.15it/s, loss=4.74, v_num=76]
Epoch 4: 100%|##########| 22/22 [00:10<00:00,  2.15it/s, loss=4.73, v_num=76]
Epoch 4:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.73, v_num=76]
Epoch 5:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.73, v_num=76]
Epoch 5: 100%|##########| 22/22 [00:09<00:00,  2.30it/s, loss=4.73, v_num=76]
Epoch 5: 100%|##########| 22/22 [00:09<00:00,  2.30it/s, loss=4.7, v_num=76]
Epoch 5:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.7, v_num=76]
Epoch 6:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.7, v_num=76]
Epoch 6: 100%|##########| 22/22 [00:10<00:00,  2.16it/s, loss=4.7, v_num=76]
Epoch 6: 100%|##########| 22/22 [00:10<00:00,  2.16it/s, loss=4.7, v_num=76]
Epoch 6:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.7, v_num=76]
Epoch 7:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.7, v_num=76]
Epoch 7: 100%|##########| 22/22 [00:09<00:00,  2.24it/s, loss=4.7, v_num=76]
Epoch 7: 100%|##########| 22/22 [00:09<00:00,  2.24it/s, loss=4.67, v_num=76]
Epoch 7:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.67, v_num=76]
Epoch 8:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.67, v_num=76]
Epoch 8: 100%|##########| 22/22 [00:09<00:00,  2.27it/s, loss=4.67, v_num=76]
Epoch 8: 100%|##########| 22/22 [00:09<00:00,  2.27it/s, loss=4.66, v_num=76]
Epoch 8:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.66, v_num=76]
Epoch 9:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.66, v_num=76]
Epoch 9: 100%|##########| 22/22 [00:10<00:00,  2.20it/s, loss=4.66, v_num=76]
Epoch 9: 100%|##########| 22/22 [00:10<00:00,  2.20it/s, loss=4.65, v_num=76]
Epoch 9:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.65, v_num=76]
Epoch 10:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.65, v_num=76]
Epoch 10: 100%|##########| 22/22 [00:10<00:00,  2.18it/s, loss=4.65, v_num=76]
Epoch 10: 100%|##########| 22/22 [00:10<00:00,  2.18it/s, loss=4.64, v_num=76]
Epoch 10:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.64, v_num=76]
Epoch 11:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.64, v_num=76]
Epoch 11: 100%|##########| 22/22 [00:10<00:00,  2.17it/s, loss=4.64, v_num=76]
Epoch 11: 100%|##########| 22/22 [00:10<00:00,  2.17it/s, loss=4.63, v_num=76]
Epoch 11:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.63, v_num=76]
Epoch 12:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.63, v_num=76]
Epoch 12: 100%|##########| 22/22 [00:10<00:00,  2.19it/s, loss=4.63, v_num=76]
Epoch 12: 100%|##########| 22/22 [00:10<00:00,  2.19it/s, loss=4.63, v_num=76]
Epoch 12:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.63, v_num=76]
Epoch 13:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.63, v_num=76]
Epoch 13: 100%|##########| 22/22 [00:09<00:00,  2.23it/s, loss=4.63, v_num=76]
Epoch 13: 100%|##########| 22/22 [00:09<00:00,  2.23it/s, loss=4.61, v_num=76]
Epoch 13:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.61, v_num=76]
Epoch 14:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.61, v_num=76]
Epoch 14: 100%|##########| 22/22 [00:09<00:00,  2.36it/s, loss=4.61, v_num=76]
Epoch 14: 100%|##########| 22/22 [00:09<00:00,  2.36it/s, loss=4.61, v_num=76]
Epoch 14:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.61, v_num=76]
Epoch 15:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.61, v_num=76]
Epoch 15: 100%|##########| 22/22 [00:10<00:00,  2.15it/s, loss=4.61, v_num=76]
Epoch 15: 100%|##########| 22/22 [00:10<00:00,  2.15it/s, loss=4.61, v_num=76]
Epoch 15:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.61, v_num=76]
Epoch 16:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.61, v_num=76]
Epoch 16: 100%|##########| 22/22 [00:09<00:00,  2.21it/s, loss=4.61, v_num=76]
Epoch 16: 100%|##########| 22/22 [00:09<00:00,  2.21it/s, loss=4.6, v_num=76]
Epoch 16:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.6, v_num=76]
Epoch 17:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.6, v_num=76]
Epoch 17: 100%|##########| 22/22 [00:09<00:00,  2.22it/s, loss=4.6, v_num=76]
Epoch 17: 100%|##########| 22/22 [00:09<00:00,  2.22it/s, loss=4.59, v_num=76]
Epoch 17:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.59, v_num=76]
Epoch 18:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.59, v_num=76]
Epoch 18: 100%|##########| 22/22 [00:09<00:00,  2.28it/s, loss=4.59, v_num=76]
Epoch 18: 100%|##########| 22/22 [00:09<00:00,  2.28it/s, loss=4.58, v_num=76]
Epoch 18:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.58, v_num=76]
Epoch 19:   0%|          | 0/22 [00:00<?, ?it/s, loss=4.58, v_num=76]
Epoch 19: 100%|##########| 22/22 [00:09<00:00,  2.34it/s, loss=4.58, v_num=76]
Epoch 19: 100%|##########| 22/22 [00:09<00:00,  2.34it/s, loss=4.57, v_num=76]
Epoch 19: 100%|##########| 22/22 [00:09<00:00,  2.34it/s, loss=4.57, v_num=76]

  0%|          | 0/23 [00:00<?, ?it/s]
Compute efficiency: 0.01:   0%|          | 0/23 [00:01<?, ?it/s]
Compute efficiency: 0.01:   4%|4         | 1/23 [00:01<00:36,  1.65s/it]
Compute efficiency: 0.17:   4%|4         | 1/23 [00:01<00:36,  1.65s/it]
Compute efficiency: 0.17:   4%|4         | 1/23 [00:01<00:36,  1.65s/it]
Compute efficiency: 0.01:   4%|4         | 1/23 [00:02<00:36,  1.65s/it]
Compute efficiency: 0.01:  17%|#7        | 4/23 [00:02<00:10,  1.86it/s]
Compute efficiency: 0.14:  17%|#7        | 4/23 [00:02<00:10,  1.86it/s]
Compute efficiency: 0.16:  17%|#7        | 4/23 [00:02<00:10,  1.86it/s]
Compute efficiency: 0.16:  17%|#7        | 4/23 [00:02<00:10,  1.86it/s]
Compute efficiency: 0.16:  17%|#7        | 4/23 [00:02<00:10,  1.86it/s]
Compute efficiency: 0.16:  35%|###4      | 8/23 [00:02<00:03,  4.32it/s]
Compute efficiency: 0.01:  35%|###4      | 8/23 [00:03<00:03,  4.32it/s]
Compute efficiency: 0.07:  35%|###4      | 8/23 [00:03<00:03,  4.32it/s]
Compute efficiency: 0.07:  43%|####3     | 10/23 [00:03<00:03,  4.29it/s]
Compute efficiency: 0.17:  43%|####3     | 10/23 [00:03<00:03,  4.29it/s]
Compute efficiency: 0.00:  43%|####3     | 10/23 [00:04<00:03,  4.29it/s]
Compute efficiency: 0.00:  52%|#####2    | 12/23 [00:04<00:04,  2.52it/s]
Compute efficiency: 0.12:  52%|#####2    | 12/23 [00:04<00:04,  2.52it/s]
Compute efficiency: 0.13:  52%|#####2    | 12/23 [00:04<00:04,  2.52it/s]
Compute efficiency: 0.12:  52%|#####2    | 12/23 [00:04<00:04,  2.52it/s]
Compute efficiency: 0.12:  52%|#####2    | 12/23 [00:04<00:04,  2.52it/s]
Compute efficiency: 0.12:  70%|######9   | 16/23 [00:04<00:01,  4.41it/s]
Compute efficiency: 0.12:  70%|######9   | 16/23 [00:04<00:01,  4.41it/s]
Compute efficiency: 0.12:  70%|######9   | 16/23 [00:04<00:01,  4.41it/s]
Compute efficiency: 0.11:  70%|######9   | 16/23 [00:04<00:01,  4.41it/s]
Compute efficiency: 0.00:  70%|######9   | 16/23 [00:05<00:01,  4.41it/s]
Compute efficiency: 0.00:  87%|########6 | 20/23 [00:05<00:00,  4.00it/s]
Compute efficiency: 0.15:  87%|########6 | 20/23 [00:05<00:00,  4.00it/s]
Compute efficiency: 0.12:  87%|########6 | 20/23 [00:05<00:00,  4.00it/s]
Compute efficiency: 0.14:  87%|########6 | 20/23 [00:05<00:00,  4.00it/s]
Compute efficiency: 0.14: 100%|##########| 23/23 [00:05<00:00,  3.84it/s]

other example

plot_knn_examples(embeddings)
  • d=0.000, d=0.168, d=0.178
  • d=0.000, d=0.069, d=0.104
  • d=0.000, d=0.225, d=0.231
  • d=0.000, d=0.296, d=0.307
  • d=0.000, d=0.087, d=0.095
  • d=0.000, d=0.057, d=0.058

What’s next?

# You could use the pre-trained model and train a classifier on top.
pretrained_resnet_backbone = model.backbone

# you can also store the backbone and use it in another code
state_dict = {
    'resnet18_parameters': pretrained_resnet_backbone.state_dict()
}
torch.save(state_dict, 'model.pth')

THIS COULD BE IN A NEW FILE (e.g. inference.py

Make sure you place the model.pth file in the same folder as this code

# load the model in a new file for inference
resnet18_new = torchvision.models.resnet18()
last_conv_channels = list(resnet.children())[-1].in_features
# note that we need to create exactly the same backbone in order to load the weights
backbone_new = nn.Sequential(
    *list(resnet.children())[:-1],
    nn.Conv2d(last_conv_channels, num_ftrs, 1),
)

ckpt = torch.load('model.pth')
backbone_new.load_state_dict(ckpt['resnet18_parameters'])

Out:

<All keys matched successfully>