Last active
November 18, 2022 16:05
-
-
Save andres-fr/db9d0ba31d1502df62a09d382e504d1e to your computer and use it in GitHub Desktop.
Different ways of computing GSNR with PyTorch+BackPACK
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding:utf-8 -*- | |
""" | |
!!! WARNING !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! | |
The quantities computed in this gist are wrong. While a correction is | |
pending, see the following discussion for more details. The TLDR is that | |
they must be divided by batch_size^2, or use the alternative Cockpit | |
formula: | |
https://github.com/f-dangel/cockpit/issues/28 | |
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! | |
This MNIST example uses backpack to obtain the GSNR during training in 3 | |
different ways: | |
1. Easy but memory-inefficient: we gather all element-wise gradients and then | |
compute GSNR manually. This is inefficient because we need to store | |
model_size*batch_size parameters. | |
2. Easy, memory efficient: we use the ``Variance`` backpack extension to gather | |
batch-wise variances for each parameter, so we never need to store all | |
element-wise gradients and we end up with model_size+small_constant. | |
3. Possibly faster version of 2. using Var(X) = E[X²] - E[X]². Note that that | |
batch_size scaling must be taken into account | |
-------------------------------------------------------------------------------- | |
The following commands allow to install all needed dependencies from scratch:: | |
conda create -n gsnr python==3.9 | |
conda activate gsnr | |
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch | |
pip install backpack-for-pytorch==1.5.0 | |
-------------------------------------------------------------------------------- | |
to run it:: | |
python gsnr_backpack.py | |
""" | |
import torch | |
import torch.nn.functional as F | |
# | |
from torchvision.datasets import MNIST | |
from torchvision.transforms import ToTensor | |
# | |
from backpack import backpack, extend | |
from backpack.extensions import BatchGrad, SumGradSquared, Variance | |
# check out the different extensions available: https://docs.backpack.pt | |
# ############################################################################## | |
# # HELPERS | |
# ############################################################################## | |
class SimpleCNN(torch.nn.Module): | |
""" | |
""" | |
def __init__(self): | |
""" | |
""" | |
super().__init__() | |
self.conv1 = torch.nn.Conv2d(1, 32, 3, 1) | |
self.conv2 = torch.nn.Conv2d(32, 64, 3, 1) | |
self.fc1 = torch.nn.Linear(9216, 128) | |
self.fc2 = torch.nn.Linear(128, 10) | |
def forward(self, x): | |
""" | |
""" | |
x = F.leaky_relu(self.conv1(x)) | |
x = F.leaky_relu(self.conv2(x)) | |
x = F.max_pool2d(x, 2) | |
x = torch.flatten(x, 1) | |
x = F.leaky_relu(self.fc1(x)) | |
x = self.fc2(x) | |
# we skip softmax, feed logits directly to loss_fn | |
# x = F.log_softmax(x, dim=1) | |
return x | |
# ############################################################################## | |
# # MAIN ROUTINE | |
# ############################################################################## | |
if __name__ == "__main__": | |
# globals | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
MNIST_ROOTPATH = "/tmp" | |
NUM_EPOCHS = 10 | |
BATCH_SIZE = 32 | |
LEARNING_RATE = 0.01 | |
EPSILON = 1e-5 | |
SEED = 54321 | |
# seed | |
torch.manual_seed(SEED) | |
# dataset, dataloader | |
mnist = MNIST(MNIST_ROOTPATH, train=True, transform=ToTensor(), | |
download=True) | |
dataloader = torch.utils.data.DataLoader(mnist, BATCH_SIZE, shuffle=True) | |
# model, loss, optimizer | |
model = SimpleCNN().to(DEVICE) | |
loss_fn = torch.nn.CrossEntropyLoss() # receives (logits, idx_target) | |
opt = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE) | |
# backpack wrapper | |
model = extend(model) | |
loss_fn = extend(loss_fn) | |
# training loop | |
for epoch_i in range(NUM_EPOCHS): | |
for batch_i, (batch_x, batch_y) in enumerate(dataloader): | |
with torch.no_grad(): | |
batch_x = batch_x.to(DEVICE) # shape: (b, 1, 28, 28) | |
batch_y = batch_y.to(DEVICE) # shape: (b,) | |
# | |
opt.zero_grad() | |
preds = model(batch_x) # shape: (b, 10) | |
loss = loss_fn(preds, batch_y) | |
# This is the important bit: call backward with BP extensions | |
with backpack( | |
# Elementwise grads, needed only for the "inefficient" way | |
BatchGrad(), | |
# g**2 / var(g), memory efficient+easy | |
Variance(), | |
# Possibly faster version using Var(X) = E[X²] - E[X]² | |
SumGradSquared()): | |
loss.backward() | |
# now we have the quantities available. Compute layerwise GSNR: | |
params = [p for p in model.parameters()] | |
grads_2 = [p.grad ** 2 for p in params] | |
# 1. Memory-nefficient way: | |
variances = [p.grad_batch.var(dim=0) for p in params] | |
gsnr1 = [g2 / (v + EPSILON) | |
for g2, v in zip(grads_2, variances)] | |
# 2. Using built-in extensions (eltwise grads never needed) | |
gsnr2 = [g2 / (p.variance + EPSILON) | |
for g2, p in zip(grads_2, params)] | |
# 3. Possibly faster version using Var(X) = E[X²] - E[X]² | |
gsnr3 = [g2 / ((p.sum_grad_squared / BATCH_SIZE)- | |
(g2 / BATCH_SIZE**2) + EPSILON) | |
for g2, p in zip(grads_2, params)] | |
# test that alternatives do provide GSNR | |
print ([torch.allclose(g1, g2, rtol=0.01) | |
for g1, g2 in zip(gsnr1, gsnr2)]) | |
print ([torch.allclose(g1, g3, rtol=0.01) | |
for g1, g3 in zip(gsnr1, gsnr2)]) | |
# finalize the batch training step and report | |
opt.step() | |
print("loss:", loss.item()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment