Skip to content

Instantly share code, notes, and snippets.

@mfkasim1
mfkasim1 / eig.py
Created January 23, 2020 14:11
Differentiable torch.eig() for real eigenvalues
import torch
class eig(torch.autograd.Function):
@staticmethod
def forward(ctx, A):
# normalize the shape to be batched
Ashape = A.shape
if A.ndim == 2:
A = A.unsqueeze(0)
elif A.ndim > 3: