Self-Supervision with FastAI
A tutorial of rotation-based self-supervision using FastAI2 & PyTorch!
- Introduction
- Experiment Layout
- FastAI Vision Model Creation Methods
- PyTorch Rotation/Classification Self-Supervised Dataset
- Rotation Prediction Data
- FastAI Vision Learner [Rotation]
- Original Classification Data
- FastAI Vision Learner [Transfer-Classification]
- FastAI Vision Learner [From Sratch-Classification]
- Conclusion
Introduction
This notebook is an introduction to self-supervised learning. In short, self-supervised learning has 2 components:
- Pretrain on a pretext task, where the labels can come from the data itself!
- Transfer the features, and train on the actual classification labels!
"What if we can get labels for free for unlabelled data and train unsupervised dataset in a supervised manner? We can achieve this by framing a supervised learning task in a special form to predict only a subset of information using the rest. In this way, all the information needed, both inputs and labels, has been provided. This is known as self-supervised learning." - Lilian Weng
Using FastAI2, we'll use rotation as a pretext task for learning representations/features of our data.
Here are some great overviews of self-supervised learning that I've come across:
In this notebook, we will be using the MNIST dataset.
Also check out ImageWang from FastAI themselves! It's a dataset designed for self-supervision tasks!
-
Train a model on a rotation prediction task.
- We will use all the training data for rotation prediction.
- Input: A rotated image.
- Target/Label: Classify the amount of degrees rotated.
- Our model should learn useful features that can transfer well for a classification task.
- (The model should learn what digits look like in order to be able to successfully predict the amount of rotation).
-
Transfer our rotation pretraining features to solve the classification task with much fewer labels, < 1% of the original data.
- Input: A normal image.
- Target/Label: The images' original categorical label.
- Classification accuracy should be decent, even with only using < 1% of the original data.
-
Train a classifier from scratch on the same amount of data used in experiment 2.
- Input: A normal image.
- Target/Label: The images' original categorical label.
- Classification accuracy should be low (lack of transfer learning & too few labeled data!)
- Model may overfit.
!pip install fastai --upgrade
# Uncomment and run the below line to get a fresh install of fastai, if needed
# !pip install fastai --upgrade
simple_arch
.
simple_arch
takes in one argument, pretrained
. This is to allow FastAI to pass pretrained=True
or pretrained=False
when creating the model body! Below are some use cases of when we would want pretrained=True
or pretrained=False
.
-
pretrained=False
= For training a new model on our rotation prediction task. -
pretrained=True
= For transferring the learnt features from our rotation task pretraining to solve a classification task. -
pretrained=False
= For training a new model from scratch on the main classification task (no transfer learning).
from fastai.vision.all import *
def simple_arch(pretrained=False):
# Note that FastAI will automatically cut at pooling layer for the body!
model = nn.Sequential(
nn.Conv2d(1, 4, 3, 1),
nn.BatchNorm2d(4),
nn.ReLU(),
nn.Conv2d(4, 16, 3, 1),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 32, 3, 1),
nn.BatchNorm2d(32),
nn.AdaptiveAvgPool2d(1),
)
if (pretrained):
print("Loading pretrained model...")
pretrained_weights = torch.load(save_path/'rot_pretrained.pt')
print(model.load_state_dict(pretrained_weights))
return model
The follow below code snippets are examples of how FastAI creates CNNs. Every model will have a body and a head
body = create_body(arch=simple_arch, pretrained=False)
body
head = create_head(nf=32, n_out=8, lin_ftrs=[])
head
# Note that FastAI automatically determines nf for the head!
model = create_cnn_model(arch=simple_arch, pretrained=False, n_out=8, lin_ftrs=[])
model
import torchvision
tensorToImage = torchvision.transforms.ToPILImage()
imageToTensor = torchvision.transforms.ToTensor()
# Uncomment and run the below lines if torchvision has trouble downloading MNIST (in the next cell)
# !wget -P data/MNIST/raw/ http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
# !wget -P data/MNIST/raw/ http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
# !wget -P data/MNIST/raw/ http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
# !wget -P data/MNIST/raw/ http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
torchvision.datasets.MNIST('data/', download=True)
Below we define a dataset, here's the docstring:
A Dataset for Rotation-based Self-Supervision! Images are rotated clockwise.
-
file
- MNIST processed .pt file. -
pct
- percent of data to use -
classification
- False=Use rotation labels. True=Use original classification labels.
class Custom_Dataset_MNIST():
'''
A Dataset for Rotation-based Self-Supervision! Images are rotated clockwise.
- file - MNIST processed .pt file.
- pct - percent of data to use
- classification - False=Use rotation labels. True=Use original classification labels.
'''
def __init__(self, file, pct, classification):
data = torch.load(file)
self.imgs = data[0]
self.labels = data[1]
self.pct = pct
self.classification = classification
slice_idx = int(len(self.imgs)*self.pct)
self.imgs = self.imgs[:slice_idx]
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
img = self.imgs[idx].unsqueeze(0)
img = tensorToImage(img)
img = img.resize((32, 32), resample=1)
img = imageToTensor(img)
if (not self.classification):
# 4 classes for rotation
degrees = [0, 45, 90, 135, 180, 225, 270, 315]
rand_choice = random.randint(0, len(degrees)-1)
img = tensorToImage(img)
img = img.rotate(degrees[rand_choice])
img = imageToTensor(img)
return img, torch.tensor(rand_choice).long()
return img, self.labels[idx]
def show_batch(self, n=3):
fig, axs = plt.subplots(n, n)
fig.tight_layout()
for i in range(n):
for j in range(n):
rand_idx = random.randint(0, len(self)-1)
img, label = self.__getitem__(rand_idx)
axs[i, j].imshow(tensorToImage(img), cmap='gray')
if self.classification:
axs[i, j].set_title('Label: {0} (Digit #{1})'.format(label.item(), label.item()))
else:
axs[i, j].set_title('Label: {0} ({1} Degrees)'.format(label.item(), label.item()*45))
axs[i, j].axis('off')
train_ds = Custom_Dataset_MNIST('data/MNIST/processed/training.pt', pct=1.0, classification=False)
valid_ds = Custom_Dataset_MNIST('data/MNIST/processed/test.pt', pct=1.0, classification=False)
print('{0} Training Samples | {1} Validation Samples'.format(len(train_ds), len(valid_ds)))
[0, 45, 90, 135, 180, 225, 270, 315]
from fastai.data.core import DataLoaders
dls = DataLoaders.from_dsets(train_ds, valid_ds).cuda()
# Override the show_batch function of dls to the one used in our dataset!
dls.show_batch = train_ds.show_batch
# We have 8 classes! [0, 1, 2, 3, 4, 5, 6, 7] that correspond to the [0, 45, 90, 135, 180, 225, 270, 315] degrees of rotation.
dls.c = 8
dls.show_batch()
rotation_head = create_head(nf=32, n_out=8, lin_ftrs=[])
rotation_head
top_2_accuracy
along with regular (top_1) accuracy, because there are hard-cases where it’s understandable why our model got it wrong. For example: ’0’ rotated 90 or 270 degrees, or ’1’ rotated 0 or 180 degrees. (They can look the same!)
# - A zero rotated 90 or 270 degrees?
# - A one rotated 0 or 180 degrees?
# etc :P
top_2_accuracy = lambda inp, targ: top_k_accuracy(inp, targ, k=2)
top_2_accuracy
Here, we train a model on the rotation prediction task!
# Note to set a value for lin_ftrs, we use the defined config above.
learner = cnn_learner(dls,
simple_arch,
pretrained=False,
loss_func=CrossEntropyLossFlat(),
custom_head=rotation_head,
metrics=[accuracy, top_2_accuracy])
learner.model
learner.summary()
learner.lr_find()
learner.fit_one_cycle(5, lr_max=3e-2)