lightly.models.utils

Utils for working with SSL models

lightly.models.utils.activate_requires_grad(model: Module)

Activates the requires_grad flag for all parameters of a model.

Use this method to activate gradients for a model (e.g. after deactivating them using deactivate_requires_grad(…)).

Examples

>>> backbone = resnet18()
>>> activate_requires_grad(backbone)
lightly.models.utils.add_stochastic_depth_to_blocks(vit: Module, prob: float = 0.0, mode='row') None

Adds stochastic depth dropout to all transformer blocks in a Vision Transformer Model

Parameters
  • vit – Vision Transformer Model to which stochastic depth dropout will be added.

  • prob – Probability of dropping a layer.

  • mode – Mode for stochastic depth. Default is “row”.

Raises

Runtime Error – If torchvision version is less than 0.12.

lightly.models.utils.batch_shuffle(batch: Tensor, distributed: bool = False) Tuple[Tensor, Tensor]

Randomly shuffles all tensors in the batch.

Parameters
  • batch – The batch to shuffle.

  • distributed – If True then batches are shuffled across multiple gpus.

Returns

A (batch, shuffle) tuple where batch is the shuffled version of the input batch and shuffle is an index to restore the original order.

Examples

>>> # forward pass through the momentum model with batch shuffling
>>> x1_shuffled, shuffle = batch_shuffle(x1)
>>> f1 = moco_momentum(x1)
>>> out0 = projection_head_momentum(f0)
>>> out1 = batch_unshuffle(out1, shuffle)
lightly.models.utils.batch_shuffle_distributed(batch: Tensor) Tuple[Tensor, Tensor]

Shuffles batch over multiple devices.

This code was taken and adapted from here: https://github.com/facebookresearch/moco.

Parameters

batch – The tensor to shuffle.

Returns

A (batch, shuffle) tuple where batch is the shuffled version of the input batch and shuffle is an index to restore the original order.

lightly.models.utils.batch_unshuffle(batch: Tensor, shuffle: Tensor, distributed: bool = False) Tensor

Unshuffles a batch.

Parameters
  • batch – The batch to unshuffle.

  • shuffle – Index to unshuffle the batch.

  • distributed – If True then the batch is unshuffled across multiple gpus.

Returns

The unshuffled batch.

Examples

>>> # forward pass through the momentum model with batch shuffling
>>> x1_shuffled, shuffle = batch_shuffle(x1)
>>> f1 = moco_momentum(x1)
>>> out0 = projection_head_momentum(f0)
>>> out1 = batch_unshuffle(out1, shuffle)
lightly.models.utils.batch_unshuffle_distributed(batch: Tensor, shuffle: Tensor) Tensor

Undo batch shuffle over multiple devices.

This code was taken and adapted from here: https://github.com/facebookresearch/moco.

Parameters
  • batch – The tensor to unshuffle.

  • shuffle – Index to restore the original tensor.

Returns

The unshuffled tensor.

lightly.models.utils.concat_all_gather(x: Tensor) Tensor

Returns concatenated instances of x gathered from all gpus.

This code was taken and adapted from here: https://github.com/facebookresearch/moco.

lightly.models.utils.deactivate_requires_grad(model: Module)

Deactivates the requires_grad flag for all parameters of a model.

This has the same effect as permanently executing the model within a torch.no_grad() context. Use this method to disable gradient computation and therefore training for a model.

Examples

>>> backbone = resnet18()
>>> deactivate_requires_grad(backbone)
lightly.models.utils.expand_index_like(index: Tensor, tokens: Tensor) Tensor

Expands the index along the last dimension of the input tokens.

Parameters
  • index – Index tensor with shape (batch_size, idx_length) where each entry is an index in [0, sequence_length).

  • tokens – Tokens tensor with shape (batch_size, sequence_length, dim).

Returns

Index tensor with shape (batch_size, idx_length, dim) where the original indices are repeated dim times along the last dimension.

lightly.models.utils.get_1d_sine_cosine_positional_embedding_from_positions(embed_dim: int, pos: NDArray[np.float32]) NDArray[np.float32]

Generates 1D sine-cosine positional embedding from positions.

Code follows: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py

Parameters
  • embed_dim – Embedding dimension.

  • pos – Positions to be encoded with shape (N, M).

Returns

Positional embedding with shape (N * M, embed_dim).

lightly.models.utils.get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool) NDArray[np.float32]

Generates 2D sine-cosine positional embedding.

Code follows: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py

Parameters
  • embed_dim – Embedding dimension.

  • grid_size – Height and width of the grid.

  • cls_token – If True, a positional embedding for the class token is generated.

Returns

Positional embedding with shape (grid_size * grid_size, embed_dim) or (1 + grid_size * grid_size, embed_dim) if cls_token is True.

lightly.models.utils.get_2d_sine_cosine_positional_embedding(embed_dim: int, grid_size: int, cls_token: bool) NDArray[np.float32]

Generates 2D sine-cosine positional embedding.

Code follows: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py

Parameters
  • embed_dim – Embedding dimension.

  • grid_size – Height and width of the grid.

  • cls_token – If True, a positional embedding for the class token is generated.

Returns

Positional embedding with shape (grid_size * grid_size, embed_dim) or (1 + grid_size * grid_size, embed_dim) if cls_token is True.

lightly.models.utils.get_2d_sine_cosine_positional_embedding_from_grid(embed_dim: int, grid: NDArray[np.float32]) NDArray[np.float32]

Generates 2D sine-cosine positional embedding from a grid.

Code follows: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py

Parameters
  • embed_dim – Embedding dimension.

  • grid – Grid of shape (2, grid_size, grid_size) with x and y coordinates.

Returns

Positional embedding with shape (grid_size * grid_size, embed_dim).

lightly.models.utils.get_at_index(tokens: Tensor, index: Tensor) Tensor

Selects tokens at index.

Parameters
  • tokens – Token tensor with shape (batch_size, sequence_length, dim).

  • index – Index tensor with shape (batch_size, index_length) where each entry is an index in [0, sequence_length).

Returns

Token tensor with shape (batch_size, index_length, dim) containing the selected tokens.

lightly.models.utils.get_named_leaf_modules(module: Module) Dict[str, Module]

Returns all leaf modules of the model with their names.

lightly.models.utils.get_weight_decay_parameters(modules: ~typing.Iterable[~torch.nn.modules.module.Module], decay_norm: bool = False, decay_bias: bool = False, norm_layers: ~typing.Tuple[~typing.Type[~torch.nn.modules.module.Module], ...] = (<class 'torch.nn.modules.batchnorm._NormBase'>, <class 'torch.nn.modules.normalization.LayerNorm'>, <class 'torch.nn.modules.normalization.CrossMapLRN2d'>, <class 'torch.nn.modules.normalization.LocalResponseNorm'>, <class 'torch.nn.modules.normalization.GroupNorm'>)) Tuple[List[Parameter], List[Parameter]]

Returns all parameters of the modules that should be decayed and not decayed.

Parameters
  • modules – List of modules to get the parameters from.

  • decay_norm – If True, normalization parameters are decayed.

  • decay_bias – If True, bias parameters are decayed.

  • norm_layers – Tuple of normalization classes to decay if decay_norm is True.

Returns

(params, params_no_weight_decay) tuple.

lightly.models.utils.initialize_learnable_positional_embedding(pos_embedding: Parameter) None

Initializes a learnable positional embedding.

Uses standard initialization for ViT models, see [0].

Parameters

pos_embedding – Positional embedding parameter.

lightly.models.utils.initialize_positional_embedding(pos_embedding: Parameter, strategy: str, num_prefix_tokens: int) None

Initializes the positional embedding with the given strategy.

Parameters
  • pos_embedding – Positional embedding parameter.

  • strategy – Positional embedding initialization strategy. Valid options are: [‘learn’, ‘sincos’, ‘skip’]. ‘learn’ makes the embedding learnable, ‘sincos’ creates a fixed 2D sine-cosine positional embedding, and ‘skip’ does not initialize the positional embedding.

  • num_prefix_tokens – Number of prefix tokens in the positional embedding. This includes the class token.

Raises

ValueError – If an invalid strategy is provided.

lightly.models.utils.mask_at_index(tokens: Tensor, index: Tensor, mask_token: Tensor) Tensor

Returns a tensor where the tokens at the given indices are replaced by the mask token.

Parameters
  • tokens – Tokens tensor with shape (batch_size, sequence_length, dim).

  • index – Index tensor with shape (batch_size, index_length).

  • mask_token – Value tensor with shape (1, 1, dim).

Returns

Tokens tensor with shape (batch_size, sequence_length, dim) containing the new values.

lightly.models.utils.mask_bool(tokens: Tensor, mask: Tensor, mask_token: Tensor) Tensor

Returns a tensor with tokens replaced by the mask tokens in all positions where the mask is True.

Parameters
  • tokens – Tokens tensor with shape (batch_size, sequence_length, dim).

  • mask – Boolean mask tensor with shape (batch_size, sequence_length).

  • mask_token – Mask token with shape (1, 1, dim).

Returns

Tokens tensor with shape (batch_size, sequence_length, dim) where tokens[i, j] is replaced by the mask token if mask[i, j] is True.

lightly.models.utils.most_similar_index(x: Tensor, y: Tensor) Tensor

For each feature in x, searches the most similar feature in y and returns the corresponding index.

Parameters
  • x – Tensor with shape (B, N, C) containing the features to compare.

  • y – Tensor with shape (B, N, C) containing the features to search for similarity.

Returns

Index with shape (B, N) such that y[i, index[i, j]] is most similar to x[i, j] over all y[i, …].

lightly.models.utils.nearest_neighbors(input_maps: Tensor, candidate_maps: Tensor, distances: Tensor, num_matches: int) Tuple[Tensor, Tensor]

Finds the nearest neighbors of the maps in input_maps in candidate_maps.

Parameters
  • input_maps – A tensor of maps for which to find nearest neighbors. It has shape: [batch_size, input_map_size, feature_dimension]

  • candidate_maps – A tensor of maps to search for nearest neighbors. It has shape: [batch_size, candidate_map_size, feature_dimension]

  • distances – A tensor of distances between the maps in input_maps and candidate_maps. It has shape: [batch_size, input_map_size, candidate_map_size]

  • num_matches – Number of nearest neighbors to return. If num_matches is None or -1, all the maps in candidate_maps are considered.

Returns

A tuple of tensors, containing the nearest neighbors in input_maps and candidate_maps. They both have shape: [batch_size, input_map_size, feature_dimension]

lightly.models.utils.normalize_mean_var(x: Tensor, dim: int = -1, eps: float = 1e-06) Tensor

Normalizes the input tensor to zero mean and unit variance.

Parameters
  • x – Input tensor.

  • dim – Dimension along which to compute mean and standard deviation. Takes last dimension by default.

  • eps – Epsilon value to avoid division by zero.

Returns

Normalized tensor.

lightly.models.utils.normalize_weight(weight: Parameter, dim: int = 1, keepdim: bool = True)

Normalizes the weight to unit length along the specified dimension.

lightly.models.utils.patchify(images: Tensor, patch_size: int) Tensor

Converts a batch of input images into patches.

Parameters
  • images – Images tensor with shape (batch_size, channels, height, width)

  • patch_size – Patch size in pixels. Image width and height must be multiples of the patch size.

Returns

Patches tensor with shape (batch_size, num_patches, channels * patch_size ** 2) where num_patches = image_width / patch_size * image_height / patch_size.

lightly.models.utils.pool_masked(source: Tensor, mask: Tensor, reduce: str = 'mean', num_cls: Optional[int] = None) Tensor

Reduce image feature maps (B, C, H, W) or (C, H, W) according to an integer index given by mask (B, H, W) or (H, W).

Parameters
  • source – Float tensor of shape (B, C, H, W) or (C, H, W) to be reduced.

  • mask – Integer tensor of shape (B, H, W) or (H, W) containing the integer indices.

  • reduce – The reduction operation to be applied, one of ‘prod’, ‘mean’, ‘amax’ or ‘amin’. Defaults to ‘mean’.

  • num_cls – The number of classes in the possible masks. If None, the number of classes is inferred from the unique elements in mask. This is useful when not all classes are present in the mask.

Returns

A tensor of shape (B, C, N) or (C, N) where N is the number of unique elements in mask or num_cls if specified.

lightly.models.utils.prepend_class_token(tokens: Tensor, class_token: Tensor) Tensor

Prepends class token to tokens.

Parameters
  • tokens – Tokens tensor with shape (batch_size, sequence_length, dim).

  • class_token – Class token with shape (1, 1, dim).

Returns

Tokens tensor with the class token prepended at index 0 in every sequence. The tensor has shape (batch_size, sequence_length + 1, dim).

lightly.models.utils.random_block_mask(size: Tuple[int, int, int], batch_mask_ratio: float = 0.5, min_image_mask_ratio: float = 0.1, max_image_mask_ratio: float = 0.5, min_num_masks_per_block: int = 4, max_num_masks_per_block: Optional[int] = None, min_block_aspect_ratio: float = 0.3, max_block_aspect_ratio: Optional[float] = None, max_attempts_per_block: int = 10, device: Optional[Union[device, str]] = None) Tensor

Creates a random block mask for a batch of images.

A block is in this context a rectangle of patches in an image that are masked together. The function generates block masks until the desired number of patches per image are masked. DINOv2 uses a more complex masking strategy that only generates masks for mask_ratio of the images. On top of that, it also masks a different number of patches for every image. This is controlled by the min_image_mask_ratio and max_image_mask_ratio arguments.

Based on the implementation of the block mask in DINOv2 [0]. For details see [1] and [2].

Parameters
  • size – Size of the image batch for which to generate masks. Should be (batch_size, height, width).

  • batch_mask_ratio – Percentage of images per batch for which to generate block masks. The remaining images are not masked.

  • min_image_mask_ratio – Minimum percentage of the image to mask. In practice, fewer than min_image_mask_ratio patches of the image can be masked due to additional constraints.

  • max_image_mask_ratio – Maximum percentage of the image to mask.

  • min_num_masks_per_block – Minimum number of patches to mask per block.

  • max_num_masks_per_block – Maximum number of patches to mask per block.

  • min_block_aspect_ratio – Minimum aspect ratio (height/width) of a masked block.

  • max_block_aspect_ratio – Maximum aspect ratio (height/width) of a masked block.

  • max_attempts_per_block – Maximum number of attempts to find a valid block mask for an image.

  • device – Device on which to create the mask.

Returns

A boolean tensor with shape (batch_size, height, width) where each entry is True if the patch should be masked and False otherwise.

Raises

ValueError – If ‘max_image_mask_ratio’ is less than ‘min_image_mask_ratio’.

lightly.models.utils.random_block_mask_image(size: Tuple[int, int], num_masks: int, min_num_masks_per_block: int = 4, max_num_masks_per_block: Optional[int] = None, min_block_aspect_ratio: float = 0.3, max_block_aspect_ratio: Optional[float] = None, max_attempts_per_block: int = 10, device: Optional[Union[device, str]] = None) Tensor

Creates a random block mask for a single image.

Parameters
  • size – Size of the image for which to generate a mask. Should be (height, width).

  • num_masks – Number of patches to mask.

  • min_num_masks_per_block – Minimum number of patches to mask per block.

  • max_num_masks_per_block – Maximum number of patches to mask per block.

  • min_block_aspect_ratio – Minimum aspect ratio (height/width) of a masked block.

  • max_block_aspect_ratio – Maximum aspect ratio (height/width) of a masked block.

  • max_attempts_per_block – Maximum number of attempts to find a valid block mask.

  • device – Device on which to create the mask.

Returns

A boolean tensor with shape (height, width) where each entry is True if the patch should be masked and False otherwise.

Raises

ValueError – If ‘max_num_masks_per_block’ is less than ‘min_num_masks_per_block’ or if ‘max_block_aspect_ratio’ is less than ‘min_block_aspect_ratio’

lightly.models.utils.random_prefix_mask(size: Tuple[int, int], max_prefix_length: int, device: Optional[Union[device, str]] = None) Tensor

Creates a random prefix mask.

The mask is created by uniformly sampling a prefix length in [0, max_prefix_length] for each sequence in the batch. All tokens with an index greater or equal to the prefix length are masked.

Parameters
  • size – Size of the token batch for which to generate masks. Should be (batch_size, sequence_length).

  • max_prefix_length – Maximum length of the prefix to mask.

  • device – Device on which to create the mask.

Returns

A mask tensor with shape (batch_size, sequence_length) where each entry is True if the token should be masked and False otherwise.

lightly.models.utils.random_token_mask(size: Tuple[int, int], mask_ratio: float = 0.6, mask_class_token: bool = False, device: Optional[Union[device, str]] = None) Tuple[Tensor, Tensor]

Creates random token masks.

Parameters
  • size – Size of the token batch for which to generate masks. Should be (batch_size, sequence_length).

  • mask_ratio – Proportion of tokens to mask.

  • mask_class_token – If False the class token is never masked. If True the class token might be masked.

  • device – Device on which to create the index masks.

Returns

A (index_keep, index_mask) tuple where each index is a tensor. index_keep contains the indices of the unmasked tokens and has shape (batch_size, num_keep). index_mask contains the indices of the masked tokens and has shape (batch_size, sequence_length - num_keep). num_keep is equal to sequence_length * (1 - mask_ratio).

lightly.models.utils.repeat_token(token: Tensor, size: Tuple[int, int]) Tensor

Repeats a token size times.

Parameters
  • token – Token tensor with shape (1, 1, dim).

  • size – (batch_size, sequence_length) tuple.

Returns

Tensor with shape (batch_size, sequence_length, dim) containing copies of the input token.

lightly.models.utils.select_most_similar(x: Tensor, y: Tensor, y_values: Tensor) Tensor

For each feature in x, searches the most similar feature in y and returns the corresponding value from y_values.

Parameters
  • x – Tensor with shape (B, N, C).

  • y – Tensor with shape (B, N, C).

  • y_values – Tensor with shape (B, N, D).

Returns

Values with shape (B, N, D) where values[i, j] is the entry in y_values[i, …] such that x[i, j] is the most similar to y[i, …].

lightly.models.utils.set_at_index(tokens: Tensor, index: Tensor, value: Tensor) Tensor

Copies all values into the input tensor at the given indices.

Parameters
  • tokens – Tokens tensor with shape (batch_size, sequence_length, dim).

  • index – Index tensor with shape (batch_size, index_length).

  • value – Value tensor with shape (batch_size, index_length, dim).

Returns

Tokens tensor with shape (batch_size, sequence_length, dim) containing the new values.

lightly.models.utils.unpatchify(patches: Tensor, patch_size: int, channels: int = 3) Tensor

Reconstructs images from their patches.

Args:
patches:

Patches tensor with shape (batch_size, num_patches, channels * patch_size ** 2).

patch_size:

The patch size in pixels used to create the patches.

channels:

The number of channels the image must have

Returns:

Reconstructed images tensor with shape (batch_size, channels, height, width).

lightly.models.utils.update_drop_path_rate(model: VisionTransformer, drop_path_rate: float, mode: str = 'linear') None

Updates the drop path rate in a TIMM VisionTransformer model.

Parameters
  • model – TIMM VisionTransformer model.

  • drop_path_rate – Maximum drop path rate.

  • mode – Drop path rate update mode. Can be “linear” or “uniform”. Linear increases the drop path rate from 0 to drop_path_rate over the depth of the model. Uniform sets the drop path rate to drop_path_rate for all blocks.

Raises

ValueError – If an unknown mode is provided.

lightly.models.utils.update_momentum(model: Module, model_ema: Module, m: float)

Updates parameters of model_ema with Exponential Moving Average of model

Momentum encoders are a crucial component for models such as MoCo or BYOL.

Parameters
  • model – The current model.

  • model_ema – The model with exponential moving average (EMA) parameters.

  • m – The momentum factor, between 0 and 1.

Examples

>>> backbone = resnet18()
>>> projection_head = MoCoProjectionHead()
>>> backbone_momentum = copy.deepcopy(moco)
>>> projection_head_momentum = copy.deepcopy(projection_head)
>>>
>>> # update momentum
>>> update_momentum(moco, moco_momentum, m=0.999)
>>> update_momentum(projection_head, projection_head_momentum, m=0.999)