Torchvision¶
This page describes how to use Torchvision models with LightlyTrain.
Pretrain and Fine-tune a Torchvision Model¶
Pretrain¶
Pretraining Torchvision models with LightlyTrain is straightforward. Below we provide
the minimum scripts for pretraining using torchvision/resnet18 as an example:
import lightly_train
if __name__ == "__main__":
lightly_train.train(
out="out/my_experiment", # Output directory.
data="my_data_dir", # Directory with images.
model="torchvision/resnet18", # Pass the Torchvision model.
)
Or alternatively, pass directly a Torchvision model instance:
from torchvision.models import resnet18
import lightly_train
if __name__ == "__main__":
model = resnet18() # Load the Torchvision model.
lightly_train.train(
out="out/my_experiment", # Output directory.
data="my_data_dir", # Directory with images.
model=model, # Pass the Torchvision model.
)
lightly-train train out="out/my_experiment" data="my_data_dir" model="torchvision/resnet18"
Fine-tune¶
After pretraining, you can load the exported model for fine-tuning with Torchvision:
import torch
from torchvision.models import resnet18
model = resnet18()
state_dict = torch.load("out/my_experiment/exported_models/exported_last.pt", weights_only=True)
model.load_state_dict(state_dict)
Supported Models¶
The following Torchvision models are supported:
ResNet
torchvision/resnet18torchvision/resnet34torchvision/resnet50torchvision/resnet101torchvision/resnet152
ConvNext
torchvision/convnext_basetorchvision/convnext_largetorchvision/convnext_smalltorchvision/convnext_tiny
ShuffleNetV2
torchvision/shufflenet_v2_x0_5torchvision/shufflenet_v2_x1_0torchvision/shufflenet_v2_x1_5torchvision/shufflenet_v2_x2_0