lightly.embedding

The lightly.embedding module provides trainable embedding strategies.

The embedding models use a pre-trained ResNet but should be finetuned on each dataset instance.

.embedding

Embedding Strategies

class lightly.embedding.embedding.SelfSupervisedEmbedding(model: torch.nn.modules.module.Module, criterion: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer, dataloader: torch.utils.data.dataloader.DataLoader, scheduler=None)

Implementation of self-supervised embedding models.

Implements an embedding strategy based on self-supervised learning. A model backbone, self-supervised criterion, optimizer, and dataloader are passed to the constructor. The embedding itself is a pytorch-lightning module which can be trained very easily:

https://pytorch-lightning.readthedocs.io/en/stable/

The implementation is based on contrastive learning.

SimCLR: https://arxiv.org/abs/2002.05709

MoCo: https://arxiv.org/abs/1911.05722

Attributes:
model:

A backbone convolutional network with a projection head.

criterion:

A contrastive loss function.

optimizer:

A PyTorch optimizer.

dataloader:

A torchvision dataloader.

scheduler:

A PyTorch learning rate scheduler.

Examples:
>>> # define a model, criterion, optimizer, and dataloader above
>>> import lightly.embedding as embedding
>>> encoder = SelfSupervisedEmbedding(
>>>     model,
>>>     criterion,
>>>     optimizer,
>>>     dataloader,
>>> )
>>> # train the self-supervised embedding with default settings
>>> encoder.train_embedding()
>>> # pass pytorch-lightning trainer arguments as kwargs
>>> encoder.train_embedding(max_epochs=10)
embed(dataloader: torch.utils.data.dataloader.DataLoader, device: torch.device = None, to_numpy: bool = True)

Embeds images in a vector space.

Args:
dataloader:

A torchvision dataloader.

device:

Selected device (see PyTorch documentation)

to_numpy:

Whether to return the embeddings as numpy array.

Returns:

A tuple consisting of a tensor or ndarray of embeddings with shape n_images x num_ftrs and labels, fnames

Examples:
>>> # embed images in vector space
>>> embeddings, labels, fnames = encoder.embed(dataloader)