lightly.embedding¶
The lightly.embedding module provides trainable embedding strategies.
The embedding models use a pre-trained ResNet but should be finetuned on each dataset instance.
.embedding¶
Embedding Strategies
-
class
lightly.embedding.embedding.
SelfSupervisedEmbedding
(model: torch.nn.modules.module.Module, criterion: torch.nn.modules.module.Module, optimizer: torch.optim.optimizer.Optimizer, dataloader: torch.utils.data.dataloader.DataLoader, scheduler=None)¶ Implementation of self-supervised embedding models.
Implements an embedding strategy based on self-supervised learning. A model backbone, self-supervised criterion, optimizer, and dataloader are passed to the constructor. The embedding itself is a pytorch-lightning module which can be trained very easily:
https://pytorch-lightning.readthedocs.io/en/stable/
The implementation is based on contrastive learning.
SimCLR: https://arxiv.org/abs/2002.05709
MoCo: https://arxiv.org/abs/1911.05722
- Attributes:
- model:
A backbone convolutional network with a projection head.
- criterion:
A contrastive loss function.
- optimizer:
A PyTorch optimizer.
- dataloader:
A torchvision dataloader.
- scheduler:
A PyTorch learning rate scheduler.
- Examples:
>>> # define a model, criterion, optimizer, and dataloader above >>> import lightly.embedding as embedding >>> encoder = SelfSupervisedEmbedding( >>> model, >>> criterion, >>> optimizer, >>> dataloader, >>> ) >>> # train the self-supervised embedding with default settings >>> encoder.train_embedding() >>> # pass pytorch-lightning trainer arguments as kwargs >>> encoder.train_embedding(max_epochs=10)
-
embed
(dataloader: torch.utils.data.dataloader.DataLoader, device: torch.device = None, to_numpy: bool = True)¶ Embeds images in a vector space.
- Args:
- dataloader:
A torchvision dataloader.
- device:
Selected device (see PyTorch documentation)
- to_numpy:
Whether to return the embeddings as numpy array.
- Returns:
A tuple consisting of a tensor or ndarray of embeddings with shape n_images x num_ftrs and labels, fnames
- Examples:
>>> # embed images in vector space >>> embeddings, labels, fnames = encoder.embed(dataloader)