SMoG
Example implementation of the Synchronous Momentum Grouping (SMoG) paper. SMoG follows the framework of contrastive learning but replaces the contrastive unit from instance to group, mimicking clustering-based methods. To achieve this, they propose the momentum grouping scheme which synchronously conducts feature grouping with representation learning.
This example can be run from the command line with:
python lightly/examples/pytorch/smog.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import copy
import torch
import torchvision
from sklearn.cluster import KMeans
from torch import nn
from lightly.models import utils
from lightly.models.modules.heads import (
SMoGPredictionHead,
SMoGProjectionHead,
SMoGPrototypes,
)
from lightly.models.modules.memory_bank import MemoryBankModule
from lightly.transforms.smog_transform import SMoGTransform
class SMoGModel(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
self.projection_head = SMoGProjectionHead(512, 2048, 128)
self.prediction_head = SMoGPredictionHead(128, 2048, 128)
self.backbone_momentum = copy.deepcopy(self.backbone)
self.projection_head_momentum = copy.deepcopy(self.projection_head)
utils.deactivate_requires_grad(self.backbone_momentum)
utils.deactivate_requires_grad(self.projection_head_momentum)
self.n_groups = 300
self.smog = SMoGPrototypes(
group_features=torch.rand(self.n_groups, 128), beta=0.99
)
def _cluster_features(self, features: torch.Tensor) -> torch.Tensor:
# clusters the features using sklearn
# (note: faiss is probably more efficient)
features = features.cpu().numpy()
kmeans = KMeans(self.n_groups).fit(features)
clustered = torch.from_numpy(kmeans.cluster_centers_).float()
clustered = torch.nn.functional.normalize(clustered, dim=1)
return clustered
def reset_group_features(self, memory_bank):
# see https://arxiv.org/pdf/2207.06167.pdf Table 7b)
features = memory_bank.bank
group_features = self._cluster_features(features.t())
self.smog.set_group_features(group_features)
def reset_momentum_weights(self):
# see https://arxiv.org/pdf/2207.06167.pdf Table 7b)
self.backbone_momentum = copy.deepcopy(self.backbone)
self.projection_head_momentum = copy.deepcopy(self.projection_head)
utils.deactivate_requires_grad(self.backbone_momentum)
utils.deactivate_requires_grad(self.projection_head_momentum)
def forward(self, x):
features = self.backbone(x).flatten(start_dim=1)
encoded = self.projection_head(features)
predicted = self.prediction_head(encoded)
return encoded, predicted
def forward_momentum(self, x):
features = self.backbone_momentum(x).flatten(start_dim=1)
encoded = self.projection_head_momentum(features)
return encoded
batch_size = 256
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = SMoGModel(backbone)
# memory bank because we reset the group features every 300 iterations
memory_bank_size = 300 * batch_size
memory_bank = MemoryBankModule(size=(memory_bank_size, 128))
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
transform = SMoGTransform(
crop_sizes=(32, 32),
crop_counts=(1, 1),
gaussian_blur_probs=(0.0, 0.0),
crop_min_scales=(0.2, 0.2),
crop_max_scales=(1.0, 1.0),
)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-6
)
global_step = 0
print("Starting Training")
for epoch in range(10):
total_loss = 0
for batch_idx, batch in enumerate(dataloader):
(x0, x1) = batch[0]
if batch_idx % 2:
# swap batches every second iteration
x1, x0 = x0, x1
x0 = x0.to(device)
x1 = x1.to(device)
if global_step > 0 and global_step % 300 == 0:
# reset group features and weights every 300 iterations
model.reset_group_features(memory_bank=memory_bank)
model.reset_momentum_weights()
else:
# update momentum
utils.update_momentum(model.backbone, model.backbone_momentum, 0.99)
utils.update_momentum(
model.projection_head, model.projection_head_momentum, 0.99
)
x0_encoded, x0_predicted = model(x0)
x1_encoded = model.forward_momentum(x1)
# update group features and get group assignments
assignments = model.smog.assign_groups(x1_encoded)
group_features = model.smog.get_updated_group_features(x0_encoded)
logits = model.smog(x0_predicted, group_features, temperature=0.1)
model.smog.set_group_features(group_features)
loss = criterion(logits, assignments)
# use memory bank to periodically reset the group features with k-means
memory_bank(x0_encoded, update=True)
loss.backward()
optimizer.step()
optimizer.zero_grad()
global_step += 1
total_loss += loss.detach()
avg_loss = total_loss / len(dataloader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
This example can be run from the command line with:
python lightly/examples/pytorch_lightning/smog.py
# This example requires the following dependencies to be installed:
# pip install lightly
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.
import copy
import pytorch_lightning as pl
import torch
import torchvision
from sklearn.cluster import KMeans
from torch import nn
from lightly import loss, models
from lightly.models import utils
from lightly.models.modules import heads
from lightly.transforms.smog_transform import SMoGTransform
class SMoGModel(pl.LightningModule):
def __init__(self):
super().__init__()
# create a ResNet backbone and remove the classification head
resnet = models.ResNetGenerator("resnet-18")
self.backbone = nn.Sequential(
*list(resnet.children())[:-1], nn.AdaptiveAvgPool2d(1)
)
# create a model based on ResNet
self.projection_head = heads.SMoGProjectionHead(512, 2048, 128)
self.prediction_head = heads.SMoGPredictionHead(128, 2048, 128)
self.backbone_momentum = copy.deepcopy(self.backbone)
self.projection_head_momentum = copy.deepcopy(self.projection_head)
utils.deactivate_requires_grad(self.backbone_momentum)
utils.deactivate_requires_grad(self.projection_head_momentum)
# smog
self.n_groups = 300
memory_bank_size = 10000
self.memory_bank = loss.memory_bank.MemoryBankModule(size=memory_bank_size)
# create our loss
group_features = torch.nn.functional.normalize(
torch.rand(self.n_groups, 128), dim=1
).to(self.device)
self.smog = heads.SMoGPrototypes(group_features=group_features, beta=0.99)
self.criterion = nn.CrossEntropyLoss()
def _cluster_features(self, features: torch.Tensor) -> torch.Tensor:
features = features.cpu().numpy()
kmeans = KMeans(self.n_groups).fit(features)
clustered = torch.from_numpy(kmeans.cluster_centers_).float()
clustered = torch.nn.functional.normalize(clustered, dim=1)
return clustered
def _reset_group_features(self):
# see https://arxiv.org/pdf/2207.06167.pdf Table 7b)
features = self.memory_bank.bank
group_features = self._cluster_features(features.t())
self.smog.set_group_features(group_features)
def _reset_momentum_weights(self):
# see https://arxiv.org/pdf/2207.06167.pdf Table 7b)
self.backbone_momentum = copy.deepcopy(self.backbone)
self.projection_head_momentum = copy.deepcopy(self.projection_head)
utils.deactivate_requires_grad(self.backbone_momentum)
utils.deactivate_requires_grad(self.projection_head_momentum)
def training_step(self, batch, batch_idx):
if self.global_step > 0 and self.global_step % 300 == 0:
# reset group features and weights every 300 iterations
self._reset_group_features()
self._reset_momentum_weights()
else:
# update momentum
utils.update_momentum(self.backbone, self.backbone_momentum, 0.99)
utils.update_momentum(
self.projection_head, self.projection_head_momentum, 0.99
)
(x0, x1) = batch[0]
if batch_idx % 2:
# swap batches every second iteration
x0, x1 = x1, x0
x0_features = self.backbone(x0).flatten(start_dim=1)
x0_encoded = self.projection_head(x0_features)
x0_predicted = self.prediction_head(x0_encoded)
x1_features = self.backbone_momentum(x1).flatten(start_dim=1)
x1_encoded = self.projection_head_momentum(x1_features)
# update group features and get group assignments
assignments = self.smog.assign_groups(x1_encoded)
group_features = self.smog.get_updated_group_features(x0_encoded)
logits = self.smog(x0_predicted, group_features, temperature=0.1)
self.smog.set_group_features(group_features)
loss = self.criterion(logits, assignments)
# use memory bank to periodically reset the group features with k-means
self.memory_bank(x0_encoded, update=True)
return loss
def configure_optimizers(self):
params = (
list(self.backbone.parameters())
+ list(self.projection_head.parameters())
+ list(self.prediction_head.parameters())
)
optim = torch.optim.SGD(
params,
lr=0.01,
momentum=0.9,
weight_decay=1e-6,
)
return optim
model = SMoGModel()
transform = SMoGTransform(
crop_sizes=(32, 32),
crop_counts=(1, 1),
gaussian_blur_probs=(0.0, 0.0),
crop_min_scales=(0.2, 0.2),
crop_max_scales=(1.0, 1.0),
)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)
trainer.fit(model=model, train_dataloaders=dataloader)