Skip to content

Instantly share code, notes, and snippets.

@ncullen93
Last active July 9, 2020 08:14
Show Gist options
  • Save ncullen93/9acefab137976712de0a51d88b39ffe7 to your computer and use it in GitHub Desktop.
Save ncullen93/9acefab137976712de0a51d88b39ffe7 to your computer and use it in GitHub Desktop.
import torch as th
from torch.autograd import Variable
from torch.autograd.function import Function
class Symeig(Function):
def __init__(self, eigenvectors=True, upper=True):
self.eigenvectors = eigenvectors
self.upper = upper
def forward(self, input):
w, v = th.symeig(input, eigenvectors=self.eigenvectors, upper=self.upper)
self.save_for_backward(input, w, v)
return w, v
def backward(self, grad_w, grad_v):
x, w, v, = self.saved_tensors
N = x.size(0)
if self.upper:
tri0 = th.triu
tri1 = lambda a: th.tril(a, -1)
else:
tri0 = th.tril
tri1 = lambda a: th.triu(a, 1)
def G(n):
return sum([v[:, m] * grad_v.t()[n].matmul(v[:, m]) / (w[n] - w[m])
for m in range(N) if m != n])
g = sum([th.ger(v[:, n], v[:, n] * grad_w[n] + G(n))
for n in range(N)])
out = tri0(g) + tri1(g).t()
return out
def symeig(input, eigenvectors=True, upper=True):
return Symeig(eigenvectors, upper)(input)
def test_runtime():
"""test that there are no runtime errors"""
import torch as th
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
x = Variable(th.randn(30,10))
w = nn.Parameter(th.rand(30,10), requires_grad=True)
xw = F.linear(x, w)
a, b = symeig(xw)
asum = a.sum()
asum.backward()
def test_gradcheck():
"""test gradcheck"""
from torch.autograd import gradcheck
input = (Variable(th.randn(50,50).double(), requires_grad=True),)
test = gradcheck(Symeig(), input, eps=1e-6, atol=1e-4)
print(test)
def test_theano():
"""test that we get same results as theano"""
## TEST DATA
import numpy as np
XX = np.random.randn(30,10)
YY = np.random.randn(10,30)
## THEANO
import theano
import theano.tensor as T
x = theano.shared(XX)
w = theano.shared(YY)
xw = T.dot(x,w)
xw = xw + xw.T - T.diag(xw.diagonal())
[e,v] = T.nlinalg.eigh(xw)
e_sum = e.sum()
w_grad = theano.grad(e_sum, w)
x_grad = theano.grad(e_sum, x)
wgrad_theano = w_grad.eval()
xgrad_theano = x_grad.eval()
e_theano = e.eval()
v_theano = v.eval()
## PYTORCH
import torch as th
import torch.nn as nn
x = nn.Parameter(th.from_numpy(XX))
w = nn.Parameter(th.from_numpy(YY))
xw = th.mm(x, w)
xw = xw + xw.t() - th.diag(th.diag(xw))
[e,v] = symeig(xw)
e_sum = e.sum()
e_sum.backward()
wgrad_torch = w.grad.data.numpy()
xgrad_torch = x.grad.data.numpy()
e_torch = e.data.numpy()
v_torch = v.data.numpy()
# Check that the two gradients are the same
print(np.allclose(wgrad_theano, wgrad_torch))
print(np.allclose(xgrad_theano, xgrad_torch))
print(np.allclose(e_theano, e_torch))
print(np.allclose(np.abs(v_theano), np.abs(v_torch)))
@jaideep11061982
Copy link

my meomory gets full . when i use zoom from your code.
Why should that be ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment