Self-supervised studying tutorial: Implementing SimCLR with pytorch lightning



On this hands-on tutorial, we’ll give you a reimplementation of SimCLR self-supervised studying methodology for pretraining sturdy function extractors. This methodology is pretty basic and will be utilized to any imaginative and prescient dataset, in addition to completely different downstream duties.

In a earlier tutorial, I wrote a little bit of a background on the self-supervised studying enviornment. Time to get into your first venture by working SimCLR on a small dataset with 100K unlabelled pictures known as STL10.

Code is obtainable on Github.

The SimCLR methodology: contrastive studying

Let sim(u,v)sim(u,v) word the dot product between 2 normalized uu and vv vectors (i.e. cosine similarity).

Then the loss perform for a optimistic pair of examples (i,j) is outlined as:

i,j=logexp(sim(zi,zj)/τ)okay=12N1[ki]exp(sim(zi,zokay)/τ)ell_{i, j}=-log frac{exp left(operatorname{sim}left(boldsymbol{z}_{i}, boldsymbol{z}_{j}proper) / tauright)}{sum_{okay=1}^{2 N} mathbb{1}_{[k neq i]} exp left(operatorname{sim}left(boldsymbol{z}_{i}, boldsymbol{z}_{okay}proper) / tauright)}

the place 1[ki]0,1mathbb{1}_{[k neq i]} in {0,1}

τtau denotes a temperature parameter. The ultimate loss is computed by summing all optimistic pairs and divide by 2×N=views×batch_size2times N = views occasions batch_size

There are alternative ways to develop contrastive loss. Right here we give you some necessary data.

L2 normalization and cosine similarity matrix calculation

First, one wants to use an L2 normalization to the options, in any other case, this methodology doesn’t work. L2 normalization implies that the vectors are normalized such that all of them lie on the floor of the unit (hyper)sphere, the place the L2 norm is 1.

z_i = F.normalize(proj_1, p=2, dim=1)

z_j = F.normalize(proj_2, p=2, dim=1)

Concatenate the two output views within the batch dimension. Their form will likely be [2×batch_size,dim][2 times batch_size, dim]

def calc_similarity_batch(self, a, b):

representations =[a, b], dim=0)

return F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)

Indexing the similarity matrix for the SimCLR loss perform

Now we have to index the ensuing matrix of dimension [batch_size×views,batch_size×views][batch_size times views, batch_size times views]


A visible illustration of SimCLR. Picture from the writer

Okay how on earth will we try this? I had the identical query. Right here the batch dimension is 2 pictures however we wish to implement an answer for any batch dimension. If you happen to look intently, you will notice that the optimistic pairs are shifted from the primary diagonal by 2, that’s the batch dimension. A method to do this is torch.diag(). It takes the chosen diagonal from a matrix. The primary parameter is the matrix and the second specifies the diagonal, the place zero represents the primary diagonal parts. We take the diagonals which are shifted by the batch dimension.

sim_ij = torch.diag(similarity_matrix, batch_size)

sim_ji = torch.diag(similarity_matrix, -batch_size)

positives =[sim_ij, sim_ji], dim=0)

There are batch_size×viewsbatch_size occasions views

[0., 0., 0., 1., 0., 0.],

[0., 0., 0., 0., 1., 0.],

[0., 0., 0., 0., 0., 1.],

[1., 0., 0., 0., 0., 0.],

[0., 1., 0., 0., 0., 0.],

[0., 0., 1., 0., 0., 0.]

For the denominator we want each the optimistic and destructive pairs. So the binary masks would be the actual component sensible inverse of the id matrix.

self.masks = (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool)).float()

pos_and_negatives = self.masks * similarity_matrix

Once more, they’re each the positives and the negatives within the denominator.

You may make out the remainder of it (temperature scaling and summing the negatives from the denominator and so on.):

SimCLR loss implementation

import torch

import torch.nn as nn

import torch.nn.purposeful as F

def device_as(t1, t2):


Strikes t1 to the machine of t2



class ContrastiveLoss(nn.Module):


Vanilla Contrastive loss, additionally known as InfoNceLoss as in SimCLR paper


def __init__(self, batch_size, temperature=0.5):


self.batch_size = batch_size

self.temperature = temperature

self.masks = (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool)).float()

def calc_similarity_batch(self, a, b):

representations =[a, b], dim=0)

return F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)

def ahead(self, proj_1, proj_2):


proj_1 and proj_2 are batched embeddings [batch, embedding_dim]

the place corresponding indices are pairs

z_i, z_j within the SimCLR paper


batch_size = proj_1.form[0]

z_i = F.normalize(proj_1, p=2, dim=1)

z_j = F.normalize(proj_2, p=2, dim=1)

similarity_matrix = self.calc_similarity_batch(z_i, z_j)

sim_ij = torch.diag(similarity_matrix, batch_size)

sim_ji = torch.diag(similarity_matrix, -batch_size)

positives =[sim_ij, sim_ji], dim=0)

nominator = torch.exp(positives / self.temperature)

denominator = device_as(self.masks, similarity_matrix) * torch.exp(similarity_matrix / self.temperature)

all_losses = -torch.log(nominator / torch.sum(denominator, dim=1))

loss = torch.sum(all_losses) / (2 * self.batch_size)

return loss


The important thing to self-supervised illustration studying is knowledge augmentations. A generally used transformation pipeline is the next:

  • Crop on a random scale from 7% to 100% of the picture

  • Resize all pictures to 224 or different spatial dimensions.

  • Apply horizontal flipping with 50% likelihood

  • Apply heavy coloration jittering with 80% likelihood

  • Apply gaussian blur with 50% likelihood. Kernel dimension is normally round 10% of the picture or much less.

  • Convert RGB pictures to grayscale with 20% likelihood.

  • Normalize based mostly on the means and variances of imagenet

This pipeline will likely be utilized independently to every picture twice and it’ll produce two completely different views that will likely be fed into the spine mannequin. On this pocket book, we’ll use a regular resnet18.

import torch

import torchvision.transforms as T

class Increase:


A stochastic knowledge augmentation module

Transforms any given knowledge instance randomly

leading to two correlated views of the identical instance,

denoted x ̃i and x ̃j, which we take into account as a optimistic pair.


def __init__(self, img_size, s=1):

color_jitter = T.ColorJitter(

0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s


blur = T.GaussianBlur((3, 3), (0.1, 2.0))

self.train_transform = torch.nn.Sequential(



T.RandomApply([color_jitter], p=0.8),

T.RandomApply([blur], p=0.5),


T.Normalize(imply=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])


def __call__(self, x):

return self.train_transform(x), self.train_transform(x)

Under are 4 completely different views of the identical picture by making use of the identical stochastic pipeline:


4 completely different augmentation of the identical with the identical pipeline. Picture by writer

To visualise them you might want to undo the mean-std normalization and put the colour channels within the final dimension:

def imshow(img):


exhibits an imagenet-normalized picture on the display screen


imply = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32)

std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32)

unnormalize = T.Normalize((-imply / std).tolist(), (1.0 / std).tolist())

npimg = unnormalize(img).numpy()

plt.imshow(np.transpose(npimg, (1, 2, 0)))


dataset = STL10("./", break up='practice', rework=Increase(96), obtain=True)





Modify Resnet18 and outline parameter teams

One necessary step to run the simclr is to take away the final absolutely linked layer. We’ll substitute it with an id perform. Then, we have to add the projection head (one other MLP) that will likely be used just for the self-supervised pretraining stage. To take action, we want to concentrate on the dimension of the options of our mannequin. Particularly, resnet18 outputs a 512-dim vector whereas resnet50 outputs a 2048-dim vector. The projection MLP would rework it to the embedding vector dimension which is 128, based mostly on the official paper.

To optimize SSL fashions we use heavy regularization methods, like weight decay. To keep away from efficiency deterioration we have to exclude the burden decay from the batch normalization layers.

import pytorch_lightning as pl

import torch

import torch.nn.purposeful as F

from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR

from torch.optim import SGD, Adam

class AddProjection(nn.Module):

def __init__(self, config, mannequin=None, mlp_dim=512):

tremendous(AddProjection, self).__init__()

embedding_size = config.embedding_size

self.spine = default(mannequin, fashions.resnet18(pretrained=False, num_classes=config.embedding_size))

mlp_dim = default(mlp_dim, self.spine.fc.in_features)

print('Dim MLP enter:',mlp_dim)

self.spine.fc = nn.Identification()

self.projection = nn.Sequential(

nn.Linear(in_features=mlp_dim, out_features=mlp_dim),



nn.Linear(in_features=mlp_dim, out_features=embedding_size),



def ahead(self, x, return_embedding=False):

embedding = self.spine(x)

if return_embedding:

return embedding

return self.projection(embedding)

The subsequent step is to separate the fashions’ parameters into 2 teams.

The aim of the second group is to take away weight decay from batch normalization layers. Within the case of utilizing the LARS optimizer, you additionally must take away weight decay from biases. One solution to obtain that’s the following perform:

def define_param_groups(mannequin, weight_decay, optimizer_name):

def exclude_from_wd_and_adaptation(title):

if 'bn' in title:

return True

if optimizer_name == 'lars' and 'bias' in title:

return True

param_groups = [


'params': [p for name, p in model.named_parameters() if not exclude_from_wd_and_adaptation(name)],

'weight_decay': weight_decay,

'layer_adaptation': True,



'params': [p for name, p in model.named_parameters() if exclude_from_wd_and_adaptation(name)],

'weight_decay': 0.,

'layer_adaptation': False,



return param_groups

I’m not utilizing the LARS optimizer on this tutorial however should you plan to make use of it right here is an implementation that I take advantage of as a reference.

SimCLR coaching logic

Right here we’ll implement the entire coaching logic of SimCLR. Take 2 views, ahead them to get the embedding projections, and calculate the SimCLR loss.

We will wrap up the SimCLR coaching with one class utilizing Pytorch lightning that encapsulates all of the coaching logic. In its easiest kind, we have to implement the training_step methodology that will get as enter a batch from the dataloader. You may consider it as calling batch = subsequent(iter(dataloader)) in every step. Subsequent comes the configure_optimizers methodology which binds the mannequin with the optimizer and the coaching scheduler. I used an already applied scheduler from PyTorch lightning bolts (one other small bundle within the lightning ecosystem). Primarily, we regularly enhance the training charge to its base worth after which we do cosine annealing.

class SimCLR_pl(pl.LightningModule):

def __init__(self, config, mannequin=None, feat_dim=512):


self.config = config

self.increase = Increase(config.img_size)

self.mannequin = AddProjection(config, mannequin=mannequin, mlp_dim=feat_dim)

self.loss = ContrastiveLoss(config.batch_size, temperature=self.config.temperature)

def ahead(self, X):

return self.mannequin(X)

def training_step(self, batch, batch_idx):

x, labels = batch

x1, x2 = self.increase(x)

z1 = self.mannequin(x1)

z2 = self.mannequin(x2)

loss = self.loss(z1, z2)

self.log('Contrastive loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

return loss

def configure_optimizers(self):

max_epochs = int(self.config.epochs)

param_groups = define_param_groups(self.mannequin, self.config.weight_decay, 'adam')

lr =

optimizer = Adam(param_groups, lr=lr, weight_decay=self.config.weight_decay)

print(f'Optimizer Adam, '

f'Studying Charge {lr}, '

f'Efficient batch dimension {self.config.batch_size * self.config.gradient_accumulation_steps}')

scheduler_warmup = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=max_epochs,


return [optimizer], [scheduler_warmup]

Gradient accumulation and efficient batch dimension

Right here it’s essential to spotlight the significance of utilizing an enormous batch dimension. This methodology is closely depending on a big batch dimension to push away from the two views of the identical picture (positives). To try this on a restricted funds we will use gradient accumulation. We common the gradients of NN steps after which replace the mannequin, as a substitute of updating after every forward-backward go.

Thus, now it ought to make full sense that the efficient batch is: batch_size_per_gpuaccumulation_stepsnumber_of_gpusbatch_size_per_gpu * accumulation_steps * number_of_gpus

“In laptop programming, a callback is a reference to executable code or a bit of executable code that’s handed as an argument to different code. This enables a lower-level software program layer to name a subroutine (or perform) outlined in a higher-level layer.” ~ StackOverflow

from pytorch_lightning.callbacks import GradientAccumulationScheduler

accumulator = GradientAccumulationScheduler(scheduling={0: train_config.gradient_accumulation_steps})

Essential SimCLR pretraining script

The principle script simply collects every thing collectively and initializes the Coach class of PyTorch lightning. You may then run it on a single or a number of GPUs. Notice that within the snippet beneath,I’m studying all of the obtainable GPUs of the system.

import torch

from pytorch_lightning import Coach

import os

from pytorch_lightning.callbacks import GradientAccumulationScheduler

from pytorch_lightning.callbacks import ModelCheckpoint

from torchvision.fashions import resnet18

available_gpus = len([torch.cuda.device(i) for i in range(torch.cuda.device_count())])

save_model_path = a part of(os.getcwd(), "saved_models/")



resume_from_checkpoint = False

train_config = Hparams()


save_name = filename + '.ckpt'

mannequin = SimCLR_pl(train_config, mannequin=resnet18(pretrained=False), feat_dim=512)

data_loader = get_stl_dataloader(train_config.batch_size)

accumulator = GradientAccumulationScheduler(scheduling={0: train_config.gradient_accumulation_steps})

checkpoint_callback = ModelCheckpoint(filename=filename, dirpath=save_model_path,every_n_val_epochs=2,

save_last=True, save_top_k=2,monitor='Contrastive loss_epoch',mode='min')

if resume_from_checkpoint:

coach = Coach(callbacks=[accumulator, checkpoint_callback],





coach = Coach(callbacks=[accumulator, checkpoint_callback],



coach.match(mannequin, data_loader)


from google.colab import information



Okay, we educated a mannequin. Now it’s time for fine-tuning. We’ll use the PyTorch lightning module class to encapsulate the logic. I’m taking the pretrained resnet18 spine, with out the projection head, and I’m solely including one linear layer on prime. I’m advantageous tuning the entire community. No augmentations are utilized right here. They’d solely delay the coaching. As a substitute, we wish to quantify the efficiency towards pretrained weights on imagenet and random initialization.

import pytorch_lightning as pl

import torch

from torch.optim import SGD

class SimCLR_eval(pl.LightningModule):

def __init__(self, lr, mannequin=None, linear_eval=False):

tremendous().__init__() = lr

self.linear_eval = linear_eval

if self.linear_eval:


self.mlp = torch.nn.Sequential(



self.mannequin = torch.nn.Sequential(

mannequin, self.mlp


self.loss = torch.nn.CrossEntropyLoss()

def ahead(self, X):

return self.mannequin(X)

def training_step(self, batch, batch_idx):

x, y = batch

z = self.ahead(x)

loss = self.loss(z, y)

self.log('Cross Entropy loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

predicted = z.argmax(1)

acc = (predicted == y).sum().merchandise() / y.dimension(0)

self.log('Prepare Acc', acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)

return loss

def validation_step(self, batch, batch_idx):

x, y = batch

z = self.ahead(x)

loss = self.loss(z, y)

self.log('Val CE loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)

predicted = z.argmax(1)

acc = (predicted == y).sum().merchandise() / y.dimension(0)

self.log('Val Accuracy', acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)

return loss

def configure_optimizers(self):

if self.linear_eval:

print(f"nn Consideration! Linear analysis n")

optimizer = SGD(self.mlp.parameters(),, momentum=0.9)


optimizer = SGD(self.mannequin.parameters(),, momentum=0.9)

return [optimizer]

Importantly, STL10 is a subset of imagenet so switch studying from imagenet is anticipated to work very nicely.

Methodology Finetunning the entire community, Validation Accuracy Linear analysis. Validation Accuracy
SimCLR pretraining on STL10 unlabelled break up 75.1% 73.2 %
Imagenet pretraining (1M) 87.9% 78.6 %
Random initialization 50.6 %

In all instances the mannequin overfits throughout finetuning. Keep in mind no augmentations had been utilized.


Even with an unfair analysis in comparison with pretrained weights from imagenet, contrastive self-supervised studying demonstrates some tremendous promising outcomes. There are numerous different self-supervised strategies to play with, however SimCLR is the baseline.

To wrap up, we explored the best way to construct step-by-step the SimCLR loss perform and launch a coaching script with out an excessive amount of boilerplate code with Pytorch-lightning. Despite the fact that there’s a hole between SimCLR realized representations, newest state-of-the-art strategies are catching up and even surpass imagenet-learned options in lots of domains.

Thanks on your curiosity in AI and keep optimistic!

Deep Studying in Manufacturing E book ?

Discover ways to construct, practice, deploy, scale and preserve deep studying fashions. Perceive ML infrastructure and MLOps utilizing hands-on examples.

Study extra

* Disclosure: Please word that a number of the hyperlinks above could be affiliate hyperlinks, and at no extra value to you, we’ll earn a fee should you resolve to make a purchase order after clicking by way of.