Skip to content

Instantly share code, notes, and snippets.

@andres-fr
Last active November 18, 2022 16:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save andres-fr/db9d0ba31d1502df62a09d382e504d1e to your computer and use it in GitHub Desktop.
Save andres-fr/db9d0ba31d1502df62a09d382e504d1e to your computer and use it in GitHub Desktop.
Different ways of computing GSNR with PyTorch+BackPACK
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
!!! WARNING !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
The quantities computed in this gist are wrong. While a correction is
pending, see the following discussion for more details. The TLDR is that
they must be divided by batch_size^2, or use the alternative Cockpit
formula:
https://github.com/f-dangel/cockpit/issues/28
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
This MNIST example uses backpack to obtain the GSNR during training in 3
different ways:
1. Easy but memory-inefficient: we gather all element-wise gradients and then
compute GSNR manually. This is inefficient because we need to store
model_size*batch_size parameters.
2. Easy, memory efficient: we use the ``Variance`` backpack extension to gather
batch-wise variances for each parameter, so we never need to store all
element-wise gradients and we end up with model_size+small_constant.
3. Possibly faster version of 2. using Var(X) = E[X²] - E[X]². Note that that
batch_size scaling must be taken into account
--------------------------------------------------------------------------------
The following commands allow to install all needed dependencies from scratch::
conda create -n gsnr python==3.9
conda activate gsnr
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
pip install backpack-for-pytorch==1.5.0
--------------------------------------------------------------------------------
to run it::
python gsnr_backpack.py
"""
import torch
import torch.nn.functional as F
#
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
#
from backpack import backpack, extend
from backpack.extensions import BatchGrad, SumGradSquared, Variance
# check out the different extensions available: https://docs.backpack.pt
# ##############################################################################
# # HELPERS
# ##############################################################################
class SimpleCNN(torch.nn.Module):
"""
"""
def __init__(self):
"""
"""
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 32, 3, 1)
self.conv2 = torch.nn.Conv2d(32, 64, 3, 1)
self.fc1 = torch.nn.Linear(9216, 128)
self.fc2 = torch.nn.Linear(128, 10)
def forward(self, x):
"""
"""
x = F.leaky_relu(self.conv1(x))
x = F.leaky_relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = F.leaky_relu(self.fc1(x))
x = self.fc2(x)
# we skip softmax, feed logits directly to loss_fn
# x = F.log_softmax(x, dim=1)
return x
# ##############################################################################
# # MAIN ROUTINE
# ##############################################################################
if __name__ == "__main__":
# globals
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MNIST_ROOTPATH = "/tmp"
NUM_EPOCHS = 10
BATCH_SIZE = 32
LEARNING_RATE = 0.01
EPSILON = 1e-5
SEED = 54321
# seed
torch.manual_seed(SEED)
# dataset, dataloader
mnist = MNIST(MNIST_ROOTPATH, train=True, transform=ToTensor(),
download=True)
dataloader = torch.utils.data.DataLoader(mnist, BATCH_SIZE, shuffle=True)
# model, loss, optimizer
model = SimpleCNN().to(DEVICE)
loss_fn = torch.nn.CrossEntropyLoss() # receives (logits, idx_target)
opt = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
# backpack wrapper
model = extend(model)
loss_fn = extend(loss_fn)
# training loop
for epoch_i in range(NUM_EPOCHS):
for batch_i, (batch_x, batch_y) in enumerate(dataloader):
with torch.no_grad():
batch_x = batch_x.to(DEVICE) # shape: (b, 1, 28, 28)
batch_y = batch_y.to(DEVICE) # shape: (b,)
#
opt.zero_grad()
preds = model(batch_x) # shape: (b, 10)
loss = loss_fn(preds, batch_y)
# This is the important bit: call backward with BP extensions
with backpack(
# Elementwise grads, needed only for the "inefficient" way
BatchGrad(),
# g**2 / var(g), memory efficient+easy
Variance(),
# Possibly faster version using Var(X) = E[X²] - E[X]²
SumGradSquared()):
loss.backward()
# now we have the quantities available. Compute layerwise GSNR:
params = [p for p in model.parameters()]
grads_2 = [p.grad ** 2 for p in params]
# 1. Memory-nefficient way:
variances = [p.grad_batch.var(dim=0) for p in params]
gsnr1 = [g2 / (v + EPSILON)
for g2, v in zip(grads_2, variances)]
# 2. Using built-in extensions (eltwise grads never needed)
gsnr2 = [g2 / (p.variance + EPSILON)
for g2, p in zip(grads_2, params)]
# 3. Possibly faster version using Var(X) = E[X²] - E[X]²
gsnr3 = [g2 / ((p.sum_grad_squared / BATCH_SIZE)-
(g2 / BATCH_SIZE**2) + EPSILON)
for g2, p in zip(grads_2, params)]
# test that alternatives do provide GSNR
print ([torch.allclose(g1, g2, rtol=0.01)
for g1, g2 in zip(gsnr1, gsnr2)])
print ([torch.allclose(g1, g3, rtol=0.01)
for g1, g3 in zip(gsnr1, gsnr2)])
# finalize the batch training step and report
opt.step()
print("loss:", loss.item())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment