lightly.loss
The lightly.loss package provides loss functions for self-supervised learning.
- class lightly.loss.barlow_twins_loss.BarlowTwinsLoss(lambda_param: float = 0.005, gather_distributed: bool = False)
Implementation of the Barlow Twins Loss from Barlow Twins[0] paper.
This code specifically implements the Figure Algorithm 1 from [0]. [0] Zbontar,J. et.al, 2021, Barlow Twins… https://arxiv.org/abs/2103.03230
Examples
>>> # initialize loss function >>> loss_fn = BarlowTwinsLoss() >>> >>> # generate two random transforms of images >>> t0 = transforms(images) >>> t1 = transforms(images) >>> >>> # feed through SimSiam model >>> out0, out1 = model(t0, t1) >>> >>> # calculate loss >>> loss = loss_fn(out0, out1)
- forward(z_a: Tensor, z_b: Tensor) Tensor
Computes the Barlow Twins loss for the given projections.
- Parameters
z_a – Output projections of the first set of transformed images.
z_b – Output projections of the second set of transformed images.
- Returns
Computed Barlow Twins Loss.
- class lightly.loss.dcl_loss.DCLLoss(temperature: float = 0.1, weight_fn: Optional[Callable[[Tensor, Tensor], Tensor]] = None, gather_distributed: bool = False)
Implementation of the Decoupled Contrastive Learning Loss from Decoupled Contrastive Learning [0].
This code implements Equation 6 in [0], including the sum over all images i and views k. The loss is reduced to a mean loss over the mini-batch. The implementation was inspired by [1].
[0] Chun-Hsiao Y. et. al., 2021, Decoupled Contrastive Learning https://arxiv.org/abs/2110.06848
[1] https://github.com/raminnakhli/Decoupled-Contrastive-Learning
- temperature
Similarities are scaled by inverse temperature.
- weight_fn
Weighting function w from the paper. Scales the loss between the positive views (views from the same image). No weighting is performed if weight_fn is None. The function must take the two input tensors passed to the forward call as input and return a weight tensor. The returned weight tensor must have the same length as the input tensors.
- gather_distributed
If True, negatives from all GPUs are gathered before the loss calculation.
Examples
>>> loss_fn = DCLLoss(temperature=0.07) >>> >>> # generate two random transforms of images >>> t0 = transforms(images) >>> t1 = transforms(images) >>> >>> # embed images using some model, for example SimCLR >>> out0 = model(t0) >>> out1 = model(t1) >>> >>> # calculate loss >>> loss = loss_fn(out0, out1) >>> >>> # you can also add a custom weighting function >>> weight_fn = lambda out0, out1: torch.sum((out0 - out1) ** 2, dim=1) >>> loss_fn = DCLLoss(weight_fn=weight_fn)
- forward(out0: Tensor, out1: Tensor) Tensor
Forward pass of the DCL loss.
- Parameters
out0 – Output projections of the first set of transformed images. Shape: (batch_size, embedding_size)
out1 – Output projections of the second set of transformed images. Shape: (batch_size, embedding_size)
- Returns
Mean loss over the mini-batch.
- class lightly.loss.dcl_loss.DCLWLoss(temperature: float = 0.1, sigma: float = 0.5, gather_distributed: bool = False)
Implementation of the Weighted Decoupled Contrastive Learning Loss from Decoupled Contrastive Learning [0].
This code implements Equation 6 in [0] with a negative Mises-Fisher weighting function. The loss returns the mean over all images i and views k in the mini-batch. The implementation was inspired by [1].
[0] Chun-Hsiao Y. et. al., 2021, Decoupled Contrastive Learning https://arxiv.org/abs/2110.06848
[1] https://github.com/raminnakhli/Decoupled-Contrastive-Learning
- temperature
Similarities are scaled by inverse temperature.
- sigma
Similar to temperature but applies the inverse scaling in the weighting function.
- gather_distributed
If True, negatives from all GPUs are gathered before the loss calculation.
Examples
>>> loss_fn = DCLWLoss(temperature=0.07) >>> >>> # generate two random transforms of images >>> t0 = transforms(images) >>> t1 = transforms(images) >>> >>> # embed images using some model, for example SimCLR >>> out0 = model(t0) >>> out1 = model(t1) >>> >>> # calculate loss >>> loss = loss_fn(out0, out1)
- class lightly.loss.dino_loss.DINOLoss(output_dim: int = 65536, warmup_teacher_temp: float = 0.04, teacher_temp: float = 0.04, warmup_teacher_temp_epochs: int = 30, student_temp: float = 0.1, center_momentum: float = 0.9, center_mode: str = 'mean')
Implementation of the loss described in ‘Emerging Properties in Self-Supervised Vision Transformers’. [0]
This implementation follows the code published by the authors. [1] It supports global and local image crops. A linear warmup schedule for the teacher temperature is implemented to stabilize training at the beginning. Centering is applied to the teacher output to avoid model collapse.
[0]: DINO, 2021, https://arxiv.org/abs/2104.14294
- output_dim
Dimension of the model output.
- warmup_teacher_temp
Initial value of the teacher temperature. Should be decreased if the training loss does not decrease.
- teacher_temp
Final value of the teacher temperature after linear warmup. Values above 0.07 result in unstable behavior in most cases. Can be slightly increased to improve performance during fine-tuning.
- warmup_teacher_temp_epochs
Number of epochs for the teacher temperature warmup.
- student_temp
Temperature of the student.
- center_momentum
Momentum term for the center calculation.
Examples
>>> # initialize loss function >>> loss_fn = DINOLoss(128) >>> >>> # generate a view of the images with a random transform >>> view = transform(images) >>> >>> # embed the view with a student and teacher model >>> teacher_out = teacher(view) >>> student_out = student(view) >>> >>> # calculate loss >>> loss = loss_fn([teacher_out], [student_out], epoch=0)
- forward(teacher_out: List[Tensor], student_out: List[Tensor], epoch: int) Tensor
Cross-entropy between softmax outputs of the teacher and student networks.
- Parameters
teacher_out – List of tensors with shape (batch_size, output_dim) containing features from the teacher model. Each tensor must represent one view of the batch.
student_out – List of tensors with shape (batch_size, output_dim) containing features from the student model. Each tensor must represent one view of the batch.
epoch – The current training epoch.
update_center – If True, the center used for the teacher output is updated after the loss calculation.
- Returns
The average cross-entropy loss.
- update_center(teacher_out: Tensor) None
Moving average update of the center used for the teacher output.
- Parameters
teacher_out – Tensor with shape (num_views, batch_size, output_dim) containing features from the teacher model.
- class lightly.loss.hypersphere_loss.HypersphereLoss(t=1.0, lam=1.0, alpha=2.0)
Implementation of the loss described in ‘Understanding Contrastive Representation Learning through Alignment and Uniformity on the Hypersphere.’ [0]
[0] Tongzhou Wang. et.al, 2020, … https://arxiv.org/abs/2005.10242
Note
In order for this loss to function as advertized, an L1-normalization to the hypersphere is required. This loss function applies this L1-normalization internally in the loss layer. However, it is recommended that the same normalization is also applied in your architecture, considering that this L1-loss is also intended to be applied during inference. Perhaps there may be merit in leaving it out of the inferrence pathway, but this use has not been tested.
Moreover it is recommended that the layers preceeding this loss function are either a linear layer without activation, a batch-normalization layer, or both. The directly upstream architecture can have a large influence on the ability of this loss to achieve its stated aim of promoting uniformity on the hypersphere; and if by contrast the last layer going into the embedding is a RELU or similar nonlinearity, we may see that we will never get very close to achieving the goal of uniformity on the hypersphere, but will confine ourselves to the subspace of positive activations. Similar architectural considerations are relevant to most contrastive loss functions, but we call it out here explicitly.
Examples
>>> # initialize loss function >>> loss_fn = HypersphereLoss() >>> >>> # generate two random transforms of images >>> t0 = transforms(images) >>> t1 = transforms(images) >>> >>> # feed through SimSiam model >>> out0, out1 = model(t0, t1) >>> >>> # calculate loss >>> loss = loss_fn(out0, out1)
- forward(z_a: Tensor, z_b: Tensor) Tensor
Computes the Hypersphere loss, which combines alignment and uniformity loss terms.
- Parameters
z_a – Tensor of shape (batch_size, embedding_dim) for the first set of embeddings.
z_b – Tensor of shape (batch_size, embedding_dim) for the second set of embeddings.
- Returns
The computed loss.
- class lightly.loss.ibot_loss.IBOTPatchLoss(output_dim: int, teacher_temp: float = 0.04, student_temp: float = 0.1, center_mode: str = 'mean', center_momentum: float = 0.9)
Implementation of the iBOT patch loss [0] as used in DINOv2 [1].
Implementation is based on [2].
[0]: iBOT, 2021, https://arxiv.org/abs/2111.07832
[1]: DINOv2, 2023, https://arxiv.org/abs/2304.07193
[2]: https://github.com/facebookresearch/dinov2/blob/main/dinov2/loss/ibot_patch_loss.py
- output_dim
Dimension of the model output.
- teacher_temp
Temperature for the teacher output.
- student_temp
Temperature for the student output.
- center_mode
Mode for center calculation. Only ‘mean’ is supported.
- center_momentum
Momentum term for the center update.
- forward(teacher_out: Tensor, student_out: Tensor, mask: Tensor) Tensor
Forward pass through the iBOT patch loss.
- Parameters
teacher_out – Tensor with shape (batch_size * sequence_length, embed_dim) containing the teacher output of the masked tokens.
student_out – Tensor with shape (batch_size * sequence_length, embed_dim) containing the student output of the masked tokens.
mask – Boolean tensor with shape (batch_size, height, width) containing the token mask. Exactly batch_size * sequence_length entries must be set to True in the mask.
- Returns
The loss value.
- class lightly.loss.koleo_loss.KoLeoLoss(p: float = 2, eps: float = 1e-08)
KoLeo loss based on [0].
KoLeo loss is a regularizer that encourages a uniform span of the features in a batch by penalizing the distance between the features and their nearest neighbors.
Implementation is based on [1].
[0]: Spreading vectors for similarity search, 2019, https://arxiv.org/abs/1806.03198
[1]: https://github.com/facebookresearch/dinov2/blob/main/dinov2/loss/koleo_loss.py
- p
The norm degree for pairwise distance calculation.
- eps
Small value to avoid division by zero.
- forward(x: Tensor) Tensor
Forward pass through KoLeo Loss.
- Parameters
x – Tensor with shape (batch_size, embedding_size).
- Returns
Loss value.
- class lightly.loss.memory_bank.MemoryBankModule(size: Union[int, Sequence[int]] = 65536, gather_distributed: bool = False, feature_dim_first: bool = True)
Memory bank implementation
This is a parent class to all loss functions implemented by the lightly Python package. This way, any loss can be used with a memory bank if desired.
- size
Size of the memory bank as (num_features, dim) tuple. If num_features is 0 then the memory bank is disabled. Deprecated: If only a single integer is passed, it is interpreted as the number of features and the feature dimension is inferred from the first batch stored in the memory bank. Leaving out the feature dimension might lead to errors in distributed training.
- gather_distributed
If True then negatives from all gpus are gathered before the memory bank is updated. This results in more frequent updates of the memory bank and keeps the memory bank contents independent of the number of gpus. But it has the drawback that synchronization between processes is required and diversity of the memory bank content is reduced.
- feature_dim_first
If True, the memory bank returns features with shape (dim, num_features). If False, the memory bank returns features with shape (num_features, dim).
Examples
>>> class MyLossFunction(MemoryBankModule): >>> >>> def __init__(self, memory_bank_size: Tuple[int, int] = (2 ** 16, 128)): >>> super().__init__(memory_bank_size) >>> >>> def forward(self, output: Tensor, labels: Optional[Tensor] = None): >>> output, negatives = super().forward(output) >>> >>> if negatives is not None: >>> # evaluate loss with negative samples >>> else: >>> # evaluate loss without negative samples
- forward(output: Tensor, labels: Optional[Tensor] = None, update: bool = False) Tuple[Tensor, Optional[Tensor]]
Query memory bank for additional negative samples
- Parameters
output – The output of the model.
labels – Should always be None, will be ignored.
update – If True, the memory bank will be updated with the current output.
- Returns
The output if the memory bank is of size 0, otherwise the output and the entries from the memory bank. Entries from the memory bank have shape (dim, num_features) if feature_dim_first is True and (num_features, dim) otherwise.
- class lightly.loss.mmcr_loss.MMCRLoss(lmda: float = 0.005)
Implementation of the loss function from MMCR [0] using Manifold Capacity. All hyperparameters are set to the default values from the paper for ImageNet.
[0]: Efficient Coding of Natural Images using Maximum Manifold Capacity
Representations, 2023, https://arxiv.org/pdf/2303.03307.pdf
Examples
>>> # initialize loss function >>> loss_fn = MMCRLoss() >>> transform = MMCRTransform(k=2) >>> >>> # transform images, then feed through encoder and projector >>> x = transform(x) >>> online = online_network(x) >>> momentum = momentum_network(x) >>> >>> # calculate loss >>> loss = loss_fn(online, momentum)
- forward(online: Tensor, momentum: Tensor) Tensor
Computes the MMCR loss for the online and momentum network outputs.
- Parameters
online – Output of the online network for the current batch. Expected to be of shape (batch_size, k, embedding_size), where k represents the number of randomly augmented views for each sample.
momentum – Output of the momentum network for the current batch. Expected to be of shape (batch_size, k, embedding_size), where k represents the number of randomly augmented views for each sample.
- Returns
The computed loss value.
- class lightly.loss.msn_loss.MSNLoss(temperature: float = 0.1, sinkhorn_iterations: int = 3, regularization_weight: float = 1.0, me_max_weight: Optional[float] = None, gather_distributed: bool = False)
Implementation of the loss function from MSN [0].
Code inspired by [1].
[0]: Masked Siamese Networks, 2022, https://arxiv.org/abs/2204.07141
- temperature
Similarities between anchors and targets are scaled by the inverse of the temperature. Must be in (0, inf).
- sinkhorn_iterations
Number of sinkhorn normalization iterations on the targets.
- regularization_weight
Weight factor lambda by which the regularization loss is scaled. Set to 0 to disable regularization.
- me_max_weight
Deprecated, use regularization_weight instead. Takes precendence over regularization_weight if not None. Weight factor lambda by which the mean entropy maximization regularization loss is scaled. Set to 0 to disable mean entropy maximization reguliarization.
- gather_distributed
If True, then target probabilities are gathered from all GPUs.
Examples
>>> # initialize loss function >>> loss_fn = MSNLoss() >>> >>> # generate anchors and targets of images >>> anchors = transforms(images) >>> targets = transforms(images) >>> >>> # feed through MSN model >>> anchors_out = model(anchors) >>> targets_out = model.target(targets) >>> >>> # calculate loss >>> loss = loss_fn(anchors_out, targets_out, prototypes=model.prototypes)
- forward(anchors: Tensor, targets: Tensor, prototypes: Tensor, target_sharpen_temperature: float = 0.25) Tensor
Computes the MSN loss for a set of anchors, targets, and prototypes.
- Parameters
anchors – Tensor with shape (batch_size * anchor_views, dim).
targets – Tensor with shape (batch_size, dim).
prototypes – Tensor with shape (num_prototypes, dim).
target_sharpen_temperature – Temperature used to sharpen the target probabilities.
- Returns
Mean loss over all anchors.
- regularization_loss(mean_anchor_probs: Tensor) Tensor
Calculates mean entropy regularization loss.
- Parameters
mean_anchor_probs – The mean anchor probabilities.
- Returns
The calculated regularization loss.
- class lightly.loss.negative_cosine_similarity.NegativeCosineSimilarity(dim: int = 1, eps: float = 1e-08)
Implementation of the Negative Cosine Simililarity used in the SimSiam[0] paper.
[0] SimSiam, 2020, https://arxiv.org/abs/2011.10566
Examples
>>> # initialize loss function >>> loss_fn = NegativeCosineSimilarity() >>> >>> # generate two representation tensors >>> # with batch size 10 and dimension 128 >>> x0 = torch.randn(10, 128) >>> x1 = torch.randn(10, 128) >>> >>> # calculate loss >>> loss = loss_fn(x0, x1)
- forward(x0: Tensor, x1: Tensor) Tensor
Computes the negative cosine similarity between two tensors.
- Parameters
x0 – First input tensor.
x1 – Second input tensor.
- Returns
The mean negative cosine similarity.
- class lightly.loss.ntx_ent_loss.NTXentLoss(temperature: float = 0.5, memory_bank_size: Union[int, Sequence[int]] = 0, gather_distributed: bool = False)
Implementation of the Contrastive Cross Entropy Loss.
This implementation follows the SimCLR[0] paper. If you enable the memory bank by setting the memory_bank_size value > 0 the loss behaves like the one described in the MoCo[1] paper.
[0] SimCLR, 2020, https://arxiv.org/abs/2002.05709
[1] MoCo, 2020, https://arxiv.org/abs/1911.05722
- temperature
Scale logits by the inverse of the temperature.
- memory_bank_size
Size of the memory bank as (num_features, dim) tuple. num_features are the number of negative samples stored in the memory bank. If num_features is 0, the memory bank is disabled. Use 0 for SimCLR. For MoCo we typically use numbers like 4096 or 65536. Deprecated: If only a single integer is passed, it is interpreted as the number of features and the feature dimension is inferred from the first batch stored in the memory bank. Leaving out the feature dimension might lead to errors in distributed training.
- gather_distributed
If True then negatives from all GPUs are gathered before the loss calculation. If a memory bank is used and gather_distributed is True, then tensors from all gpus are gathered before the memory bank is updated.
- Raises
ValueError – If abs(temperature) < 1e-8 to prevent divide by zero.
Examples
>>> # initialize loss function without memory bank >>> loss_fn = NTXentLoss(memory_bank_size=0) >>> >>> # generate two random transforms of images >>> t0 = transforms(images) >>> t1 = transforms(images) >>> >>> # feed through SimCLR or MoCo model >>> batch = torch.cat((t0, t1), dim=0) >>> output = model(batch) >>> >>> # calculate loss >>> loss = loss_fn(output)
- forward(out0: Tensor, out1: Tensor)
Forward pass through Contrastive Cross-Entropy Loss.
If used with a memory bank, the samples from the memory bank are used as negative examples. Otherwise, within-batch samples are used as negative samples.
- Parameters
out0 – Output projections of the first set of transformed images. Shape: (batch_size, embedding_size)
out1 – Output projections of the second set of transformed images. Shape: (batch_size, embedding_size)
- Returns
Contrastive Cross Entropy Loss value.
- class lightly.loss.pmsn_loss.PMSNLoss(temperature: float = 0.1, sinkhorn_iterations: int = 3, regularization_weight: float = 1, power_law_exponent: float = 0.25, gather_distributed: bool = False)
Implementation of the loss function from PMSN [0] using a power law target distribution.
[0]: Prior Matching for Siamese Networks, 2022, https://arxiv.org/abs/2210.07277
- temperature
Similarities between anchors and targets are scaled by the inverse of the temperature. Must be in (0, inf).
- sinkhorn_iterations
Number of sinkhorn normalization iterations on the targets.
- regularization_weight
Weight factor lambda by which the regularization loss is scaled. Set to 0 to disable regularization.
- power_law_exponent
Exponent for power law distribution. Entry k of the distribution is proportional to (1 / k) ^ power_law_exponent, with k ranging from 1 to dim + 1.
- gather_distributed
If True, then target probabilities are gathered from all GPUs.
Examples
>>> # initialize loss function >>> loss_fn = PMSNLoss() >>> >>> # generate anchors and targets of images >>> anchors = transforms(images) >>> targets = transforms(images) >>> >>> # feed through PMSN model >>> anchors_out = model(anchors) >>> targets_out = model.target(targets) >>> >>> # calculate loss >>> loss = loss_fn(anchors_out, targets_out, prototypes=model.prototypes)
- regularization_loss(mean_anchor_probs: Tensor) Tensor
Calculates the regularization loss with a power law target distribution.
- Parameters
mean_anchor_probs – The mean anchor probabilities.
- Returns
The calculated regularization loss.
- class lightly.loss.pmsn_loss.PMSNCustomLoss(target_distribution: Callable[[Tensor], Tensor], temperature: float = 0.1, sinkhorn_iterations: int = 3, regularization_weight: float = 1, gather_distributed: bool = False)
Implementation of the loss function from PMSN [0] with a custom target distribution.
[0]: Prior Matching for Siamese Networks, 2022, https://arxiv.org/abs/2210.07277
- target_distribution
A function that takes the mean anchor probabilities tensor with shape (dim,) as input and returns a target probability distribution tensor with the same shape. The returned distribution should sum up to one. The final regularization loss is calculated as KL(mean_anchor_probs, target_dist) where KL is the Kullback-Leibler divergence.
- temperature
Similarities between anchors and targets are scaled by the inverse of the temperature. Must be in (0, inf).
- sinkhorn_iterations
Number of sinkhorn normalization iterations on the targets.
- regularization_weight
Weight factor lambda by which the regularization loss is scaled. Set to 0 to disable regularization.
- gather_distributed
If True, then target probabilities are gathered from all GPUs.
Examples
>>> # define custom target distribution >>> def my_uniform_distribution(mean_anchor_probabilities: Tensor) -> Tensor: >>> dim = mean_anchor_probabilities.shape[0] >>> return mean_anchor_probabilities.new_ones(dim) / dim >>> >>> # initialize loss function >>> loss_fn = PMSNCustomLoss(target_distribution=my_uniform_distribution) >>> >>> # generate anchors and targets of images >>> anchors = transforms(images) >>> targets = transforms(images) >>> >>> # feed through PMSN model >>> anchors_out = model(anchors) >>> targets_out = model.target(targets) >>> >>> # calculate loss >>> loss = loss_fn(anchors_out, targets_out, prototypes=model.prototypes)
- regularization_loss(mean_anchor_probs: Tensor) Tensor
Calculates regularization loss with a custom target distribution.
- Parameters
mean_anchor_probs – The mean anchor probabilities.
- Returns
The calculated regularization loss.
- class lightly.loss.regularizer.co2.CO2Regularizer(alpha: float = 1, t_consistency: float = 0.05, memory_bank_size: Union[int, Sequence[int]] = 0)
Implementation of the CO2 regularizer [0] for self-supervised learning.
[0] CO2, 2021, https://arxiv.org/abs/2010.02217
- alpha
Weight of the regularization term.
- t_consistency
Temperature used during softmax calculations.
- memory_bank_size
Size of the memory bank as (num_features, dim) tuple. num_features is the number of negatives stored in the bank. If set to 0, the memory bank is disabled. Deprecated: If only a single integer is passed, it is interpreted as the number of features and the feature dimension is inferred from the first batch stored in the memory bank. Leaving out the feature dimension might lead to errors in distributed training.
Examples
>>> # initialize loss function for MoCo >>> loss_fn = NTXentLoss(memory_bank_size=(4096, 128)) >>> >>> # initialize CO2 regularizer >>> co2 = CO2Regularizer(alpha=1.0, memory_bank_size=(4096, 128)) >>> >>> # generate two random trasnforms of images >>> t0 = transforms(images) >>> t1 = transforms(images) >>> >>> # feed through the MoCo model >>> out0, out1 = model(t0, t1) >>> >>> # calculate loss and apply regularizer >>> loss = loss_fn(out0, out1) + co2(out0, out1)
- forward(out0: Tensor, out1: Tensor)
Computes the CO2 regularization term for two model outputs.
- Parameters
out0 – Output projections of the first set of transformed images.
out1 – Output projections of the second set of transformed images.
- Returns
The regularization term multiplied by the weight factor alpha.
- class lightly.loss.swav_loss.SwaVLoss(temperature: float = 0.1, sinkhorn_iterations: int = 3, sinkhorn_epsilon: float = 0.05, sinkhorn_gather_distributed: bool = False)
Implementation of the SwaV loss.
- temperature
Temperature parameter used for cross entropy calculations.
- sinkhorn_iterations
Number of iterations of the sinkhorn algorithm.
- sinkhorn_epsilon
Temperature parameter used in the sinkhorn algorithm.
- sinkhorn_gather_distributed
If True, features from all GPUs are gathered to calculate the soft codes in the sinkhorn algorithm.
- forward(high_resolution_outputs: List[Tensor], low_resolution_outputs: List[Tensor], queue_outputs: Optional[List[Tensor]] = None)
Computes the SwaV loss for a set of high and low resolution outputs.
[0]: SwaV, 2020, https://arxiv.org/abs/2006.09882
- Parameters
high_resolution_outputs – List of similarities of features and SwaV prototypes for the high resolution crops.
low_resolution_outputs – List of similarities of features and SwaV prototypes for the low resolution crops.
queue_outputs – List of similarities of features and SwaV prototypes for the queue of high resolution crops from previous batches.
- Returns
Swapping assignments between views loss (SwaV) as described in [0].
- subloss(z: Tensor, q: Tensor)
Calculates the cross entropy for the SwaV prediction problem.
- Parameters
z – Similarity of the features and the SwaV prototypes.
q – Codes obtained from Sinkhorn iterations.
- Returns
Cross entropy between predictions z and codes q.
- class lightly.loss.sym_neg_cos_sim_loss.SymNegCosineSimilarityLoss
Implementation of the Symmetrized Loss used in the SimSiam[0] paper.
[0] SimSiam, 2020, https://arxiv.org/abs/2011.10566
Examples
>>> # initialize loss function >>> loss_fn = SymNegCosineSimilarityLoss() >>> >>> # generate two random transforms of images >>> t0 = transforms(images) >>> t1 = transforms(images) >>> >>> # feed through SimSiam model >>> out0, out1 = model(t0, t1) >>> >>> # calculate loss >>> loss = loss_fn(out0, out1)
- forward(out0: Tensor, out1: Tensor)
Forward pass through Symmetric Loss.
- Parameters
out0 – Output projections of the first set of transformed images. Expects the tuple to be of the form (z0, p0), where z0 is the output of the backbone and projection MLP, and p0 is the output of the prediction head.
out1 – Output projections of the second set of transformed images. Expects the tuple to be of the form (z1, p1), where z1 is the output of the backbone and projection MLP, and p1 is the output of the prediction head.
- Returns
Negative Cosine Similarity loss value.
- class lightly.loss.tico_loss.TiCoLoss(beta: float = 0.9, rho: float = 8.0, gather_distributed: bool = False)
Implementation of the Tico Loss from Tico[0] paper.
This implementation takes inspiration from the code published by sayannag using Lightly. [1]
[0] Jiachen Zhu et. al, 2022, Tico… https://arxiv.org/abs/2206.10698
- Args
- beta:
Coefficient for the EMA update of the covariance Defaults to 0.9 [0].
- rho:
Weight for the covariance term of the loss Defaults to 8.0 [0].
- gather_distributed:
If True, the cross-correlation matrices from all GPUs are gathered and summed before the loss calculation.
Examples
>>> # initialize loss function >>> loss_fn = TiCoLoss() >>> >>> # generate two random transforms of images >>> t0 = transforms(images) >>> t1 = transforms(images) >>> >>> # feed through model >>> out0, out1 = model(t0, t1) >>> >>> # calculate loss >>> loss = loss_fn(out0, out1)
- forward(z_a: Tensor, z_b: Tensor, update_covariance_matrix: bool = True) Tensor
Computes the TiCo loss.
It maximizes the agreement among embeddings of different distorted versions of the same image while avoiding collapse using Covariance matrix.
- Parameters
z_a – Tensor of shape [batch_size, num_features=256]. Output of the learned backbone.
z_b – Tensor of shape [batch_size, num_features=256]. Output of the momentum updated backbone.
update_covariance_matrix – Parameter to update the covariance matrix at each iteration.
- Returns
The computed loss.
- Raises
AssertionError – If z_a or z_b have a batch size <= 1.
AssertionError – If z_a and z_b do not have the same shape.
- class lightly.loss.vicreg_loss.VICRegLoss(lambda_param: float = 25.0, mu_param: float = 25.0, nu_param: float = 1.0, gather_distributed: bool = False, eps=0.0001)
Implementation of the VICReg loss [0].
This implementation is based on the code published by the authors [1].
[0] VICReg, 2022, https://arxiv.org/abs/2105.04906
- lambda_param
Scaling coefficient for the invariance term of the loss.
- mu_param
Scaling coefficient for the variance term of the loss.
- nu_param
Scaling coefficient for the covariance term of the loss.
- gather_distributed
If True, the cross-correlation matrices from all GPUs are gathered and summed before the loss calculation.
- eps
Epsilon for numerical stability.
Examples
>>> # initialize loss function >>> loss_fn = VICRegLoss() >>> >>> # generate two random transforms of images >>> t0 = transforms(images) >>> t1 = transforms(images) >>> >>> # feed through model >>> out0, out1 = model(t0, t1) >>> >>> # calculate loss >>> loss = loss_fn(out0, out1)
- forward(z_a: Tensor, z_b: Tensor) Tensor
Returns VICReg loss.
- Parameters
z_a – Tensor with shape (batch_size, …, dim).
z_b – Tensor with shape (batch_size, …, dim).
- Returns
The computed VICReg loss.
- Raises
AssertionError – If z_a or z_b have a batch size <= 1.
AssertionError – If z_a and z_b do not have the same shape.
- class lightly.loss.vicregl_loss.VICRegLLoss(lambda_param: float = 25.0, mu_param: float = 25.0, nu_param: float = 1.0, alpha: float = 0.75, gather_distributed: bool = False, eps: float = 0.0001, num_matches: Tuple[int, int] = (20, 4))
Implementation of the VICRegL loss from VICRegL paper [0].
This implementation follows the code published by the authors [1].
[0]: VICRegL, 2022, https://arxiv.org/abs/2210.01571
- lambda_param
Coefficient for the invariance term of the loss.
- mu_param
Coefficient for the variance term of the loss.
- nu_param
Coefficient for the covariance term of the loss.
- alpha
Coefficient to weight global with local loss. The final loss is computed as (self.alpha * global_loss + (1-self.alpha) * local_loss).
- gather_distributed
If True, the cross-correlation matrices from all gpus are gathered and summed before the loss calculation.
- eps
Epsilon for numerical stability.
- num_matches
Number of local features to match using nearest neighbors.
Examples
>>> # initialize loss function >>> criterion = VICRegLLoss() >>> transform = VICRegLTransform(n_global_views=2, n_local_views=4) >>> >>> # generate two random transforms of images >>> views_and_grids = transform(images) >>> views = views_and_grids[:6] # 2 global views + 4 local views >>> grids = views_and_grids[6:] >>> >>> # feed through model images >>> features = [model(view) for view in views] >>> >>> # calculate loss >>> loss = criterion( ... global_view_features=features[:2], ... global_view_grids=grids[:2], ... local_view_features=features[2:], ... local_view_grids=grids[2:], ... )
- forward(global_view_features: Sequence[Tuple[Tensor, Tensor]], global_view_grids: Sequence[Tensor], local_view_features: Optional[Sequence[Tuple[Tensor, Tensor]]] = None, local_view_grids: Optional[Sequence[Tensor]] = None) Tensor
Computes the global and local VICRegL loss from the input features.
- Parameters
global_view_features – Sequence of (global_features, local_features) tuples from the global crop views. global_features must have size (batch_size, global_feature_dim) and local_features must have size (batch_size, grid_height, grid_width, local_feature_dim).
global_view_grids – Sequence of grid tensors from the global crop views. Every tensor must have shape (batch_size, grid_height, grid_width, 2).
local_view_features – Sequence of (global_features, local_features) tuples from the local crop views. global_features must have size (batch_size, global_feature_dim) and local_features must have size (batch_size, grid_height, grid_width, local_feature_dim). Note that grid_height and grid_width can differ between global_view_features and local_view_features.
local_view_grids – Sequence of grid tensors from the local crop views. Every tensor must have shape (batch_size, grid_height, grid_width, 2). Note that grid_height and grid_width can differ between global_view_features and local_view_features.
- Returns
Weighted sum of the global and local loss, calculated as – self.alpha * global_loss + (1-self.alpha) * local_loss.
- Raises
ValueError – If the lengths of global_view_features and global_view_grids are not the same.
ValueError – If the lengths of local_view_features and local_view_grids are not the same.
ValueError – If only one of local_view_features or local_view_grids is set.