Self-Supervision with FastAI
A tutorial of rotation-based self-supervision using FastAI2 & PyTorch!
- Introduction
- Experiment Layout
- FastAI Vision Model Creation Methods
- PyTorch Rotation/Classification Self-Supervised Dataset
- Rotation Prediction Data
- FastAI Vision Learner [Rotation]
- Original Classification Data
- FastAI Vision Learner [Transfer-Classification]
- FastAI Vision Learner [From Sratch-Classification]
- Conclusion
This notebook is an introduction to self-supervised learning. In short, self-supervised learning has 2 components:
- Pretrain on a pretext task, where the labels can come from the data itself!
- Transfer the features, and train on the actual classification labels!
"What if we can get labels for free for unlabelled data and train unsupervised dataset in a supervised manner? We can achieve this by framing a supervised learning task in a special form to predict only a subset of information using the rest. In this way, all the information needed, both inputs and labels, has been provided. This is known as self-supervised learning." - Lilian Weng
Using FastAI2, we'll use rotation as a pretext task for learning representations/features of our data.
Here are some great overviews of self-supervised learning that I've come across:
In this notebook, we will be using the MNIST dataset.
Also check out ImageWang from FastAI themselves! It's a dataset designed for self-supervision tasks!
Train a model on a rotation prediction task.
- We will use all the training data for rotation prediction.
- Input: A rotated image.
- Target/Label: Classify the amount of degrees rotated.
- Our model should learn useful features that can transfer well for a classification task.
- (The model should learn what digits look like in order to be able to successfully predict the amount of rotation).
Transfer our rotation pretraining features to solve the classification task with much fewer labels, < 1% of the original data.
- Input: A normal image.
- Target/Label: The images' original categorical label.
- Classification accuracy should be decent, even with only using < 1% of the original data.
Train a classifier from scratch on the same amount of data used in experiment 2.
- Input: A normal image.
- Target/Label: The images' original categorical label.
- Classification accuracy should be low (lack of transfer learning & too few labeled data!)
- Model may overfit.
!pip install fastai --upgrade
# Uncomment and run the below line to get a fresh install of fastai, if needed
# !pip install fastai --upgrade
takes in one argument, pretrained
. This is to allow FastAI to pass pretrained=True
or pretrained=False
when creating the model body! Below are some use cases of when we would want pretrained=True
or pretrained=False
= For training a new model on our rotation prediction task. -
= For transferring the learnt features from our rotation task pretraining to solve a classification task. -
= For training a new model from scratch on the main classification task (no transfer learning).
from import *
def simple_arch(pretrained=False):
# Note that FastAI will automatically cut at pooling layer for the body!
model = nn.Sequential(
nn.Conv2d(1, 4, 3, 1),
nn.Conv2d(4, 16, 3, 1),
nn.Conv2d(16, 32, 3, 1),
if (pretrained):
print("Loading pretrained model...")
pretrained_weights = torch.load(save_path/'')
return model
The follow below code snippets are examples of how FastAI creates CNNs. Every model will have a body and a head
body = create_body(arch=simple_arch, pretrained=False)
head = create_head(nf=32, n_out=8, lin_ftrs=[])
# Note that FastAI automatically determines nf for the head!
model = create_cnn_model(arch=simple_arch, pretrained=False, n_out=8, lin_ftrs=[])
import torchvision
tensorToImage = torchvision.transforms.ToPILImage()
imageToTensor = torchvision.transforms.ToTensor()
# Uncomment and run the below lines if torchvision has trouble downloading MNIST (in the next cell)
# !wget -P data/MNIST/raw/
# !wget -P data/MNIST/raw/
# !wget -P data/MNIST/raw/
# !wget -P data/MNIST/raw/
torchvision.datasets.MNIST('data/', download=True)
Below we define a dataset, here's the docstring:
A Dataset for Rotation-based Self-Supervision! Images are rotated clockwise.
- MNIST processed .pt file. -
- percent of data to use -
- False=Use rotation labels. True=Use original classification labels.
class Custom_Dataset_MNIST():
A Dataset for Rotation-based Self-Supervision! Images are rotated clockwise.
- file - MNIST processed .pt file.
- pct - percent of data to use
- classification - False=Use rotation labels. True=Use original classification labels.
def __init__(self, file, pct, classification):
data = torch.load(file)
self.imgs = data[0]
self.labels = data[1]
self.pct = pct
self.classification = classification
slice_idx = int(len(self.imgs)*self.pct)
self.imgs = self.imgs[:slice_idx]
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
img = self.imgs[idx].unsqueeze(0)
img = tensorToImage(img)
img = img.resize((32, 32), resample=1)
img = imageToTensor(img)
if (not self.classification):
# 4 classes for rotation
degrees = [0, 45, 90, 135, 180, 225, 270, 315]
rand_choice = random.randint(0, len(degrees)-1)
img = tensorToImage(img)
img = img.rotate(degrees[rand_choice])
img = imageToTensor(img)
return img, torch.tensor(rand_choice).long()
return img, self.labels[idx]
def show_batch(self, n=3):
fig, axs = plt.subplots(n, n)
for i in range(n):
for j in range(n):
rand_idx = random.randint(0, len(self)-1)
img, label = self.__getitem__(rand_idx)
axs[i, j].imshow(tensorToImage(img), cmap='gray')
if self.classification:
axs[i, j].set_title('Label: {0} (Digit #{1})'.format(label.item(), label.item()))
axs[i, j].set_title('Label: {0} ({1} Degrees)'.format(label.item(), label.item()*45))
axs[i, j].axis('off')
train_ds = Custom_Dataset_MNIST('data/MNIST/processed/', pct=1.0, classification=False)
valid_ds = Custom_Dataset_MNIST('data/MNIST/processed/', pct=1.0, classification=False)
print('{0} Training Samples | {1} Validation Samples'.format(len(train_ds), len(valid_ds)))
[0, 45, 90, 135, 180, 225, 270, 315]
from import DataLoaders
dls = DataLoaders.from_dsets(train_ds, valid_ds).cuda()
# Override the show_batch function of dls to the one used in our dataset!
dls.show_batch = train_ds.show_batch
# We have 8 classes! [0, 1, 2, 3, 4, 5, 6, 7] that correspond to the [0, 45, 90, 135, 180, 225, 270, 315] degrees of rotation.
dls.c = 8
rotation_head = create_head(nf=32, n_out=8, lin_ftrs=[])
along with regular (top_1) accuracy, because there are hard-cases where it’s understandable why our model got it wrong. For example: ’0’ rotated 90 or 270 degrees, or ’1’ rotated 0 or 180 degrees. (They can look the same!)
# - A zero rotated 90 or 270 degrees?
# - A one rotated 0 or 180 degrees?
# etc :P
top_2_accuracy = lambda inp, targ: top_k_accuracy(inp, targ, k=2)
Here, we train a model on the rotation prediction task!
# Note to set a value for lin_ftrs, we use the defined config above.
learner = cnn_learner(dls,
metrics=[accuracy, top_2_accuracy])
learner.fit_one_cycle(5, lr_max=3e-2)