BYOL tutorial: self-supervised studying on CIFAR photographs with code in Pytorch



After presenting SimCLR, a contrastive self-supervised studying framework, I made a decision to display one other notorious methodology, known as BYOL. Bootstrap Your Personal Latent (BYOL), is a brand new algorithm for self-supervised studying of picture representations. BYOL has two foremost benefits:

  • It doesn’t explicitly use unfavourable samples. As an alternative, it instantly minimizes the similarity of representations of the identical picture beneath a unique augmented view (optimistic pair). Detrimental samples are photographs from the batch aside from the optimistic pair.

  • Because of this, BYOL is claimed to require smaller batch sizes, which makes it a beautiful selection.

Under, you possibly can look at the strategy. In contrast to the unique paper, I name the net community pupil and the goal community instructor.


Overview of BYOL methodology. Supply: BYOL paper

On-line community aka pupil: in comparison with SimCLR, there’s a second MLP, known as predictor, which makes the entire methodology uneven. Uneven in comparison with what? Properly, to the instructor mannequin (goal community).

Why is that necessary?

As a result of the instructor mannequin is up to date solely via exponential shifting common (EMA) from the scholar’s parameters. In the end, at every iteration, a tiny share (lower than 1%) of the parameters of the scholar is handed to the instructor. Thus, gradients circulation solely via the scholar community. This may be applied as:

class EMA():

def __init__(self, alpha):


self.alpha = alpha

def update_average(self, outdated, new):

if outdated is None:

return new

return outdated * self.alpha + (1 - self.alpha) * new

ema = EMA(0.99)

for student_params, teacher_params in zip(student_model.parameters(),teacher_model.parameters()):

old_weight, up_weight = teacher_params.knowledge, student_params.knowledge

teacher_params.knowledge = ema.update_average(old_weight, up_weight)

One other key distinction between Simclr and BYOL is the loss operate.

Loss operate

The predictor MLP is solely utilized to the scholar, making the structure uneven. It is a key design option to keep away from mode collapse. Mode collapse right here could be to output the identical projection for all of the inputs.


Overview of BYOL methodology. Supply: BYOL paper

Lastly, the authors outlined the next imply squared error between the L2-normalized predictions and goal projections:

Lθ,ξqˉθ(zθ)zˉξ22=22qθ(zθ),zξqθ(zθ)2zξ2.mathcal{L}_{theta, xi} triangleqleft|bar{q}_{theta}left(z_{theta}proper)-bar{z}_{xi}^{prime}proper|_{2}^{2}=2-2 cdot frac{leftlangle q_{theta}left(z_{theta}proper), z_{xi}^{prime}rightrangle}{left|q_{theta}left(z_{theta}proper)proper|_{2} cdotleft|z_{xi}^{prime}proper|_{2}} .

The L2 loss might be applied as follows. L2 normalization is utilized beforehand.

import torch

import torch.nn.useful as F

def loss_fn(x, y):

x = F.normalize(x, dim=-1, p=2)

y = F.normalize(y, dim=-1, p=2)

return 2 - 2 * (x * y).sum(dim=-1)

Code is out there on GitHub

Monitoring down what’s taking place in self-supervised pretraining: KNN accuracy

Nonetheless, the loss in self-supervised studying is just not a dependable metric to trace. What I came upon to be the easiest way to trace what’s taking place whereas coaching, is to measure the ΚΝΝ accuracy.

The essential benefit of utilizing KNN is that we do not have to coach a linear classifier on high every time, so it’s quicker and fully unsupervised.

Observe: Measuring KNN solely applies to picture classification, however you get the concept. For this function, I made a category to encapsulate the logic of KNN in our context:

import numpy as np

import torch

from sklearn.model_selection import cross_val_score

from sklearn.neighbors import KNeighborsClassifier

from torch import nn

class KNN():

def __init__(self, mannequin, ok, gadget):

tremendous(KNN, self).__init__()

self.ok = ok

self.gadget = gadget

self.mannequin =


def extract_features(self, loader):


Infer/Extract options from a skilled mannequin


loader: practice or check loader

Returns: 3 tensors of all: input_images, options, labels


x_lst = []

options = []

label_lst = []

with torch.no_grad():

for input_tensor, label in loader:

h = self.mannequin(




x_total = torch.stack(x_lst)

h_total = torch.stack(options)

label_total = torch.stack(label_lst)

return x_total, h_total, label_total

def knn(self, options, labels, ok=1):


Evaluating knn accuracy in function area.

Calculates solely top-1 accuracy (returns 0 for top-5)


options: [... , dataset_size, feat_dim]

labels: [... , dataset_size]

ok: nearest neighbours

Returns: practice accuracy, or practice and check acc


feature_dim = options.form[-1]

with torch.no_grad():

features_np = options.cpu().view(-1, feature_dim).numpy()

labels_np = labels.cpu().view(-1).numpy()

self.cls = KNeighborsClassifier(ok, metric="cosine").match(features_np, labels_np)

acc = self.eval(options, labels)

return acc

def eval(self, options, labels):

feature_dim = options.form[-1]

options = options.cpu().view(-1, feature_dim).numpy()

labels = labels.cpu().view(-1).numpy()

acc = 100 * np.imply(cross_val_score(self.cls, options, labels))

return acc

def _find_best_indices(self, h_query, h_ref):

h_query = h_query / h_query.norm(dim=1).view(-1, 1)

h_ref = h_ref / h_ref.norm(dim=1).view(-1, 1)

scores = torch.matmul(h_query, h_ref.t())

rating, indices = scores.topk(1, dim=1)

return rating, indices

def match(self, train_loader, test_loader=None):

with torch.no_grad():

x_train, h_train, l_train = self.extract_features(train_loader)

train_acc = self.knn(h_train, l_train, ok=self.ok)

if test_loader is not None:

x_test, h_test, l_test = self.extract_features(test_loader)

test_acc = self.eval(h_test, l_test)

return train_acc, test_acc

Now we will give attention to the strategy and BYOL mannequin.

Modify resnet: add MLP projection heads

We’ll begin with a base mannequin (resnet18) and modify it for self-supervised studying. The final layer that usually does the classification is changed with an id operate. The output options of resnet18 can be fed to the MLP projector.

import copy

import torch

from torch import nn

import torch.nn.useful as F

class MLP(nn.Module):

def __init__(self, dim, embedding_size=256, hidden_size=2048, batch_norm_mlp=False):


norm = nn.BatchNorm1d(hidden_size) if batch_norm_mlp else nn.Identification()

self.web = nn.Sequential(

nn.Linear(dim, hidden_size),



nn.Linear(hidden_size, embedding_size)


def ahead(self, x):

return self.web(x)

class AddProjHead(nn.Module):

def __init__(self, mannequin, in_features, layer_name, hidden_size=4096,

embedding_size=256, batch_norm_mlp=True):

tremendous(AddProjHead, self).__init__()

self.spine = mannequin

setattr(self.spine, layer_name, nn.Identification())

self.spine.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)

self.spine.maxpool = torch.nn.Identification()

self.projection = MLP(in_features, embedding_size, hidden_size=hidden_size, batch_norm_mlp=batch_norm_mlp)

def ahead(self, x, return_embedding=False):

embedding = self.spine(x)

if return_embedding:

return embedding

return self.projection(embedding)

I additionally changed the primary conv layer of resnet18 from 7×7 to 3×3 convolution since we’re taking part in with 32×32 photographs (CIFAR-10).

Code is out there on GitHub. In case you are planning to solidify your Pytorch information, there are two wonderful books that we extremely advocate: Deep studying with PyTorch from Manning Publications and Machine Studying with PyTorch and Scikit-Be taught by Sebastian Raschka. You may all the time use the 35% low cost code blaisummer21 for all Manning’s merchandise.

The precise BYOL methodology

To date I offered all of the necessary parts to achieve this level. Now we are going to construct the BYOL module with our beloved pupil and instructor networks. Discover that the scholar predictor MLP and projector are similar.

My implementation of BYOL was based mostly on lucidrains’ repo. I modified it to make it extra easy and mess around with it.

class BYOL(nn.Module):

def __init__(












web: mannequin to be skilled

batch_norm_mlp: whether or not to make use of batchnorm1d within the mlp predictor and projector

in_features: the quantity options which are produced by the spine web i.e. resnet

projection_size: the dimensions of the output vector of the 2 similar MLPs

projection_hidden_size: the dimensions of the hidden vector of the 2 similar MLPs

augment_fn2: apply completely different augmentation the second view

moving_average_decay: t hyperparameter to regulate the affect within the goal community weight replace

use_momentum: whether or not to replace the goal community



self.web = web

self.student_model = AddProjHead(mannequin=web, in_features=in_features,





self.use_momentum = use_momentum

self.teacher_model = self._get_teacher()

self.target_ema_updater = EMA(moving_average_decay)

self.student_predictor = MLP(projection_size, projection_size, projection_hidden_size)


def _get_teacher(self):

return copy.deepcopy(self.student_model)


def update_moving_average(self):

assert self.use_momentum, 'you don't want to replace the shifting common, since you will have turned off momentum '

'for the goal encoder '

assert self.teacher_model is not None, 'goal encoder has not been created but'

for student_params, teacher_params in zip(self.student_model.parameters(), self.teacher_model.parameters()):

old_weight, up_weight = teacher_params.knowledge, student_params.knowledge

teacher_params.knowledge = self.target_ema_updater.update_average(old_weight, up_weight)

def ahead(


image_one, image_two=None,


if return_embedding or (image_two is None):

return self.student_model(image_one, return_embedding=True)

student_proj_one = self.student_model(image_one)

student_proj_two = self.student_model(image_two)

student_pred_one = self.student_predictor(student_proj_one)

student_pred_two = self.student_predictor(student_proj_two)

with torch.no_grad():

teacher_proj_one = self.teacher_model(image_one).detach_()

teacher_proj_two = self.teacher_model(image_two).detach_()

loss_one = loss_fn(student_pred_one, teacher_proj_one)

loss_two = loss_fn(student_pred_two, teacher_proj_two)

return (loss_one + loss_two).imply()

For CIFAR-10 it’s sufficient to make use of 2048 as a hidden dimension and 256 because the embedding dimension. We’ll practice a resnet18 that outputs 512 options for 100 epochs. The elements of the code that consult with knowledge loading and augmentations are omitted to extend readability. You may look them up within the code.

You should utilize the Adam optimizer ( lr=3104lr=3 * 10^{-4}

The one factor that can be modified within the practice code is the EMA replace.

def training_step(mannequin, knowledge):

(view1, view2), _ = knowledge

loss = mannequin(view1.cuda(), view2.cuda())

return loss

def train_one_epoch(mannequin, train_dataloader, optimizer):


total_loss = 0.

num_batches = len(train_dataloader)

for knowledge in train_dataloader:


loss = training_step(mannequin, knowledge)




total_loss += loss.merchandise()

return total_loss/num_batches

Let’s bounce on the outcomes!

Outcomes: KNN accuracy VS pretraining epochs


KNN accuracy each 4 epochs. Picture by creator

Isn’t it wonderful that with none labels we will attain a validation accuracy of 70%? I discovered this wonderful, particularly for this methodology that appears to be much less delicate to the batch measurement.

However why does the batch measurement has an impact right here? Isn’t it alleged to be not utilizing unfavourable paris? The place does the dependence of the batch measurement come from?

Quick reply: Properly, it’s batch normalization within the MLP layers!

Right here is the experiments I made to cross-check it.

A word on batch norm in MLP networks and EMA momentum

I used to be curious to look at the mode collapse with out batch normalization. You may attempt that by your self by setting:

mannequin = BYOL(mannequin, in_features=512, batch_norm_mlp=False)

I noticed that the L2 distance goes to virtually zero from the very first epochs:

Epoch 0: loss:0.06423207696957084

Epoch 8: loss:0.005584242034894534

Epoch 20: loss:0.005460431350347323

The loss goes to roughly zero and KNN stops growing (35% VS 60% within the regular setup). That’s why it’s claimed that BYOL implicitly makes use of a type of contrastive studying by leveraging the batch statistics within the MLPs. Right here is the KNN accuracy:


Mode collapse in BYOL by eradicating batch norm in MLPs. Picture by creator

I’m properly conscious of papers that present that batch statistics are usually not the one situation for BYOL to work. That is an experimental submit, so I’m not going to play that sport. I used to be simply curious to look at mode collapse right here.


For a extra detailed rationalization of the strategy test Yannic’s video on BYOL:

On this tutorial, we applied BYOL step-by-step and pretrained on CIFAR10. We observe the large enhance in KNN accuracy by matching the representations of the identical picture. A random classifier would have 10% and with 100 epochs we attain 70% KNN validation accuracy with none labels. How cool is that?

To study extra about self-supervised studying, keep tuned! Assist us by social media sharing, making a donation, or shopping for our Deep studying in Manufacturing e-book. It could be extremely appreciated.

Deep Studying in Manufacturing E-book ?

Learn to construct, practice, deploy, scale and preserve deep studying fashions. Perceive ML infrastructure and MLOps utilizing hands-on examples.

Be taught extra

* Disclosure: Please word that a few of the hyperlinks above could be affiliate hyperlinks, and at no extra price to you, we are going to earn a fee should you determine to make a purchase order after clicking via.