Torchvision¶
This page describes how to use Torchvision models with LightlyTrain.
Pretrain and Fine-tune a Torchvision Model¶
Pretrain¶
Pretraining Torchvision models with LightlyTrain is straightforward. Below we provide
the minimum scripts for pretraining using torchvision/resnet18
as an example:
import lightly_train
if __name__ == "__main__":
lightly_train.train(
out="out/my_experiment", # Output directory.
data="my_data_dir", # Directory with images.
model="torchvision/resnet18", # Pass the Torchvision model.
)
Or alternatively, pass directly a Torchvision model instance:
from torchvision.models import resnet18
import lightly_train
if __name__ == "__main__":
model = resnet18() # Load the Torchvision model.
lightly_train.train(
out="out/my_experiment", # Output directory.
data="my_data_dir", # Directory with images.
model=model, # Pass the Torchvision model.
)
lightly-train train out="out/my_experiment" data="my_data_dir" model="torchvision/resnet18"
Fine-tune¶
After pretraining, you can load the exported model for fine-tuning with Torchvision:
import torch
from torchvision.models import resnet18
model = resnet18()
state_dict = torch.load("out/my_experiment/exported_models/exported_last.pt")
model.load_state_dict(state_dict, weights_only=True)
Supported Models¶
The following Torchvision models are supported:
ResNet
torchvision/resnet18
torchvision/resnet34
torchvision/resnet50
torchvision/resnet101
torchvision/resnet152
ConvNext
torchvision/convnext_base
torchvision/convnext_large
torchvision/convnext_small
torchvision/convnext_tiny
ShuffleNetV2
torchvision/shufflenet_v2_x0_5
torchvision/shufflenet_v2_x1_0
torchvision/shufflenet_v2_x1_5
torchvision/shufflenet_v2_x2_0