Skip to content

Instantly share code, notes, and snippets.

@t-vi
Created October 13, 2017 07:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save t-vi/f3437d31b3e4680cc78d9999ea5a8af6 to your computer and use it in GitHub Desktop.
Save t-vi/f3437d31b3e4680cc78d9999ea5a8af6 to your computer and use it in GitHub Desktop.
Computing the Variance of Gradients for Linear Layers
import torch
from torch.autograd import Variable
def linear_with_sumsq(inp, weight, bias=None):
def provide_sumsq(inp,w,b):
def _h(i):
if not hasattr(w, 'grad_sumsq'):
w.grad_sumsq = 0
w.grad_sumsq += ((i**2).t().matmul(inp**2))*i.size(0)
if b is not None:
if not hasattr(b, 'grad_sumsq'):
b.grad_sumsq = 0
b.grad_sumsq += (i**2).sum(0)*i.size(0)
return _h
res = inp.matmul(weight.t())
if bias is not None:
res = res + bias
res.register_hook(provide_sumsq(inp,weight,bias))
return res
weight = Variable(torch.randn(3,2), requires_grad=True)
inp = Variable(torch.randn(4,2))
bias = Variable(torch.randn(3), requires_grad=True)
c = linear_with_sumsq(inp, weight, bias)
d = (c**2).sum(1).mean(0)
d.backward()
# manual variance calculation
gr = []
gr_b = []
for i in range(len(inp)):
w_i = Variable(weight.data, requires_grad=True)
b_i = Variable(bias.data, requires_grad=True)
i_i = inp[i:i+1]
c_i = i_i.matmul(w_i.t())+b_i
d_i = (c_i**2).sum()
d_i.backward()
gr.append(w_i.grad.data)
gr_b.append(b_i.grad.data)
gr = torch.stack(gr, dim=0)
gr_b = torch.stack(gr_b, dim=0)
print(gr.var(0,unbiased=False), weight.grad_sumsq-weight.grad**2, gr_b.var(0,unbiased=False), bias.grad_sumsq-bias.grad**2)
@t-vi
Copy link
Author

t-vi commented Oct 13, 2017

The convolution is defined by
$$f_{cij} = w_{cdkl} inp_{d,i+k,j+l}$$
so the derivative is
$$d f_{cij} / d w_{cdkl} = inp_{d,i+k,j+l}$$
and the total derivative of some loss out is
$$
d out / d w_{cdkl} = sum_{ij} d f_{cij} / d w_{cdkl} \cdot d out / d f_{cij}
= sum_{ij} inp_{d,i+k,j+l} \cdot dout/ d f_{cij}
$$
Getting the sum over ij and then square before summing over the batch seems not possible right now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment