Skip to content

Instantly share code, notes, and snippets.

@bobchennan
Created March 15, 2019 18:49
Show Gist options
  • Save bobchennan/a865b153c6835a3a6a5c628213766150 to your computer and use it in GitHub Desktop.
Save bobchennan/a865b153c6835a3a6a5c628213766150 to your computer and use it in GitHub Desktop.
GELS with deriviative in pytorch
import torch
from torch.autograd import Function
class GELS(Function):
""" Efficient implementation of gels from
Nanxin Chen
bobchennan@gmail.com
"""
@staticmethod
def forward(ctx, A, b):
# A: (..., M, N)
# b: (..., M, K)
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_ops.py#L267
u = torch.cholesky(torch.matmul(A.transpose(-1, -2), A), upper=True)
ret = torch.potrs(torch.matmul(A.transpose(-1, -2), b), u)
ctx.save_for_backward(u, ret, A, b)
return ret
@staticmethod
def backward(ctx, grad_output):
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_grad.py#L223
chol, x, a, b = ctx.saved_tensors
z = torch.potrs(grad_output, chol)
xzt = torch.matmul(x, z.transpose(-1,-2))
zx_sym = xzt + xzt.transpose(-1, -2)
grad_A = - torch.matmul(a, zx_sym) + torch.matmul(b, z.transpose(-1, -2))
grad_b = torch.matmul(a, z)
return grad_A, grad_b
if __name__ == "__main__":
A=torch.nn.Parameter(torch.randn(1, 10, 4, dtype=torch.double))
b=torch.nn.Parameter(torch.randn(1, 10, 1, dtype=torch.double))
# forward test
print(GELS.apply(A,b))
print(torch.gels(b.squeeze(0),A.squeeze(0))[0])
import numpy as np
print(np.linalg.lstsq(A.data.numpy()[0], b.data.numpy()[0]))
# backward test
print(torch.autograd.gradcheck(GELS.apply, [A, b]))
@LemonPi
Copy link

LemonPi commented May 11, 2019

New pytorch version warns of deprecation for potrs; you can replace line 15 with

        ret = torch.cholesky_solve(torch.matmul(A.transpose(-1, -2), b), u, upper=True)

and line 23 with

        z = torch.cholesky_solve(grad_output, chol, upper=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment